Skip to content

Commit fdd6465

Browse files
committed
fix(attention): computing sub sequence lengths on correct input
1 parent cd00777 commit fdd6465

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/modalities/models/gpt2/collator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
5050
targets = {self.target_key: sample_tensor[:, 1:]}
5151
if self.sub_seq_lengths_key is not None:
5252
# Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
53-
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(sample_tensor)
53+
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(samples[self.sample_key])
5454
samples[self.sub_seq_lengths_key] = sub_seq_lengths
5555
return DatasetBatch(targets=targets, samples=samples)
5656

0 commit comments

Comments
 (0)