Support for @RequestParam Map declared with MultipartFile/Part values

Issue: SPR-17405
This commit is contained in:
Juergen Hoeller
2018-10-24 20:44:58 +02:00
parent 488a1d4561
commit f0f1979fc5
4 changed files with 247 additions and 96 deletions

View File

@@ -16,10 +16,14 @@
package org.springframework.web.method.annotation;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
@@ -29,22 +33,29 @@ import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartRequest;
import org.springframework.web.multipart.support.MultipartResolutionDelegate;
/**
* Resolves {@link Map} method arguments annotated with an @{@link RequestParam}
* where the annotation does not specify a request parameter name.
* See {@link RequestParamMethodArgumentResolver} for resolving {@link Map}
* method arguments with a request parameter name.
*
* <p>The created {@link Map} contains all request parameter name/value pairs.
* If the method parameter type is {@link MultiValueMap} instead, the created
* map contains all request parameters and all there values for cases where
* request parameters have multiple values.
* <p>The created {@link Map} contains all request parameter name/value pairs,
* or all multipart files for a given parameter name if specifically declared
* with {@link MultipartFile} as the value type. If the method parameter type is
* {@link MultiValueMap} instead, the created map contains all request parameters
* and all their values for cases where request parameters have multiple values
* (or multiple multipart files of the same name).
*
* @author Arjen Poutsma
* @author Rossen Stoyanchev
* @author Juergen Hoeller
* @since 3.1
* @see RequestParamMethodArgumentResolver
* @see HttpServletRequest#getParameterMap()
* @see MultipartRequest#getMultiFileMap()
* @see MultipartRequest#getFileMap()
*/
public class RequestParamMapMethodArgumentResolver implements HandlerMethodArgumentResolver {
@@ -59,26 +70,71 @@ public class RequestParamMapMethodArgumentResolver implements HandlerMethodArgum
public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {
Class<?> paramType = parameter.getParameterType();
ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter);
Map<String, String[]> parameterMap = webRequest.getParameterMap();
if (MultiValueMap.class.isAssignableFrom(paramType)) {
MultiValueMap<String, String> result = new LinkedMultiValueMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
for (String value : values) {
result.add(key, value);
if (MultiValueMap.class.isAssignableFrom(parameter.getParameterType())) {
// MultiValueMap
Class<?> valueType = resolvableType.as(MultiValueMap.class).getGeneric(1).resolve();
if (valueType == MultipartFile.class) {
MultipartRequest multipartRequest = MultipartResolutionDelegate.resolveMultipartRequest(webRequest);
return (multipartRequest != null ? multipartRequest.getMultiFileMap() : new LinkedMultiValueMap<>(0));
}
else if (valueType == Part.class) {
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
if (servletRequest != null && MultipartResolutionDelegate.isMultipartRequest(servletRequest)) {
Collection<Part> parts = servletRequest.getParts();
LinkedMultiValueMap<String, Part> result = new LinkedMultiValueMap<>(parts.size());
for (Part part : parts) {
result.add(part.getName(), part);
}
return result;
}
});
return result;
return new LinkedMultiValueMap<>(0);
}
else {
Map<String, String[]> parameterMap = webRequest.getParameterMap();
MultiValueMap<String, String> result = new LinkedMultiValueMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
for (String value : values) {
result.add(key, value);
}
});
return result;
}
}
else {
Map<String, String> result = new LinkedHashMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
if (values.length > 0) {
result.put(key, values[0]);
// Regular Map
Class<?> valueType = resolvableType.asMap().getGeneric(1).resolve();
if (valueType == MultipartFile.class) {
MultipartRequest multipartRequest = MultipartResolutionDelegate.resolveMultipartRequest(webRequest);
return (multipartRequest != null ? multipartRequest.getFileMap() : new LinkedHashMap<>(0));
}
else if (valueType == Part.class) {
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
if (servletRequest != null && MultipartResolutionDelegate.isMultipartRequest(servletRequest)) {
Collection<Part> parts = servletRequest.getParts();
LinkedHashMap<String, Part> result = new LinkedHashMap<>(parts.size());
for (Part part : parts) {
if (!result.containsKey(part.getName())) {
result.put(part.getName(), part);
}
}
return result;
}
});
return result;
return new LinkedHashMap<>(0);
}
else {
Map<String, String[]> parameterMap = webRequest.getParameterMap();
Map<String, String> result = new LinkedHashMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
if (values.length > 0) {
result.put(key, values[0]);
}
});
return result;
}
}
}
}

View File

@@ -25,8 +25,10 @@ import javax.servlet.http.Part;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.MultipartRequest;
import org.springframework.web.util.WebUtils;
/**
@@ -44,6 +46,19 @@ public abstract class MultipartResolutionDelegate {
public static final Object UNRESOLVABLE = new Object();
@Nullable
public static MultipartRequest resolveMultipartRequest(NativeWebRequest webRequest) {
MultipartRequest multipartRequest = webRequest.getNativeRequest(MultipartRequest.class);
if (multipartRequest != null) {
return multipartRequest;
}
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
if (servletRequest != null && isMultipartContent(servletRequest)) {
return new StandardMultipartHttpServletRequest(servletRequest);
}
return null;
}
public static boolean isMultipartRequest(HttpServletRequest request) {
return (WebUtils.getNativeRequest(request, MultipartHttpServletRequest.class) != null ||
isMultipartContent(request));
@@ -103,13 +118,13 @@ public abstract class MultipartResolutionDelegate {
}
}
else if (Part.class == parameter.getNestedParameterType()) {
return (isMultipart ? resolvePart(request, name) : null);
return (isMultipart ? request.getPart(name): null);
}
else if (isPartCollection(parameter)) {
return (isMultipart ? resolvePartList(request, name) : null);
}
else if (isPartArray(parameter)) {
return (isMultipart ? resolvePartArray(request, name) : null);
return (isMultipart ? resolvePartList(request, name).toArray(new Part[0]) : null);
}
else {
return UNRESOLVABLE;
@@ -144,12 +159,8 @@ public abstract class MultipartResolutionDelegate {
return null;
}
private static Part resolvePart(HttpServletRequest servletRequest, String name) throws Exception {
return servletRequest.getPart(name);
}
private static List<Part> resolvePartList(HttpServletRequest servletRequest, String name) throws Exception {
Collection<Part> parts = servletRequest.getParts();
private static List<Part> resolvePartList(HttpServletRequest request, String name) throws Exception {
Collection<Part> parts = request.getParts();
List<Part> result = new ArrayList<>(parts.size());
for (Part part : parts) {
if (part.getName().equals(name)) {
@@ -159,15 +170,4 @@ public abstract class MultipartResolutionDelegate {
return result;
}
private static Part[] resolvePartArray(HttpServletRequest servletRequest, String name) throws Exception {
Collection<Part> parts = servletRequest.getParts();
List<Part> result = new ArrayList<>(parts.size());
for (Part part : parts) {
if (part.getName().equals(name)) {
result.add(part);
}
}
return result.toArray(new Part[0]);
}
}