diff --git a/spring-restdocs-mockmvc/src/main/java/org/springframework/restdocs/mockmvc/MockMvcOperationRequestFactory.java b/spring-restdocs-mockmvc/src/main/java/org/springframework/restdocs/mockmvc/MockMvcOperationRequestFactory.java index d3236e08..4d1c4a9f 100644 --- a/spring-restdocs-mockmvc/src/main/java/org/springframework/restdocs/mockmvc/MockMvcOperationRequestFactory.java +++ b/spring-restdocs-mockmvc/src/main/java/org/springframework/restdocs/mockmvc/MockMvcOperationRequestFactory.java @@ -86,37 +86,57 @@ class MockMvcOperationRequestFactory { private List extractParts(MockHttpServletRequest servletRequest) throws IOException, ServletException { List parts = new ArrayList<>(); - for (Part part : servletRequest.getParts()) { - HttpHeaders partHeaders = extractHeaders(part); - List contentTypeHeader = partHeaders.get(HttpHeaders.CONTENT_TYPE); - if (part.getContentType() != null && contentTypeHeader == null) { - partHeaders - .setContentType(MediaType.parseMediaType(part.getContentType())); - } - parts.add(new StandardOperationRequestPart(part.getName(), StringUtils - .hasText(part.getSubmittedFileName()) ? part.getSubmittedFileName() - : null, FileCopyUtils.copyToByteArray(part.getInputStream()), - partHeaders)); - } + parts.addAll(extractServletRequestParts(servletRequest)); if (servletRequest instanceof MockMultipartHttpServletRequest) { - for (Entry> entry : ((MockMultipartHttpServletRequest) servletRequest) - .getMultiFileMap().entrySet()) { - for (MultipartFile file : entry.getValue()) { - HttpHeaders partHeaders = new HttpHeaders(); - if (StringUtils.hasText(file.getContentType())) { - partHeaders.setContentType(MediaType.parseMediaType(file - .getContentType())); - } - parts.add(new StandardOperationRequestPart(file.getName(), - StringUtils.hasText(file.getOriginalFilename()) ? file - .getOriginalFilename() : null, file.getBytes(), - partHeaders)); - } + parts.addAll(extractMultipartRequestParts((MockMultipartHttpServletRequest) servletRequest)); + } + return parts; + } + + private List extractServletRequestParts( + MockHttpServletRequest servletRequest) throws IOException, ServletException { + List parts = new ArrayList<>(); + for (Part part : servletRequest.getParts()) { + parts.add(createOperationRequestPart(part)); + } + return parts; + } + + private StandardOperationRequestPart createOperationRequestPart(Part part) + throws IOException { + HttpHeaders partHeaders = extractHeaders(part); + List contentTypeHeader = partHeaders.get(HttpHeaders.CONTENT_TYPE); + if (part.getContentType() != null && contentTypeHeader == null) { + partHeaders.setContentType(MediaType.parseMediaType(part.getContentType())); + } + return new StandardOperationRequestPart(part.getName(), StringUtils.hasText(part + .getSubmittedFileName()) ? part.getSubmittedFileName() : null, + FileCopyUtils.copyToByteArray(part.getInputStream()), partHeaders); + } + + private List extractMultipartRequestParts( + MockMultipartHttpServletRequest multipartRequest) throws IOException { + List parts = new ArrayList<>(); + for (Entry> entry : multipartRequest + .getMultiFileMap().entrySet()) { + for (MultipartFile file : entry.getValue()) { + parts.add(createOperationRequestPart(file)); } } return parts; } + private StandardOperationRequestPart createOperationRequestPart(MultipartFile file) + throws IOException { + HttpHeaders partHeaders = new HttpHeaders(); + if (StringUtils.hasText(file.getContentType())) { + partHeaders.setContentType(MediaType.parseMediaType(file.getContentType())); + } + return new StandardOperationRequestPart(file.getName(), StringUtils.hasText(file + .getOriginalFilename()) ? file.getOriginalFilename() : null, + file.getBytes(), partHeaders); + } + private HttpHeaders extractHeaders(Part part) { HttpHeaders partHeaders = new HttpHeaders(); for (String headerName : part.getHeaderNames()) {