@@ -12,25 +12,25 @@ def __init__(
1212 sample_key : str ,
1313 target_key : str ,
1414 sub_seq_lengths_key : str | None = None ,
15- eod_token_id : int | None = None ,
15+ eos_token_id : int | None = None ,
1616 padding_token_id : int | None = None ,
1717 ):
1818 """
1919 Initializes the Collator object.
20- If the eod document token ID and the sub_seq_lengths_key are provided,
20+ If the eos token ID and the sub_seq_lengths_key are provided,
2121 a list[list[int]] representing the sub-sequence lengths will be created.
2222
2323 Args:
2424 sample_key (str): The key for accessing the sample data.
2525 target_key (str): The key for accessing the target data.
2626 sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
27- eod_token_id (int | None): The end-of-document token ID.
27+ eos_token_id (int | None): The end-of-sequence token ID.
2828 padding_token_id (int | None): The padding token ID.
2929 """
3030 self .sample_key = sample_key
3131 self .target_key = target_key
3232 self .sub_seq_lengths_key = sub_seq_lengths_key
33- self .eod_token_id = eod_token_id
33+ self .eos_token_id = eos_token_id
3434 self .padding_token_id = padding_token_id
3535
3636 def __call__ (self , batch : list [dict [str , torch .Tensor ]]) -> DatasetBatch :
@@ -49,42 +49,42 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
4949 samples = {self .sample_key : sample_tensor [:, :- 1 ]}
5050 targets = {self .target_key : sample_tensor [:, 1 :]}
5151 if self .sub_seq_lengths_key is not None :
52- # Determine sub sequence lengths by finding the eod tokens in each sequence in the batch.
52+ # Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
5353 sub_seq_lengths = self ._compute_sub_sequence_lengths_for_each_sequence (sample_tensor )
5454 samples [self .sub_seq_lengths_key ] = sub_seq_lengths
5555 return DatasetBatch (targets = targets , samples = samples )
5656
5757 def _compute_sub_sequence_lengths_for_each_sequence (self , sample_tensor : torch .Tensor ) -> list [list [int ]]:
5858 sub_seq_lengths = []
5959 for seq in sample_tensor :
60- eod_positions = (seq == self .eod_token_id ).nonzero (as_tuple = True )[0 ]
61- if len (eod_positions ) == 0 :
60+ eos_positions = (seq == self .eos_token_id ).nonzero (as_tuple = True )[0 ]
61+ if len (eos_positions ) == 0 :
6262 assert (
6363 self .padding_token_id is None or seq [0 ] != self .padding_token_id
6464 ), "Sequence starts with padding token"
6565 sub_seq_lengths .append ([len (seq )])
6666 else :
67- subseq_lengths = self ._compute_subsequence_length (seq , eod_positions )
67+ subseq_lengths = self ._compute_subsequence_length (seq , eos_positions )
6868 sub_seq_lengths .append (subseq_lengths )
6969 return sub_seq_lengths
7070
71- def _compute_subsequence_length (self , seq : torch .Tensor , eod_positions : torch .Tensor ) -> list [int ]:
72- # If the last sequence is cut, i.e. does not end on an eod token,
71+ def _compute_subsequence_length (self , seq : torch .Tensor , eos_positions : torch .Tensor ) -> list [int ]:
72+ # If the last sequence is cut, i.e. does not end on an eos token,
7373 # it should also be included unless the padding token is set and
7474 # the last sequence is just padding.
75- last_eod_pos = eod_positions [- 1 ].item ()
76- if self ._has_cutoff_final_sequence (seq , last_eod_pos ):
77- eod_positions = torch .cat ([eod_positions , torch .tensor ([len (seq ) - 1 ])])
75+ last_eos_pos = eos_positions [- 1 ].item ()
76+ if self ._has_cutoff_final_sequence (seq , last_eos_pos ):
77+ eos_positions = torch .cat ([eos_positions , torch .tensor ([len (seq ) - 1 ])])
7878 # Compute length of each subsequence and add to lengths list.
7979 subseq_lengths = []
8080 prev_pos = 0
81- for pos in eod_positions :
81+ for pos in eos_positions :
8282 subseq_lengths .append (pos .item () - prev_pos + 1 )
8383 prev_pos = pos .item () + 1
8484 return subseq_lengths
8585
86- def _has_cutoff_final_sequence (self , seq : torch .Tensor , last_eod_pos : int ) -> bool :
86+ def _has_cutoff_final_sequence (self , seq : torch .Tensor , last_eos_pos : int ) -> bool :
8787 # Assumption: If the first token of the last sequence is padding, so is the rest.
88- return last_eod_pos < len (seq ) - 1 and (
89- self .padding_token_id is None or seq [last_eod_pos + 1 ] != self .padding_token_id
88+ return last_eos_pos < len (seq ) - 1 and (
89+ self .padding_token_id is None or seq [last_eos_pos + 1 ] != self .padding_token_id
9090 )
0 commit comments