Skip to content

Commit 8d71928

Browse files
fix: directly use correct device + dtype for eos positions extensions
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 115259c commit 8d71928

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/modalities/models/gpt2/collator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _compute_subsequence_length(self, seq: torch.Tensor, eos_positions: torch.Te
7474
# the last sequence is just padding.
7575
last_eos_pos = eos_positions[-1].item()
7676
if self._has_cutoff_final_sequence(seq, last_eos_pos):
77-
eos_positions = torch.cat([eos_positions, torch.tensor([len(seq) - 1])])
77+
eos_positions = torch.cat([eos_positions, eos_positions.new_tensor([len(seq) - 1])])
7878
# Compute length of each subsequence and add to lengths list.
7979
subseq_lengths = []
8080
prev_pos = 0

0 commit comments

Comments
 (0)