Parse parts in MockMultipartHttpServletRequestBuilder

Closes gh-26261
This commit is contained in:
Rossen Stoyanchev
2020-12-14 21:13:01 +00:00
parent 17e6cf1cc1
commit bcfbde9848
3 changed files with 68 additions and 34 deletions

View File

@@ -35,7 +35,6 @@ import org.springframework.web.context.request.async.CallableProcessingIntercept
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView;
@@ -68,10 +67,6 @@ final class TestDispatcherServlet extends DispatcherServlet {
protected void service(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (!request.getParts().isEmpty()) {
request = new StandardMultipartHttpServletRequest(request);
}
registerAsyncResultInterceptors(request);
super.service(request, response);

View File

@@ -17,7 +17,10 @@
package org.springframework.test.web.servlet.request;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
@@ -25,17 +28,17 @@ import java.util.List;
import javax.servlet.ServletContext;
import javax.servlet.http.Part;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.mock.web.MockPart;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.multipart.MultipartFile;
/**
* Default builder for {@link MockMultipartHttpServletRequest}.
@@ -141,26 +144,47 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque
@Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
this.files.forEach(file -> request.addPart(toMockPart(file)));
this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart);
this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
request.addPart(part);
try {
MultipartFile file = asMultipartFile(part);
if (file != null) {
request.addFile(file);
return;
}
String value = toParameterValue(part);
if (value != null) {
request.addParameter(part.getName(), toParameterValue(part));
}
}
catch (IOException ex) {
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
}
});
return request;
}
private MockPart toMockPart(MockMultipartFile file) {
byte[] bytes = null;
if (!file.isEmpty()) {
try {
bytes = file.getBytes();
}
catch (IOException ex) {
throw new IllegalStateException("Unexpected IOException", ex);
}
@Nullable
private MultipartFile asMultipartFile(Part part) throws IOException {
String name = part.getName();
String filename = part.getSubmittedFileName();
if (filename != null) {
return new MockMultipartFile(name, filename, part.getContentType(), part.getInputStream());
}
MockPart part = new MockPart(file.getName(), file.getOriginalFilename(), bytes);
if (file.getContentType() != null) {
part.getHeaders().set(HttpHeaders.CONTENT_TYPE, file.getContentType());
return null;
}
@Nullable
private String toParameterValue(Part part) throws IOException {
String rawType = part.getContentType();
MediaType mediaType = (rawType != null ? MediaType.parseMediaType(rawType) : MediaType.TEXT_PLAIN);
if (!mediaType.isCompatibleWith(MediaType.TEXT_PLAIN)) {
return null;
}
return part;
Charset charset = (mediaType.getCharset() != null ? mediaType.getCharset() : StandardCharsets.UTF_8);
InputStreamReader reader = new InputStreamReader(part.getInputStream(), charset);
return FileCopyUtils.copyToString(reader);
}
}