Skip to content

Commit 3d583c1

Browse files
committed
feat(attention): added sub_seq_lengths_key to GPT2LLMConfig and renamed eod_token_id to eos_token_id
1 parent eba9c5b commit 3d583c1

4 files changed

Lines changed: 32 additions & 31 deletions

File tree

src/modalities/config/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,15 +444,15 @@ class GPT2LLMCollateFnConfig(BaseModel):
444444
sample_key: str
445445
target_key: str
446446
sub_seq_lengths_key: str | None = None
447-
eod_token_id: int | None = None
447+
eos_token_id: int | None = None
448448
padding_token_id: int | None = None
449449

450450
@model_validator(mode="before")
451-
def check_sub_seq_lengths_and_eod_token(cls, values):
451+
def check_sub_seq_lengths_and_eos_token(cls, values):
452452
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
453-
eod_token_id = values.get("eod_token_id")
454-
if (sub_seq_lengths_key is None) != (eod_token_id is None):
455-
raise ValueError("Either both or neither of sub_seq_lengths_key and eod_token_id must be provided.")
453+
eos_token_id = values.get("eos_token_id")
454+
if (sub_seq_lengths_key is None) != (eos_token_id is None):
455+
raise ValueError("Either both or neither of sub_seq_lengths_key and eos_token_id must be provided.")
456456
return values
457457

458458
@model_validator(mode="before")

src/modalities/models/gpt2/collator.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,25 @@ def __init__(
1212
sample_key: str,
1313
target_key: str,
1414
sub_seq_lengths_key: str | None = None,
15-
eod_token_id: int | None = None,
15+
eos_token_id: int | None = None,
1616
padding_token_id: int | None = None,
1717
):
1818
"""
1919
Initializes the Collator object.
20-
If the eod document token ID and the sub_seq_lengths_key are provided,
20+
If the eos token ID and the sub_seq_lengths_key are provided,
2121
a list[list[int]] representing the sub-sequence lengths will be created.
2222
2323
Args:
2424
sample_key (str): The key for accessing the sample data.
2525
target_key (str): The key for accessing the target data.
2626
sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
27-
eod_token_id (int | None): The end-of-document token ID.
27+
eos_token_id (int | None): The end-of-sequence token ID.
2828
padding_token_id (int | None): The padding token ID.
2929
"""
3030
self.sample_key = sample_key
3131
self.target_key = target_key
3232
self.sub_seq_lengths_key = sub_seq_lengths_key
33-
self.eod_token_id = eod_token_id
33+
self.eos_token_id = eos_token_id
3434
self.padding_token_id = padding_token_id
3535

3636
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
@@ -49,42 +49,42 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
4949
samples = {self.sample_key: sample_tensor[:, :-1]}
5050
targets = {self.target_key: sample_tensor[:, 1:]}
5151
if self.sub_seq_lengths_key is not None:
52-
# Determine sub sequence lengths by finding the eod tokens in each sequence in the batch.
52+
# Determine sub sequence lengths by finding the eos tokens in each sequence in the batch.
5353
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(sample_tensor)
5454
samples[self.sub_seq_lengths_key] = sub_seq_lengths
5555
return DatasetBatch(targets=targets, samples=samples)
5656

5757
def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]:
5858
sub_seq_lengths = []
5959
for seq in sample_tensor:
60-
eod_positions = (seq == self.eod_token_id).nonzero(as_tuple=True)[0]
61-
if len(eod_positions) == 0:
60+
eos_positions = (seq == self.eos_token_id).nonzero(as_tuple=True)[0]
61+
if len(eos_positions) == 0:
6262
assert (
6363
self.padding_token_id is None or seq[0] != self.padding_token_id
6464
), "Sequence starts with padding token"
6565
sub_seq_lengths.append([len(seq)])
6666
else:
67-
subseq_lengths = self._compute_subsequence_length(seq, eod_positions)
67+
subseq_lengths = self._compute_subsequence_length(seq, eos_positions)
6868
sub_seq_lengths.append(subseq_lengths)
6969
return sub_seq_lengths
7070

71-
def _compute_subsequence_length(self, seq: torch.Tensor, eod_positions: torch.Tensor) -> list[int]:
72-
# If the last sequence is cut, i.e. does not end on an eod token,
71+
def _compute_subsequence_length(self, seq: torch.Tensor, eos_positions: torch.Tensor) -> list[int]:
72+
# If the last sequence is cut, i.e. does not end on an eos token,
7373
# it should also be included unless the padding token is set and
7474
# the last sequence is just padding.
75-
last_eod_pos = eod_positions[-1].item()
76-
if self._has_cutoff_final_sequence(seq, last_eod_pos):
77-
eod_positions = torch.cat([eod_positions, torch.tensor([len(seq) - 1])])
75+
last_eos_pos = eos_positions[-1].item()
76+
if self._has_cutoff_final_sequence(seq, last_eos_pos):
77+
eos_positions = torch.cat([eos_positions, torch.tensor([len(seq) - 1])])
7878
# Compute length of each subsequence and add to lengths list.
7979
subseq_lengths = []
8080
prev_pos = 0
81-
for pos in eod_positions:
81+
for pos in eos_positions:
8282
subseq_lengths.append(pos.item() - prev_pos + 1)
8383
prev_pos = pos.item() + 1
8484
return subseq_lengths
8585

86-
def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eod_pos: int) -> bool:
86+
def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eos_pos: int) -> bool:
8787
# Assumption: If the first token of the last sequence is padding, so is the rest.
88-
return last_eod_pos < len(seq) - 1 and (
89-
self.padding_token_id is None or seq[last_eod_pos + 1] != self.padding_token_id
88+
return last_eos_pos < len(seq) - 1 and (
89+
self.padding_token_id is None or seq[last_eos_pos + 1] != self.padding_token_id
9090
)

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ class GPT2LLMConfig(BaseModel):
373373
use_weight_tying: bool
374374
seed: Optional[int] = None
375375
enforce_swiglu_hidden_dim_multiple_of: int = 256
376+
sub_seq_lengths_key: str | None = None
376377

377378
@model_validator(mode="after")
378379
def check_divisibility(self) -> "GPT2LLMConfig":

tests/models/test_gpt2_collator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ def test_gpt2_collate_shifts_samples_and_targets():
1717
assert result.targets["labels"].tolist() == [[2, 3, 4], [6, 7, 8]]
1818

1919

20-
def test_gpt2_collate_sub_seq_lengths_without_eod():
20+
def test_gpt2_collate_sub_seq_lengths_without_eos():
2121
collator = GPT2LLMCollateFn(
2222
sample_key="input_ids",
2323
target_key="labels",
2424
sub_seq_lengths_key="sub_seq_lengths",
25-
eod_token_id=99,
25+
eos_token_id=99,
2626
)
2727
batch = [
2828
{"input_ids": torch.tensor([10, 11, 12, 13, 14])},
@@ -34,12 +34,12 @@ def test_gpt2_collate_sub_seq_lengths_without_eod():
3434
assert result.samples["sub_seq_lengths"] == [[5], [5]]
3535

3636

37-
def test_gpt2_collate_sub_seq_lengths_with_eod():
37+
def test_gpt2_collate_sub_seq_lengths_with_eos():
3838
collator = GPT2LLMCollateFn(
3939
sample_key="input_ids",
4040
target_key="labels",
4141
sub_seq_lengths_key="sub_seq_lengths",
42-
eod_token_id=99,
42+
eos_token_id=99,
4343
)
4444
batch = [
4545
{"input_ids": torch.tensor([1, 99, 2, 3, 99])},
@@ -51,12 +51,12 @@ def test_gpt2_collate_sub_seq_lengths_with_eod():
5151
assert result.samples["sub_seq_lengths"] == [[2, 3], [4, 1]]
5252

5353

54-
def test_gpt2_collate_sub_seq_lengths_with_eod_and_padding():
54+
def test_gpt2_collate_sub_seq_lengths_with_eos_and_padding():
5555
collator = GPT2LLMCollateFn(
5656
sample_key="input_ids",
5757
target_key="labels",
5858
sub_seq_lengths_key="sub_seq_lengths",
59-
eod_token_id=99,
59+
eos_token_id=99,
6060
padding_token_id=0,
6161
)
6262
batch = [
@@ -74,7 +74,7 @@ def test_gpt2_collate_sub_seq_lengths_adds_tail_when_not_padding():
7474
sample_key="input_ids",
7575
target_key="labels",
7676
sub_seq_lengths_key="sub_seq_lengths",
77-
eod_token_id=5,
77+
eos_token_id=5,
7878
padding_token_id=0,
7979
)
8080
batch = [{"input_ids": torch.tensor([1, 5, 9, 8])}]
@@ -84,12 +84,12 @@ def test_gpt2_collate_sub_seq_lengths_adds_tail_when_not_padding():
8484
assert result.samples["sub_seq_lengths"] == [[2, 2]]
8585

8686

87-
def test_gpt2_collate_raises_when_sequence_starts_with_padding_and_no_eod():
87+
def test_gpt2_collate_raises_when_sequence_starts_with_padding_and_no_eos():
8888
collator = GPT2LLMCollateFn(
8989
sample_key="input_ids",
9090
target_key="labels",
9191
sub_seq_lengths_key="sub_seq_lengths",
92-
eod_token_id=99,
92+
eos_token_id=99,
9393
padding_token_id=0,
9494
)
9595
batch = [{"input_ids": torch.tensor([0, 1, 2, 3])}]

0 commit comments

Comments
 (0)