Skip to content

Commit 8e6b6c5

Browse files
committed
Refine error handling in MultipartParser
Closes gh-36947
1 parent 7b31e0c commit 8e6b6c5

4 files changed

Lines changed: 96 additions & 35 deletions

File tree

spring-web/src/main/java/org/springframework/http/converter/multipart/MultipartParser.java

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.http.converter.multipart;
1818

19-
import java.io.IOException;
2019
import java.io.InputStream;
2120
import java.nio.charset.Charset;
2221
import java.nio.charset.StandardCharsets;
@@ -81,10 +80,28 @@ private MultipartParser(
8180
}
8281

8382

83+
/**
84+
* Simple delegation to {@link State#data(DataBuffer)}.
85+
*/
8486
void handleData(DataBuffer dataBuffer) {
8587
this.state.data(dataBuffer);
8688
}
8789

90+
/**
91+
* Handle a parsing error cleaning resources in the state and in the listener
92+
* unless the error is {@link HttpMessageConversionException} in which case
93+
* it is simply propagated.
94+
*/
95+
void handleException(String message, @Nullable Throwable cause) {
96+
if (cause instanceof HttpMessageConversionException ex) {
97+
throw ex;
98+
}
99+
changeState(DisposedState.INSTANCE, null);
100+
HttpMessageConversionException ex = new HttpMessageConversionException(message, cause);
101+
this.listener.onError(ex);
102+
throw ex;
103+
}
104+
88105
private void changeState(State newState, @Nullable DataBuffer remainder) {
89106
if (logger.isTraceEnabled()) {
90107
logger.trace("Changed state: " + this.state + " -> " + newState);
@@ -118,6 +135,7 @@ private static byte[] concat(byte[]... byteArrays) {
118135
return result;
119136
}
120137

138+
121139
/**
122140
* Parse the given stream of bytes into events published to the given {@link PartListener}.
123141
* @param input the input stream
@@ -141,9 +159,8 @@ public static void parse(InputStream input, byte[] boundary, Charset headersChar
141159
}
142160
parser.state.complete();
143161
}
144-
catch (IOException ex) {
145-
parser.state.dispose();
146-
listener.onError(new HttpMessageConversionException("Could not decode multipart message", ex));
162+
catch (Throwable ex) {
163+
parser.handleException("Could not decode multipart message", ex);
147164
}
148165
}
149166

@@ -155,13 +172,19 @@ interface PartListener {
155172

156173
/**
157174
* Handle {@link HttpHeaders} for a part.
175+
* <p>Expectations for exception handling are the same as for {@link #onBody}.
158176
*/
159177
void onHeaders(HttpHeaders headers);
160178

161179
/**
162-
* Handle a piece of data for a body part.
180+
* Handle the next chunk of body data.
181+
* <p>Implementations must release the input buffer.
182+
* <p>Implementations must handle all exceptions, cleaning up resources,
183+
* and wrapping the exception as {@link HttpMessageConversionException}.
163184
* @param buffer a chunk of body
164-
* @param last whether this is the last chunk for the part
185+
* @param last whether this is the last chunk for the part
186+
* @throws HttpMessageConversionException if the buffer could not be
187+
* handled due to exceeded limits or for any other reason
165188
*/
166189
void onBody(DataBuffer buffer, boolean last);
167190

@@ -171,7 +194,12 @@ interface PartListener {
171194
void onComplete();
172195

173196
/**
174-
* Handle any error thrown during the parsing phase.
197+
* Handle any error thrown during the parsing phase. The purpose of the
198+
* method call is to allow cleaning up of resources. The listener does
199+
* not need to throw or wrap and throw the error.
200+
* <p>{@link #onHeaders} and {@link #onBody} are expected to handle their
201+
* own exceptions, i.e. any exception those methods throw will not be
202+
* passed here.
175203
*/
176204
void onError(Throwable error);
177205

@@ -194,10 +222,26 @@ interface PartListener {
194222
*/
195223
private interface State {
196224

225+
/**
226+
* Handle the next chunk of data.
227+
* <p>If this method raises any exception other than
228+
* {@link HttpMessageConversionException}, it will be
229+
* {@link MultipartParser#handleException handled} by the parser.
230+
* An {@link HttpMessageConversionException} on the other hand is
231+
* considered fully handled and allowed to propagate as is.
232+
*/
197233
void data(DataBuffer buf);
198234

235+
/**
236+
* Called when the current part is fully parsed.
237+
* <p>Expecations for exception handling are the same as for {@link #data}.
238+
*/
199239
void complete();
200240

241+
/**
242+
* Clean up resources held by the state. Called in case of errors or
243+
* when switching to a new state.
244+
*/
201245
default void dispose() {
202246
}
203247

@@ -241,8 +285,7 @@ public void data(DataBuffer buf) {
241285

242286
@Override
243287
public void complete() {
244-
changeState(DisposedState.INSTANCE, null);
245-
MultipartParser.this.listener.onError(new HttpMessageConversionException("Could not find first boundary"));
288+
handleException("Could not find first boundary", null);
246289
}
247290

248291
@Override
@@ -312,7 +355,14 @@ private void emitHeaders() {
312355
if (logger.isTraceEnabled()) {
313356
logger.trace("Emitting headers: " + headers);
314357
}
315-
MultipartParser.this.listener.onHeaders(headers);
358+
try {
359+
MultipartParser.this.listener.onHeaders(headers);
360+
}
361+
catch (Throwable ex) {
362+
// PartListener should have cleaned its state, clean our own
363+
dispose();
364+
throw ex;
365+
}
316366
}
317367

318368
/**
@@ -338,12 +388,9 @@ private boolean belowMaxHeaderSize(long count) {
338388
if (count <= MultipartParser.this.maxHeadersSize) {
339389
return true;
340390
}
341-
else {
342-
MultipartParser.this.listener.onError(
343-
new HttpMessageConversionException("Part headers exceeded the memory usage limit of " +
344-
MultipartParser.this.maxHeadersSize + " bytes"));
345-
return false;
346-
}
391+
MultipartParser.this.handleException(
392+
"Part headers exceeded the limit of " + MultipartParser.this.maxHeadersSize + " bytes", null);
393+
return false;
347394
}
348395

349396
/**
@@ -377,8 +424,7 @@ private HttpHeaders parseHeaders() {
377424

378425
@Override
379426
public void complete() {
380-
changeState(DisposedState.INSTANCE, null);
381-
MultipartParser.this.listener.onError(new HttpMessageConversionException("Could not find end of headers"));
427+
MultipartParser.this.handleException("Could not find end of headers", null);
382428
}
383429

384430
@Override
@@ -493,25 +539,32 @@ private void enqueue(DataBuffer buf) {
493539
}
494540
len += previous.readableByteCount();
495541
}
496-
emit.forEach(buffer -> MultipartParser.this.listener.onBody(buffer, false));
542+
emit.forEach(buffer -> invokeListener(buffer, false));
497543
}
498544

499545
private void flush() {
500546
for (Iterator<DataBuffer> iterator = this.queue.iterator(); iterator.hasNext(); ) {
501547
DataBuffer buffer = iterator.next();
502548
boolean last = !iterator.hasNext();
503-
MultipartParser.this.listener.onBody(buffer, last);
549+
invokeListener(buffer, last);
504550
}
505551
this.queue.clear();
506552
}
507553

554+
private void invokeListener(DataBuffer buffer, boolean last) {
555+
try {
556+
MultipartParser.this.listener.onBody(buffer, last);
557+
}
558+
catch (Throwable ex) {
559+
dispose();
560+
throw ex;
561+
}
562+
}
563+
508564
@Override
509565
public void complete() {
510-
changeState(DisposedState.INSTANCE, null);
511-
String msg = "Could not find end of body (␍␊--" +
512-
new String(MultipartParser.this.boundary, StandardCharsets.UTF_8) +
513-
")";
514-
MultipartParser.this.listener.onError(new HttpMessageConversionException(msg));
566+
MultipartParser.this.handleException("Could not find end of body (␍␊--" +
567+
new String(MultipartParser.this.boundary, StandardCharsets.UTF_8) + ")", null);
515568
}
516569

517570
@Override

spring-web/src/main/java/org/springframework/http/converter/multipart/PartGenerator.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ public void onHeaders(HttpHeaders headers) {
9696

9797
private static boolean isFormField(HttpHeaders headers) {
9898
MediaType contentType = headers.getContentType();
99-
return (contentType == null || MediaType.TEXT_PLAIN.equalsTypeAndSubtype(contentType)) &&
100-
headers.getContentDisposition().getFilename() == null;
99+
return ((contentType == null || MediaType.TEXT_PLAIN.equalsTypeAndSubtype(contentType)) &&
100+
headers.getContentDisposition().getFilename() == null);
101101
}
102102

103103
@Override
@@ -133,9 +133,8 @@ public void onComplete() {
133133
}
134134

135135
@Override
136-
public void onError(Throwable error) {
136+
public void onError(Throwable ex) {
137137
deleteParts();
138-
throw new HttpMessageConversionException("Cannot decode multipart body", error);
139138
}
140139

141140
void addPart(Part part) {
@@ -168,7 +167,7 @@ void addPart(Part part) {
168167
private interface State {
169168

170169
/**
171-
* Invoked when a {@link MultipartParser.PartListener#onBody(DataBuffer, boolean)} is received.
170+
* Invoked when the parser receives additional data.
172171
*/
173172
void onBody(DataBuffer dataBuffer, boolean last);
174173

spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartHttpMessageConverterTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ void readMultipartBrowser() throws Exception {
200200
void readMultipartInvalid() throws Exception {
201201
MockHttpInputMessage response = createMultipartResponse("garbage-1.multipart", "boundary");
202202
assertThatThrownBy(() -> converter.read(ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class), response, null))
203-
.isInstanceOf(HttpMessageConversionException.class).hasMessage("Cannot decode multipart body");
203+
.isInstanceOf(HttpMessageConversionException.class).hasMessage("Could not find first boundary");
204204
}
205205

206206
@Test

spring-web/src/test/java/org/springframework/http/converter/multipart/MultipartParserTests.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import static java.nio.charset.StandardCharsets.UTF_8;
4040
import static org.assertj.core.api.Assertions.assertThat;
41+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4142

4243
/**
4344
* Tests for {@link MultipartParser}.
@@ -74,31 +75,39 @@ void noHeaders() throws Exception {
7475
@Test
7576
void noEndBoundary() throws Exception {
7677
TestListener listener = new TestListener();
77-
parse("no-end-boundary.multipart", "boundary", listener);
78+
assertThatThrownBy(() -> parse("no-end-boundary.multipart", "boundary", listener))
79+
.isInstanceOf(HttpMessageConversionException.class)
80+
.hasMessage("Could not find end of headers");
7881

7982
assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class);
8083
}
8184

8285
@Test
8386
void garbage() throws Exception {
8487
TestListener listener = new TestListener();
85-
parse("garbage-1.multipart", "boundary", listener);
88+
assertThatThrownBy(() -> parse("garbage-1.multipart", "boundary", listener))
89+
.isInstanceOf(HttpMessageConversionException.class)
90+
.hasMessage("Could not find first boundary");
8691

8792
assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class);
8893
}
8994

9095
@Test
9196
void noEndHeader() throws Exception {
9297
TestListener listener = new TestListener();
93-
parse("no-end-header.multipart", "boundary", listener);
98+
assertThatThrownBy(() -> parse("no-end-header.multipart", "boundary", listener))
99+
.isInstanceOf(HttpMessageConversionException.class)
100+
.hasMessage("Could not find end of headers");
94101

95102
assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class);
96103
}
97104

98105
@Test
99106
void noEndBody() throws Exception {
100107
TestListener listener = new TestListener();
101-
parse("no-end-body.multipart", "boundary", listener);
108+
assertThatThrownBy(() -> parse("no-end-body.multipart", "boundary", listener))
109+
.isInstanceOf(HttpMessageConversionException.class)
110+
.hasMessage("Could not find end of body (␍␊--boundary)");
102111

103112
assertThat(listener.error).isInstanceOf(HttpMessageConversionException.class);
104113
}

0 commit comments

Comments
 (0)