diff --git a/spring-core/src/main/java/org/springframework/util/DigestUtils.java b/spring-core/src/main/java/org/springframework/util/DigestUtils.java index a4ca743bf9..29dcb54105 100644 --- a/spring-core/src/main/java/org/springframework/util/DigestUtils.java +++ b/spring-core/src/main/java/org/springframework/util/DigestUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.util; +import java.io.IOException; +import java.io.InputStream; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -45,6 +47,16 @@ public abstract class DigestUtils { return digest(MD5_ALGORITHM_NAME, bytes); } + /** + * Calculate the MD5 digest of the given InputStream. + * @param inputStream the inputStream to calculate the digest over + * @return the digest + * @since 4.2 + */ + public static byte[] md5Digest(InputStream inputStream) throws IOException{ + return digest(MD5_ALGORITHM_NAME, inputStream); + } + /** * Return a hexadecimal string representation of the MD5 digest of the given * bytes. @@ -55,6 +67,17 @@ public abstract class DigestUtils { return digestAsHexString(MD5_ALGORITHM_NAME, bytes); } + /** + * Return a hexadecimal string representation of the MD5 digest of the given + * inputStream. + * @param inputStream the inputStream to calculate the digest over + * @return a hexadecimal digest string + * @since 4.2 + */ + public static String md5DigestAsHex(InputStream inputStream) throws IOException{ + return digestAsHexString(MD5_ALGORITHM_NAME, inputStream); + } + /** * Append a hexadecimal string representation of the MD5 digest of the given * bytes to the given {@link StringBuilder}. @@ -66,6 +89,18 @@ public abstract class DigestUtils { return appendDigestAsHex(MD5_ALGORITHM_NAME, bytes, builder); } + /** + * Append a hexadecimal string representation of the MD5 digest of the given + * inputStream to the given {@link StringBuilder}. + * @param inputStream the inputStream to calculate the digest over + * @param builder the string builder to append the digest to + * @return the given string builder + * @since 4.2 + */ + public static StringBuilder appendMd5DigestAsHex(InputStream inputStream, StringBuilder builder) throws IOException{ + return appendDigestAsHex(MD5_ALGORITHM_NAME, inputStream, builder); + } + /** * Creates a new {@link MessageDigest} with the given algorithm. Necessary * because {@code MessageDigest} is not thread-safe. @@ -83,21 +118,46 @@ public abstract class DigestUtils { return getDigest(algorithm).digest(bytes); } + private static byte[] digest(String algorithm, InputStream inputStream) throws IOException{ + MessageDigest messageDigest = getDigest(algorithm); + if(inputStream instanceof UpdateMessageDigestInputStream){ + ((UpdateMessageDigestInputStream) inputStream).updateMessageDigest(messageDigest); + return messageDigest.digest(); + }else{ + return messageDigest.digest(StreamUtils.copyToByteArray(inputStream)); + } + } + private static String digestAsHexString(String algorithm, byte[] bytes) { char[] hexDigest = digestAsHexChars(algorithm, bytes); return new String(hexDigest); } + private static String digestAsHexString(String algorithm, InputStream inputStream) throws IOException{ + char[] hexDigest = digestAsHexChars(algorithm, inputStream); + return new String(hexDigest); + } + private static StringBuilder appendDigestAsHex(String algorithm, byte[] bytes, StringBuilder builder) { char[] hexDigest = digestAsHexChars(algorithm, bytes); return builder.append(hexDigest); } + private static StringBuilder appendDigestAsHex(String algorithm, InputStream inputStream, StringBuilder builder) throws IOException{ + char[] hexDigest = digestAsHexChars(algorithm, inputStream); + return builder.append(hexDigest); + } + private static char[] digestAsHexChars(String algorithm, byte[] bytes) { byte[] digest = digest(algorithm, bytes); return encodeHex(digest); } + private static char[] digestAsHexChars(String algorithm, InputStream inputStream) throws IOException{ + byte[] digest = digest(algorithm, inputStream); + return encodeHex(digest); + } + private static char[] encodeHex(byte[] bytes) { char chars[] = new char[32]; for (int i = 0; i < chars.length; i = i + 2) { diff --git a/spring-core/src/main/java/org/springframework/util/FastByteArrayOutputStream.java b/spring-core/src/main/java/org/springframework/util/FastByteArrayOutputStream.java new file mode 100644 index 0000000000..6c6eb46e3a --- /dev/null +++ b/spring-core/src/main/java/org/springframework/util/FastByteArrayOutputStream.java @@ -0,0 +1,544 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.MessageDigest; +import java.util.Iterator; +import java.util.LinkedList; + +/** + * A speedy alternative to {@link java.io.ByteArrayOutputStream}. + *
Unlike {@link java.io.ByteArrayOutputStream}, this implementation is backed by a + * {@link java.util.LinkedList} of byte[] instead of 1 constantly resizing byte[]. + * It does not copy buffers when it's expanded.
+ * + *The initial buffer is only created when the stream is first written. + * There's also no copying of the internal buffer if its contents is extracted with the + * {@link #writeTo(OutputStream)} method. + * Instances of this class are NOT THREAD SAFE.
+ * + * @author Craig Andrews + * @since 4.2 + */ +public final class FastByteArrayOutputStream extends OutputStream { + + private static final int DEFAULT_BLOCK_SIZE = 256; + + // the buffers used to store the content bytes + private final LinkedListFastByteArrayOutputStream
+ * with the default initial capacity of {@value #DEFAULT_BLOCK_SIZE} bytes.
+ */
+ public FastByteArrayOutputStream() {
+ this(DEFAULT_BLOCK_SIZE);
+ }
+
+ /**
+ * Create a new FastByteArrayOutputStream
+ * with the specified initial capacity.
+ * @param initialBlockSize the initial buffer size in bytes
+ */
+ public FastByteArrayOutputStream(int initialBlockSize) {
+ Assert.isTrue(initialBlockSize > 0, "Initial block size must be greater than 0");
+ this.initialBlockSize = initialBlockSize;
+ nextBlockSize = initialBlockSize;
+ }
+
+ @Override
+ public void write(int datum) throws IOException {
+ if (closed) {
+ throw new IOException("Stream closed");
+ }
+ else {
+ if (buffers.peekLast() == null || buffers.getLast().length == index) {
+ addBuffer(1);
+ }
+ // store the byte
+ buffers.getLast()[index++] = (byte) datum;
+ }
+ }
+
+ @Override
+ public void write(byte[] data, int offset, int length) throws IOException {
+ if (data == null) {
+ throw new NullPointerException();
+ }
+ else if ((offset < 0) || ((offset + length) > data.length) || (length < 0)) {
+ throw new IndexOutOfBoundsException();
+ }
+ else if (closed) {
+ throw new IOException("Stream closed");
+ }
+ else {
+ if (buffers.peekLast() == null || buffers.getLast().length == index) {
+ addBuffer(length);
+ }
+ if ((index + length) > buffers.getLast().length) {
+ do {
+ if (index == buffers.getLast().length) {
+ addBuffer(length);
+ }
+ int copyLength = buffers.getLast().length - index;
+ if (length < copyLength) {
+ copyLength = length;
+ }
+ System.arraycopy(data, offset, buffers.getLast(), index, copyLength);
+ offset += copyLength;
+ index += copyLength;
+ length -= copyLength;
+ } while (length > 0);
+ }
+ else {
+ // Copy in the subarray
+ System.arraycopy(data, offset, buffers.getLast(), index, length);
+ index += length;
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ closed = true;
+ }
+
+ /**
+ * Returns the number of bytes stored in this FastByteArrayOutputStream
+ */
+ public int size() {
+ return alreadyBufferedSize + index;
+ }
+
+ /**
+ * Convert the stream's data to a byte array and return the byte array.
+ *
+ * Also replaces the internal structures with the byte array to conserve memory: + * if the byte array is being made anyways, mind as well as use it. + * This approach also means that if this method is called twice without any writes in between, + * the second call is a no-op. + * This method is "unsafe" as it returns the internal buffer - callers should not modify the returned buffer.
+ * + * @return the current contents of this output stream, as a byte array. + * @see #size() + * @see #toByteArray() + */ + public byte[] toByteArrayUnsafe() { + int totalSize = size(); + if (totalSize == 0) { + return new byte[0]; + } + resize(totalSize); + return buffers.getFirst(); + } + + /** + * Creates a newly allocated byte array. + * + *Its size is the current + * size of this output stream and the valid contents of the buffer + * have been copied into it.
+ * + * @return the current contents of this output stream, as a byte array. + * @see #size() + * @see #toByteArrayUnsafe() + */ + public byte[] toByteArray() { + byte[] bytesUnsafe = toByteArrayUnsafe(); + byte[] ret = new byte[bytesUnsafe.length]; + System.arraycopy(bytesUnsafe, 0, ret, 0, bytesUnsafe.length); + return ret; + } + + /** + * Resets the contents of thisFastByteArrayOutputStreamInputStream
+ * All currently accumulated output in the output stream is discarded. + * The output stream can be used again.
+ */ + public void reset() { + buffers.clear(); + nextBlockSize = initialBlockSize; + closed = false; + index = 0; + alreadyBufferedSize = 0; + } + + /** + * Get an {@link java.io.InputStream} to retrieve the data in this OutputStream + * + *Note that if any methods are called on the OutputStream + * (including, but not limited to, any of the write methods, {@link #reset()}, + * {@link #toByteArray()}, and {@link #toByteArrayUnsafe()}) then the {@link java.io.InputStream}'s + * behavior is undefined.
+ * + * @return {@link java.io.InputStream} of the contents of thisFastByteArrayOutputStreamInputStream
+ */
+ public InputStream getInputStream() {
+ return new FastByteArrayOutputStreamInputStream(this);
+ }
+
+ /**
+ * Write the buffers content to the given OutputStream
+ *
+ * @param out the OutputStream to write to
+ */
+ public void writeTo(OutputStream out) throws IOException {
+ IteratorAdds a new buffer that can store at least {@code minCapacity} bytes
+ */ + private void addBuffer(int minCapacity) { + if (buffers.peekLast() != null) { + alreadyBufferedSize += index; + index = 0; + } + if (nextBlockSize < minCapacity) { + nextBlockSize = nextPowerOf2(minCapacity); + } + buffers.add(new byte[nextBlockSize]); + nextBlockSize *= 2; // block size doubles each time + } + + /** + * Get the next power of 2 of a number (ex, the next power of 2 of 119 is 128) + */ + private static final int nextPowerOf2(int val) { + val--; + val = (val >> 1) | val; + val = (val >> 2) | val; + val = (val >> 4) | val; + val = (val >> 8) | val; + val = (val >> 16) | val; + val++; + return val; + } + + /** + * Converts the buffer's contents into a string decoding bytes using the + * platform's default character set. The length of the new String + * is a function of the character set, and hence may not be equal to the + * size of the buffer. + * + *This method always replaces malformed-input and unmappable-character + * sequences with the default replacement string for the platform's + * default character set. The {@linkplain java.nio.charset.CharsetDecoder} + * class should be used when more control over the decoding process is + * required.
+ * + * @return String decoded from the buffer's contents. + */ + @Override + public String toString() { + return new String(toByteArrayUnsafe()); + } + + + /** + * An implementation of {@link java.io.InputStream} that reads fromFastByteArrayOutputStream
+ * Instances of this class are NOT THREAD SAFE.
+ */
+ private static final class FastByteArrayOutputStreamInputStream extends UpdateMessageDigestInputStream {
+ int totalBytesRead = 0;
+
+ int nextIndexInCurrentBuffer = 0;
+
+ final IteratorFastByteArrayOutputStreamInputStream backed
+ * by the given FastByteArrayOutputStream
+ */
+ public FastByteArrayOutputStreamInputStream(FastByteArrayOutputStream fastByteArrayOutputStream) {
+ this.fastByteArrayOutputStream = fastByteArrayOutputStream;
+ buffersIterator = fastByteArrayOutputStream.buffers.iterator();
+ if (buffersIterator.hasNext()) {
+ currentBuffer = buffersIterator.next();
+ if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) {
+ currentBufferLength = fastByteArrayOutputStream.index;
+ }
+ else {
+ currentBufferLength = currentBuffer.length;
+ }
+ }
+ else {
+ currentBuffer = null;
+ }
+ }
+
+ @Override
+ public int read() {
+ if (currentBuffer == null) {
+ // this stream doesn't have any data in it
+ return -1;
+ }
+ else {
+ if (nextIndexInCurrentBuffer < currentBufferLength) {
+ totalBytesRead++;
+ return currentBuffer[nextIndexInCurrentBuffer++];
+ }
+ else {
+ if (buffersIterator.hasNext()) {
+ currentBuffer = buffersIterator.next();
+ if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) {
+ currentBufferLength = fastByteArrayOutputStream.index;
+ }
+ else {
+ currentBufferLength = currentBuffer.length;
+ }
+ nextIndexInCurrentBuffer = 0;
+ }
+ else {
+ currentBuffer = null;
+ }
+ return read();
+ }
+ }
+ }
+
+ @Override
+ public int read(byte[] b) {
+ return read(b, 0, b.length);
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) {
+ if (b == null) {
+ throw new NullPointerException();
+ }
+ else if (off < 0 || len < 0 || len > b.length - off) {
+ throw new IndexOutOfBoundsException();
+ }
+ else if (len == 0) {
+ return 0;
+ }
+ else if (len < 0) {
+ throw new IllegalArgumentException("len must be 0 or greater: " + len);
+ }
+ else if (off < 0) {
+ throw new IllegalArgumentException("off must be 0 or greater: " + off);
+ }
+ else {
+ if (currentBuffer == null) {
+ // this stream doesn't have any data in it
+ return 0;
+ }
+ else {
+ if (nextIndexInCurrentBuffer < currentBufferLength) {
+ int bytesToCopy = Math.min(len, currentBufferLength - nextIndexInCurrentBuffer);
+ System.arraycopy(currentBuffer, nextIndexInCurrentBuffer, b, off, bytesToCopy);
+ totalBytesRead += bytesToCopy;
+ nextIndexInCurrentBuffer += bytesToCopy;
+ return bytesToCopy + read(b, off + bytesToCopy, len - bytesToCopy);
+ }
+ else {
+ if (buffersIterator.hasNext()) {
+ currentBuffer = buffersIterator.next();
+ if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) {
+ currentBufferLength = fastByteArrayOutputStream.index;
+ }
+ else {
+ currentBufferLength = currentBuffer.length;
+ }
+ nextIndexInCurrentBuffer = 0;
+ }
+ else {
+ currentBuffer = null;
+ }
+ return read(b, off, len);
+ }
+ }
+ }
+ }
+
+ @Override
+ public long skip(long n) throws IOException {
+ if (n > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("n exceeds maximum (" +
+ Integer.MAX_VALUE + "): " + n);
+ }
+ else if (n == 0) {
+ return 0;
+ }
+ else if (n < 0) {
+ throw new IllegalArgumentException("n must be 0 or greater: " + n);
+ }
+ int len = (int) n;
+ if (currentBuffer == null) {
+ // this stream doesn't have any data in it
+ return 0;
+ }
+ else {
+ if (nextIndexInCurrentBuffer < currentBufferLength) {
+ int bytesToSkip = Math.min(len, currentBufferLength - nextIndexInCurrentBuffer);
+ totalBytesRead += bytesToSkip;
+ nextIndexInCurrentBuffer += bytesToSkip;
+ return bytesToSkip + skip(len - bytesToSkip);
+ }
+ else {
+ if (buffersIterator.hasNext()) {
+ currentBuffer = buffersIterator.next();
+ if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) {
+ currentBufferLength = fastByteArrayOutputStream.index;
+ }
+ else {
+ currentBufferLength = currentBuffer.length;
+ }
+ nextIndexInCurrentBuffer = 0;
+ }
+ else {
+ currentBuffer = null;
+ }
+ return skip(len);
+ }
+ }
+ }
+
+ @Override
+ public int available() {
+ return fastByteArrayOutputStream.size() - totalBytesRead;
+ }
+
+ /**
+ * Update the message digest with the remaining bytes in this stream.
+ *
+ * @param messageDigest The message digest to update
+ */
+ public void updateMessageDigest(MessageDigest messageDigest) {
+ updateMessageDigest(messageDigest, available());
+ }
+
+ /**
+ * Update the message digest with the next len bytes in this stream.
+ * Avoids creating new byte arrays and use internal buffers for performance.
+ * @param messageDigest The message digest to update
+ * @param len how many bytes to read from this stream and use to update the message digest
+ */
+ public void updateMessageDigest(MessageDigest messageDigest, int len) {
+ if (currentBuffer == null) {
+ // this stream doesn't have any data in it
+ return;
+ }
+ else if (len == 0) {
+ return;
+ }
+ else if (len < 0) {
+ throw new IllegalArgumentException("len must be 0 or greater: " + len);
+ }
+ else {
+ if (nextIndexInCurrentBuffer < currentBufferLength) {
+ int bytesToCopy = Math.min(len, currentBufferLength - nextIndexInCurrentBuffer);
+ messageDigest.update(currentBuffer, nextIndexInCurrentBuffer, bytesToCopy);
+ nextIndexInCurrentBuffer += bytesToCopy;
+ updateMessageDigest(messageDigest, len - bytesToCopy);
+ }
+ else {
+ if (buffersIterator.hasNext()) {
+ currentBuffer = buffersIterator.next();
+ if (currentBuffer == fastByteArrayOutputStream.buffers.getLast()) {
+ currentBufferLength = fastByteArrayOutputStream.index;
+ }
+ else {
+ currentBufferLength = currentBuffer.length;
+ }
+ nextIndexInCurrentBuffer = 0;
+ }
+ else {
+ currentBuffer = null;
+ }
+ updateMessageDigest(messageDigest, len);
+ }
+ }
+ }
+ }
+}
diff --git a/spring-core/src/main/java/org/springframework/util/UpdateMessageDigestInputStream.java b/spring-core/src/main/java/org/springframework/util/UpdateMessageDigestInputStream.java
new file mode 100644
index 0000000000..8ae173627e
--- /dev/null
+++ b/spring-core/src/main/java/org/springframework/util/UpdateMessageDigestInputStream.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.util;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.security.MessageDigest;
+
+/**
+ * Extension of {@link java.io.InputStream} that allows for optimized
+ * implementations of message digesting.
+ *
+ * @author Craig Andrews
+ * @since 4.2
+ */
+public abstract class UpdateMessageDigestInputStream extends InputStream {
+
+ /**
+ * Update the message digest with the rest of the bytes in this stream
+ *
+ * Using this method is more optimized since it avoids creating new byte arrays for each call.
+ * + * @param messageDigest The message digest to update + * @throws IOException + */ + public void updateMessageDigest(MessageDigest messageDigest) throws IOException{ + int data; + while((data = read()) != -1){ + messageDigest.update((byte)data); + } + } + + /** + * Update the message digest with the next len bytes in this stream + * + *Using this method is more optimized since it avoids creating new byte arrays for each call.
+ * + * @param messageDigest The message digest to update + * @param len how many bytes to read from this stream and use to update the message digest + * @throws IOException + */ + public void updateMessageDigest(MessageDigest messageDigest, int len) throws IOException{ + int data; + int bytesRead = 0; + while(bytesRead < len && (data = read()) != -1){ + messageDigest.update((byte)data); + bytesRead++; + } + } +} diff --git a/spring-core/src/test/java/org/springframework/util/FastByteArrayOutputStreamTests.java b/spring-core/src/test/java/org/springframework/util/FastByteArrayOutputStreamTests.java new file mode 100644 index 0000000000..9eedeca8ee --- /dev/null +++ b/spring-core/src/test/java/org/springframework/util/FastByteArrayOutputStreamTests.java @@ -0,0 +1,196 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.util; + +import static org.junit.Assert.*; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import org.junit.Before; +import org.junit.Test; + +/** + * Test suite for {@link FastByteArrayOutputStream} + * @author Craig Andrews + */ +public class FastByteArrayOutputStreamTests { + + private static final int INITIAL_CAPACITY = 256; + + private FastByteArrayOutputStream os; + + private byte[] helloBytes; + + @Before + public void setUp() throws Exception { + this.os = new FastByteArrayOutputStream(INITIAL_CAPACITY); + this.helloBytes = "Hello World".getBytes("UTF-8"); + } + + @Test + public void size() throws Exception { + this.os.write(helloBytes); + assertEquals(this.os.size(), helloBytes.length); + } + + @Test + public void resize() throws Exception { + this.os.write(helloBytes); + int sizeBefore = this.os.size(); + this.os.resize(64); + assertByteArrayEqualsString(this.os); + assertEquals(sizeBefore, this.os.size()); + } + + @Test + public void autoGrow() throws IOException { + this.os.resize(1); + for (int i = 0; i < 10; i++) { + this.os.write(1); + } + assertEquals(10, this.os.size()); + assertArrayEquals(this.os.toByteArray(), new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + } + + @Test + public void write() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + } + + @Test + public void reset() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + this.os.reset(); + assertEquals(0, this.os.size()); + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + } + + @Test(expected = IOException.class) + public void close() throws Exception { + this.os.close(); + this.os.write(helloBytes); + } + + @Test + public void toByteArrayUnsafe() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + assertSame(this.os.toByteArrayUnsafe(), this.os.toByteArrayUnsafe()); + assertArrayEquals(this.os.toByteArray(), helloBytes); + } + + @Test + public void writeTo() throws Exception { + this.os.write(helloBytes); + assertByteArrayEqualsString(this.os); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + this.os.writeTo(baos); + assertArrayEquals(baos.toByteArray(), helloBytes); + } + + @Test(expected = IllegalArgumentException.class) + public void failResize() throws Exception { + this.os.write(helloBytes); + this.os.resize(5); + } + + @Test + public void getInputStream() throws Exception { + this.os.write(helloBytes); + assertNotNull(this.os.getInputStream()); + } + + @Test + public void getInputStreamAvailable() throws Exception { + this.os.write(helloBytes); + assertEquals(this.os.getInputStream().available(), helloBytes.length); + } + + @Test + public void getInputStreamRead() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + assertEquals(inputStream.read(), helloBytes[0]); + assertEquals(inputStream.read(), helloBytes[1]); + assertEquals(inputStream.read(), helloBytes[2]); + assertEquals(inputStream.read(), helloBytes[3]); + } + + @Test + public void getInputStreamReadAll() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + byte[] actual = new byte[inputStream.available()]; + int bytesRead = inputStream.read(actual); + assertEquals(bytesRead, helloBytes.length); + assertArrayEquals(actual, helloBytes); + assertEquals(0, inputStream.available()); + } + + @Test + public void getInputStreamSkip() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + assertEquals(inputStream.read(), helloBytes[0]); + assertEquals(inputStream.skip(1), 1); + assertEquals(inputStream.read(), helloBytes[2]); + assertEquals(helloBytes.length - 3, inputStream.available()); + } + + @Test + public void getInputStreamSkipAll() throws Exception { + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + assertEquals(inputStream.skip(1000), helloBytes.length); + assertEquals(0, inputStream.available()); + } + + @Test + public void updateMessageDigest() throws Exception { + StringBuilder builder = new StringBuilder("\"0"); + this.os.write(helloBytes); + InputStream inputStream = this.os.getInputStream(); + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + builder.append("\""); + String actual = builder.toString(); + assertEquals("\"0b10a8db164e0754105b7a99be72e3fe5\"", actual); + } + + @Test + public void updateMessageDigestManyBuffers() throws Exception { + StringBuilder builder = new StringBuilder("\"0"); + // filling at least one 256 buffer + for( int i=0; i < 30; i++) { + this.os.write(helloBytes); + } + InputStream inputStream = this.os.getInputStream(); + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + builder.append("\""); + String actual = builder.toString(); + assertEquals("\"06225ca1e4533354c516e74512065331d\"", actual); + } + + private void assertByteArrayEqualsString(FastByteArrayOutputStream actual) { + assertArrayEquals(helloBytes, actual.toByteArray()); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java index 8e8fa9d028..31e66eddf9 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.filter; import java.io.IOException; +import java.io.InputStream; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; @@ -26,7 +27,6 @@ import org.springframework.http.HttpMethod; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.DigestUtils; -import org.springframework.util.StreamUtils; import org.springframework.web.util.ContentCachingResponseWrapper; import org.springframework.web.util.WebUtils; @@ -93,15 +93,12 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { HttpServletResponse rawResponse = (HttpServletResponse) responseWrapper.getResponse(); int statusCode = responseWrapper.getStatusCode(); - byte[] body = responseWrapper.getContentAsByteArray(); if (rawResponse.isCommitted()) { - if (body.length > 0) { - StreamUtils.copy(body, rawResponse.getOutputStream()); - } + responseWrapper.copyBodyToResponse(); } - else if (isEligibleForEtag(request, responseWrapper, statusCode, body)) { - String responseETag = generateETagHeaderValue(body); + else if (isEligibleForEtag(request, responseWrapper, statusCode, responseWrapper.getContentInputStream())) { + String responseETag = generateETagHeaderValue(responseWrapper.getContentInputStream()); rawResponse.setHeader(HEADER_ETAG, responseETag); String requestETag = request.getHeader(HEADER_IF_NONE_MATCH); if (responseETag.equals(requestETag)) { @@ -115,20 +112,14 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { logger.trace("ETag [" + responseETag + "] not equal to If-None-Match [" + requestETag + "], sending normal response"); } - if (body.length > 0) { - rawResponse.setContentLength(body.length); - StreamUtils.copy(body, rawResponse.getOutputStream()); - } + responseWrapper.copyBodyToResponse(); } } else { if (logger.isTraceEnabled()) { logger.trace("Response with status code [" + statusCode + "] not eligible for ETag"); } - if (body.length > 0) { - rawResponse.setContentLength(body.length); - StreamUtils.copy(body, rawResponse.getOutputStream()); - } + responseWrapper.copyBodyToResponse(); } } @@ -143,11 +134,11 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { * @param request the HTTP request * @param response the HTTP response * @param responseStatusCode the HTTP response status code - * @param responseBody the response body + * @param inputStream the response body * @return {@code true} if eligible for ETag generation; {@code false} otherwise */ protected boolean isEligibleForEtag(HttpServletRequest request, HttpServletResponse response, - int responseStatusCode, byte[] responseBody) { + int responseStatusCode, InputStream inputStream) { if (responseStatusCode >= 200 && responseStatusCode < 300 && HttpMethod.GET.name().equals(request.getMethod())) { @@ -162,13 +153,18 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { /** * Generate the ETag header value from the given response body byte array. *The default implementation generates an MD5 hash. - * @param bytes the response body as byte array + * @param inputStream the response body as an InputStream * @return the ETag header value * @see org.springframework.util.DigestUtils */ - protected String generateETagHeaderValue(byte[] bytes) { + protected String generateETagHeaderValue(InputStream inputStream) { StringBuilder builder = new StringBuilder("\"0"); - DigestUtils.appendMd5DigestAsHex(bytes, builder); + try { + DigestUtils.appendMd5DigestAsHex(inputStream, builder); + } + catch (IOException e) { + throw new RuntimeException(e); + } builder.append('"'); return builder.toString(); } diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index 3eee67e72c..f28d5b2b52 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.web.util; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; @@ -24,8 +25,7 @@ import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; -import org.springframework.util.ResizableByteArrayOutputStream; -import org.springframework.util.StreamUtils; +import org.springframework.util.FastByteArrayOutputStream; /** * {@link javax.servlet.http-HttpServletResponse} wrapper that caches all content written to @@ -39,7 +39,7 @@ import org.springframework.util.StreamUtils; */ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { - private final ResizableByteArrayOutputStream content = new ResizableByteArrayOutputStream(1024); + private final FastByteArrayOutputStream content = new FastByteArrayOutputStream(1024); private final ServletOutputStream outputStream = new ResponseServletOutputStream(); @@ -107,9 +107,7 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @Override public void setContentLength(int len) { - if (len > this.content.capacity()) { - this.content.resize(len); - } + this.content.resize(len); } // Overrides Servlet 3.1 setContentLengthLong(long) at runtime @@ -118,16 +116,12 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { throw new IllegalArgumentException("Content-Length exceeds ShallowEtagHeaderFilter's maximum (" + Integer.MAX_VALUE + "): " + len); } - if (len > this.content.capacity()) { - this.content.resize((int) len); - } + this.content.resize((int) len); } @Override public void setBufferSize(int size) { - if (size > this.content.capacity()) { - this.content.resize(size); - } + this.content.resize(size); } @Override @@ -142,7 +136,7 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { } /** - * Return the status code as specifed on the response. + * Return the status code as specified on the response. */ public int getStatusCode() { return this.statusCode; @@ -155,14 +149,24 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { return this.content.toByteArray(); } - private void copyBodyToResponse() throws IOException { + public void copyBodyToResponse() throws IOException { if (this.content.size() > 0) { - getResponse().setContentLength(this.content.size()); - StreamUtils.copy(this.content.toByteArray(), getResponse().getOutputStream()); + HttpServletResponse rawResponse = (HttpServletResponse) getResponse(); + if(! rawResponse.isCommitted()){ + rawResponse.setContentLength(this.content.size()); + } + this.content.writeTo(rawResponse.getOutputStream()); this.content.reset(); } } + public int getContentSize(){ + return this.content.size(); + } + + public InputStream getContentInputStream(){ + return this.content.getInputStream(); + } private class ResponseServletOutputStream extends ServletOutputStream { diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java index 74d501147b..da796ac3f8 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.web.filter; +import java.io.ByteArrayInputStream; import java.io.IOException; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -51,15 +52,15 @@ public class ShallowEtagHeaderFilterTests { MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); MockHttpServletResponse response = new MockHttpServletResponse(); - assertTrue(filter.isEligibleForEtag(request, response, 200, new byte[0])); - assertFalse(filter.isEligibleForEtag(request, response, 300, new byte[0])); + assertTrue(filter.isEligibleForEtag(request, response, 200, new ByteArrayInputStream(new byte[0]))); + assertFalse(filter.isEligibleForEtag(request, response, 300, new ByteArrayInputStream(new byte[0]))); request = new MockHttpServletRequest("POST", "/hotels"); - assertFalse(filter.isEligibleForEtag(request, response, 200, new byte[0])); + assertFalse(filter.isEligibleForEtag(request, response, 200, new ByteArrayInputStream(new byte[0]))); request = new MockHttpServletRequest("POST", "/hotels"); request.addHeader("Cache-Control","must-revalidate, no-store"); - assertFalse(filter.isEligibleForEtag(request, response, 200, new byte[0])); + assertFalse(filter.isEligibleForEtag(request, response, 200, new ByteArrayInputStream(new byte[0]))); } @Test