Skip to content

Commit d1af855

Browse files
committed
fix(decode): preserve dictionary + consume to EOF in read_to_end
Three correctness fixes in the read_to_end decode-in-place paths: - Concatenated frames now re-initialise via init_with_dict_handle when the StreamingDecoder was built with a dictionary, so a forced dict is preserved for following frames (plain init resolves dicts by frame id only, losing the dict for frames that omit the id). StreamingDecoder retains the (cheap Arc/Rc) dict handle for this; the shared init+decode loop is factored into FrameDecoder::decode_concatenated_frames_to_vec. - The mid-frame fallback (mixed read + read_to_end) now drains the partial current frame and then decodes the remaining concatenated frames to true EOF, matching the fast path and the Read::read_to_end contract. - A direct-path decode error truncates the just-grown output tail before propagating, so callers never observe zeroed non-decoded bytes.
1 parent 7bda2a2 commit d1af855

2 files changed

Lines changed: 99 additions & 19 deletions

File tree

zstd/src/decoding/frame_decoder.rs

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,27 +2340,53 @@ impl FrameDecoder {
23402340
&mut self,
23412341
mut input: &[u8],
23422342
output: &mut Vec<u8>,
2343+
dict: Option<&DictionaryHandle>,
23432344
) -> Result<usize, FrameDecoderError> {
23442345
let start_len = output.len();
23452346
// The current frame is already initialised (its header consumed by the
2346-
// caller). Decode it, then decode any FOLLOWING concatenated / skippable
2347-
// frames in `input` so the whole source is consumed to EOF and nothing
2348-
// is dropped (matching `read_to_end` semantics).
2347+
// caller, WITH `dict` applied if the decoder was constructed with one).
2348+
// Decode it, then decode any FOLLOWING concatenated / skippable frames
2349+
// in `input` so the whole source is consumed to EOF and nothing is
2350+
// dropped (matching `read_to_end` semantics).
23492351
self.decode_one_frame_to_vec(&mut input, output)?;
2352+
self.decode_concatenated_frames_to_vec(&mut input, output, dict)?;
2353+
Ok(output.len() - start_len)
2354+
}
2355+
2356+
/// Initialise and decode every frame remaining in `input` (concatenated /
2357+
/// skippable), APPENDING to `output`. `input` is advanced as frames are
2358+
/// consumed; on return it is empty. Re-initialisation honours `dict`: when
2359+
/// `Some`, each following frame is initialised via
2360+
/// [`Self::init_with_dict_handle`] so a forced dictionary is preserved even
2361+
/// for frames that omit the dictionary id (plain [`Self::init`] would
2362+
/// resolve dictionaries by id only). Backs the `read_to_end` fast path (the
2363+
/// frames after the current one) and its mid-frame fallback (the frames
2364+
/// after the partially-read one).
2365+
pub(crate) fn decode_concatenated_frames_to_vec(
2366+
&mut self,
2367+
input: &mut &[u8],
2368+
output: &mut Vec<u8>,
2369+
dict: Option<&DictionaryHandle>,
2370+
) -> Result<usize, FrameDecoderError> {
2371+
let start_len = output.len();
23502372
while !input.is_empty() {
2351-
match self.init(&mut input) {
2373+
let init_result = match dict {
2374+
Some(d) => self.init_with_dict_handle(&mut *input, d),
2375+
None => self.init(&mut *input),
2376+
};
2377+
match init_result {
23522378
Ok(_) => {}
23532379
Err(FrameDecoderError::ReadFrameHeaderError(
23542380
crate::decoding::errors::ReadFrameHeaderError::SkipFrame { length, .. },
23552381
)) => {
2356-
input = input
2382+
*input = input
23572383
.get(length as usize..)
23582384
.ok_or(FrameDecoderError::FailedToSkipFrame)?;
23592385
continue;
23602386
}
23612387
Err(e) => return Err(e),
23622388
}
2363-
self.decode_one_frame_to_vec(&mut input, output)?;
2389+
self.decode_one_frame_to_vec(&mut *input, output)?;
23642390
}
23652391
Ok(output.len() - start_len)
23662392
}
@@ -2394,8 +2420,17 @@ impl FrameDecoder {
23942420
// `content_size` bytes (erroring otherwise), so the grown region is
23952421
// fully written.
23962422
output.resize(frame_end, 0);
2423+
// On error, drop the just-grown (zeroed) tail before propagating so
2424+
// callers never observe bytes that were never decoded.
23972425
let written =
2398-
self.run_direct_decode(&mut *input, &mut output[frame_start..], content_size)?;
2426+
match self.run_direct_decode(&mut *input, &mut output[frame_start..], content_size)
2427+
{
2428+
Ok(n) => n,
2429+
Err(e) => {
2430+
output.truncate(frame_start);
2431+
return Err(e);
2432+
}
2433+
};
23992434
output.truncate(frame_start + written);
24002435
#[cfg(feature = "hash")]
24012436
self.verify_content_checksum()?;

zstd/src/decoding/streaming_decoder.rs

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ use crate::io::{Error, Read};
4646
pub struct StreamingDecoder<READ: Read, DEC: BorrowMut<FrameDecoder>> {
4747
pub decoder: DEC,
4848
source: READ,
49+
/// Dictionary the decoder was constructed with, if any. Retained so the
50+
/// `read_to_end` paths can re-initialise FOLLOWING concatenated frames with
51+
/// the same forced dictionary (a plain re-init resolves dictionaries by
52+
/// frame id only and would lose a forced dict for frames omitting the id).
53+
/// Cheap to hold: `DictionaryHandle` is an `Arc`/`Rc` handle.
54+
dict: Option<DictionaryHandle>,
4955
}
5056

5157
impl<READ: Read, DEC: BorrowMut<FrameDecoder>> StreamingDecoder<READ, DEC> {
@@ -54,7 +60,11 @@ impl<READ: Read, DEC: BorrowMut<FrameDecoder>> StreamingDecoder<READ, DEC> {
5460
mut decoder: DEC,
5561
) -> Result<StreamingDecoder<READ, DEC>, FrameDecoderError> {
5662
decoder.borrow_mut().init(&mut source)?;
57-
Ok(StreamingDecoder { decoder, source })
63+
Ok(StreamingDecoder {
64+
decoder,
65+
source,
66+
dict: None,
67+
})
5868
}
5969
}
6070

@@ -64,7 +74,11 @@ impl<READ: Read> StreamingDecoder<READ, FrameDecoder> {
6474
) -> Result<StreamingDecoder<READ, FrameDecoder>, FrameDecoderError> {
6575
let mut decoder = FrameDecoder::new();
6676
decoder.init(&mut source)?;
67-
Ok(StreamingDecoder { decoder, source })
77+
Ok(StreamingDecoder {
78+
decoder,
79+
source,
80+
dict: None,
81+
})
6882
}
6983

7084
/// Create a streaming decoder using a pre-parsed dictionary handle.
@@ -82,7 +96,11 @@ impl<READ: Read> StreamingDecoder<READ, FrameDecoder> {
8296
) -> Result<StreamingDecoder<READ, FrameDecoder>, FrameDecoderError> {
8397
let mut decoder = FrameDecoder::new();
8498
decoder.init_with_dict_handle(&mut source, dict)?;
85-
Ok(StreamingDecoder { decoder, source })
99+
Ok(StreamingDecoder {
100+
decoder,
101+
source,
102+
dict: Some(dict.clone()),
103+
})
86104
}
87105

88106
/// Create a streaming decoder using a serialized dictionary blob.
@@ -229,24 +247,28 @@ impl<READ: Read, DEC: BorrowMut<FrameDecoder>> Read for StreamingDecoder<READ, D
229247
/// recreate-the-decoder pattern instead.
230248
#[cfg(feature = "std")]
231249
fn read_to_end(&mut self, output: &mut alloc::vec::Vec<u8>) -> Result<usize, Error> {
250+
let start_total = output.len();
232251
// `new()` already read the frame header, so the fast path applies when
233252
// the decoder sits at the start of that frame with nothing decoded yet.
234253
let at_start = {
235254
let d = self.decoder.borrow_mut();
236255
d.is_at_frame_start() && d.can_collect() == 0
237256
};
257+
// Clone the (cheap Arc/Rc) dict handle out so the `decoder` borrow below
258+
// does not conflict with borrowing `self.dict`.
259+
let dict = self.dict.clone();
238260
if at_start {
239261
let mut compressed = alloc::vec::Vec::new();
240262
self.source.read_to_end(&mut compressed)?;
241-
let written = self
242-
.decoder
263+
self.decoder
243264
.borrow_mut()
244-
.decode_current_frame_to_vec(&compressed, output)
265+
.decode_current_frame_to_vec(&compressed, output, dict.as_ref())
245266
.map_err(Error::other)?;
246-
return Ok(written);
267+
return Ok(output.len() - start_total);
247268
}
248-
// Mid-frame fallback: grow `output` and drain through the generic path.
249-
let mut total = 0;
269+
// Mid-frame fallback: drain the partially-read CURRENT frame through the
270+
// generic path, then decode any FOLLOWING concatenated frames so
271+
// read_to_end still consumes the source to true EOF.
250272
loop {
251273
let start = output.len();
252274
output.resize(start + MAX_BLOCK_SIZE as usize, 0);
@@ -255,9 +277,18 @@ impl<READ: Read, DEC: BorrowMut<FrameDecoder>> Read for StreamingDecoder<READ, D
255277
if n == 0 {
256278
break;
257279
}
258-
total += n;
259280
}
260-
Ok(total)
281+
// Current frame fully drained; `source` is positioned at the next frame.
282+
let mut rest = alloc::vec::Vec::new();
283+
self.source.read_to_end(&mut rest)?;
284+
if !rest.is_empty() {
285+
let mut input = rest.as_slice();
286+
self.decoder
287+
.borrow_mut()
288+
.decode_concatenated_frames_to_vec(&mut input, output, dict.as_ref())
289+
.map_err(Error::other)?;
290+
}
291+
Ok(output.len() - start_total)
261292
}
262293

263294
/// no_std counterpart of the decode-in-place `read_to_end` fast path above
@@ -268,15 +299,20 @@ impl<READ: Read, DEC: BorrowMut<FrameDecoder>> Read for StreamingDecoder<READ, D
268299
let d = self.decoder.borrow_mut();
269300
d.is_at_frame_start() && d.can_collect() == 0
270301
};
302+
// Cheap Arc/Rc clone so the `decoder` borrow does not conflict with
303+
// borrowing `self.dict`.
304+
let dict = self.dict.clone();
271305
if at_start {
272306
let mut compressed = alloc::vec::Vec::new();
273307
self.source.read_to_end(&mut compressed)?;
274308
self.decoder
275309
.borrow_mut()
276-
.decode_current_frame_to_vec(&compressed, output)
310+
.decode_current_frame_to_vec(&compressed, output, dict.as_ref())
277311
.map_err(|e| Error::new(ErrorKind::Other, alloc::boxed::Box::new(e)))?;
278312
return Ok(());
279313
}
314+
// Mid-frame fallback: drain the partial CURRENT frame, then decode the
315+
// FOLLOWING concatenated frames so the source is consumed to true EOF.
280316
loop {
281317
let start = output.len();
282318
output.resize(start + MAX_BLOCK_SIZE as usize, 0);
@@ -286,6 +322,15 @@ impl<READ: Read, DEC: BorrowMut<FrameDecoder>> Read for StreamingDecoder<READ, D
286322
break;
287323
}
288324
}
325+
let mut rest = alloc::vec::Vec::new();
326+
self.source.read_to_end(&mut rest)?;
327+
if !rest.is_empty() {
328+
let mut input = rest.as_slice();
329+
self.decoder
330+
.borrow_mut()
331+
.decode_concatenated_frames_to_vec(&mut input, output, dict.as_ref())
332+
.map_err(|e| Error::new(ErrorKind::Other, alloc::boxed::Box::new(e)))?;
333+
}
289334
Ok(())
290335
}
291336
}

0 commit comments

Comments
 (0)