diff --git a/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java b/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java index d80945e4bd..36fa283085 100644 --- a/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java +++ b/spring-web/src/main/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequest.java @@ -16,10 +16,13 @@ package org.springframework.web.multipart.support; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Enumeration; import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; import javax.servlet.http.HttpServletRequest; @@ -27,6 +30,7 @@ import javax.servlet.http.HttpServletRequest; import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; import org.springframework.web.multipart.MultipartFile; /** @@ -90,11 +94,20 @@ public class DefaultMultipartHttpServletRequest extends AbstractMultipartHttpSer @Override public String[] getParameterValues(String name) { - String[] values = getMultipartParameters().get(name); - if (values != null) { - return values; + String[] parameterValues = super.getParameterValues(name); + String[] mpValues = getMultipartParameters().get(name); + if (mpValues == null) { + return parameterValues; + } + if (parameterValues == null || getQueryString() == null) { + return mpValues; + } + else { + String[] result = new String[mpValues.length + parameterValues.length]; + System.arraycopy(mpValues, 0, result, 0, mpValues.length); + System.arraycopy(parameterValues, 0, result, mpValues.length, parameterValues.length); + return result; } - return super.getParameterValues(name); } @Override @@ -105,25 +118,20 @@ public class DefaultMultipartHttpServletRequest extends AbstractMultipartHttpSer } Set paramNames = new LinkedHashSet<>(); - Enumeration paramEnum = super.getParameterNames(); - while (paramEnum.hasMoreElements()) { - paramNames.add(paramEnum.nextElement()); - } + paramNames.addAll(Collections.list(super.getParameterNames())); paramNames.addAll(multipartParameters.keySet()); return Collections.enumeration(paramNames); } @Override public Map getParameterMap() { - Map multipartParameters = getMultipartParameters(); - if (multipartParameters.isEmpty()) { - return super.getParameterMap(); + Map result = new LinkedHashMap<>(); + Enumeration names = getParameterNames(); + while (names.hasMoreElements()) { + String name = names.nextElement(); + result.put(name, getParameterValues(name)); } - - Map paramMap = new LinkedHashMap<>(); - paramMap.putAll(super.getParameterMap()); - paramMap.putAll(multipartParameters); - return paramMap; + return result; } @Override diff --git a/spring-web/src/test/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequestTests.java b/spring-web/src/test/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequestTests.java new file mode 100644 index 0000000000..fc2b964c88 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/multipart/support/DefaultMultipartHttpServletRequestTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.multipart.support; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.mock.web.test.MockHttpServletRequest; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link DefaultMultipartHttpServletRequest}. + * + * @author Rossen Stoyanchev + */ +public class DefaultMultipartHttpServletRequestTests { + + private final MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + + private final Map multipartParams = new LinkedHashMap<>(); + + private final MultiValueMap queryParams = new LinkedMultiValueMap<>(); + + + @Test // SPR-16590 + public void parameterValues() { + + this.multipartParams.put("key", new String[] {"p"}); + this.queryParams.add("key", "q"); + + String[] values = createMultipartRequest().getParameterValues("key"); + + assertArrayEquals(new String[] {"p", "q"}, values); + } + + @Test // SPR-16590 + public void parameterMap() { + + this.multipartParams.put("key1", new String[] {"p1"}); + this.multipartParams.put("key2", new String[] {"p2"}); + + this.queryParams.add("key1", "q1"); + this.queryParams.add("key3", "q3"); + + Map map = createMultipartRequest().getParameterMap(); + + assertEquals(3, map.size()); + assertArrayEquals(new String[] {"p1", "q1"}, map.get("key1")); + assertArrayEquals(new String[] {"p2"}, map.get("key2")); + assertArrayEquals(new String[] {"q3"}, map.get("key3")); + } + + private DefaultMultipartHttpServletRequest createMultipartRequest() { + insertQueryParams(); + return new DefaultMultipartHttpServletRequest(this.servletRequest, new LinkedMultiValueMap<>(), + this.multipartParams, new HashMap<>()); + } + + private void insertQueryParams() { + StringBuilder query = new StringBuilder(); + for (String key : this.queryParams.keySet()) { + for (String value : this.queryParams.get(key)) { + this.servletRequest.addParameter(key, value); + query.append(query.length() > 0 ? "&" : "").append(key).append("=").append(value); + } + } + this.servletRequest.setQueryString(query.toString()); + } + +}