Skip to content

Commit a5354c4

Browse files
committed
feat(data loading): GPT2LLMCollateFn can now determine the sub sequence lengths required for inter document attention masking.
1 parent 5e039df commit a5354c4

3 files changed

Lines changed: 172 additions & 1 deletion

File tree

src/modalities/config/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,25 @@ class BatchSamplerConfig(BaseModel):
443443
class GPT2LLMCollateFnConfig(BaseModel):
444444
sample_key: str
445445
target_key: str
446+
sub_seq_lengths_key: str | None = None
447+
eod_token_id: int | None = None
448+
padding_token_id: int | None = None
449+
450+
@model_validator(mode="before")
451+
def check_sub_seq_lengths_and_eod_token(cls, values):
452+
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
453+
eod_token_id = values.get("eod_token_id")
454+
if (sub_seq_lengths_key is None) != (eod_token_id is None):
455+
raise ValueError("Either both or neither of sub_seq_lengths_key and eod_token_id must be provided.")
456+
return values
457+
458+
@model_validator(mode="before")
459+
def check_padding_token_and_sub_seq_lengths(cls, values):
460+
padding_token_id = values.get("padding_token_id")
461+
sub_seq_lengths_key = values.get("sub_seq_lengths_key")
462+
if padding_token_id is not None and sub_seq_lengths_key is None:
463+
raise ValueError("If padding_token_id is provided, sub_seq_lengths_key must also be provided.")
464+
return values
446465

447466

448467
class LLMDataLoaderConfig(BaseModel):

src/modalities/models/gpt2/collator.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,31 @@
77
class GPT2LLMCollateFn(CollateFnIF):
88
"""GPT2LLMCollateFn class to define a collate function for GPT2 language model."""
99

10-
def __init__(self, sample_key: str, target_key: str):
10+
def __init__(
11+
self,
12+
sample_key: str,
13+
target_key: str,
14+
sub_seq_lengths_key: str | None = None,
15+
eod_token_id: int | None = None,
16+
padding_token_id: int | None = None,
17+
):
1118
"""
1219
Initializes the Collator object.
20+
If the eod document token ID and the sub_seq_lengths_key are provided,
21+
a list[list[int]] representing the sub-sequence lengths will be created.
1322
1423
Args:
1524
sample_key (str): The key for accessing the sample data.
1625
target_key (str): The key for accessing the target data.
26+
sub_seq_lengths_key (str | None): The key for accessing the sub-sequence lengths.
27+
eod_token_id (int | None): The end-of-document token ID.
28+
padding_token_id (int | None): The padding token ID.
1729
"""
1830
self.sample_key = sample_key
1931
self.target_key = target_key
32+
self.sub_seq_lengths_key = sub_seq_lengths_key
33+
self.eod_token_id = eod_token_id
34+
self.padding_token_id = padding_token_id
2035

2136
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
2237
"""
@@ -33,4 +48,43 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
3348
sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch])
3449
samples = {self.sample_key: sample_tensor[:, :-1]}
3550
targets = {self.target_key: sample_tensor[:, 1:]}
51+
if self.sub_seq_lengths_key is not None:
52+
# Determine sub sequence lengths by finding the eod tokens in each sequence in the batch.
53+
sub_seq_lengths = self._compute_sub_sequence_lengths_for_each_sequence(sample_tensor)
54+
samples[self.sub_seq_lengths_key] = sub_seq_lengths
3655
return DatasetBatch(targets=targets, samples=samples)
56+
57+
def _compute_sub_sequence_lengths_for_each_sequence(self, sample_tensor: torch.Tensor) -> list[list[int]]:
58+
sub_seq_lengths = []
59+
for seq in sample_tensor:
60+
eod_positions = (seq == self.eod_token_id).nonzero(as_tuple=True)[0]
61+
if len(eod_positions) == 0:
62+
assert (
63+
self.padding_token_id is None or seq[0] != self.padding_token_id
64+
), "Sequence starts with padding token"
65+
sub_seq_lengths.append([len(seq)])
66+
else:
67+
subseq_lengths = self._compute_subsequence_length(seq, eod_positions)
68+
sub_seq_lengths.append(subseq_lengths)
69+
return sub_seq_lengths
70+
71+
def _compute_subsequence_length(self, seq: torch.Tensor, eod_positions: torch.Tensor) -> list[int]:
72+
# If the last sequence is cut, i.e. does not end on an eod token,
73+
# it should also be included unless the padding token is set and
74+
# the last sequence is just padding.
75+
last_eod_pos = eod_positions[-1].item()
76+
if self._has_cutoff_final_sequence(seq, last_eod_pos):
77+
eod_positions = torch.cat([eod_positions, torch.tensor([len(seq) - 1])])
78+
# Compute length of each subsequence and add to lengths list.
79+
subseq_lengths = []
80+
prev_pos = 0
81+
for pos in eod_positions:
82+
subseq_lengths.append(pos.item() - prev_pos + 1)
83+
prev_pos = pos.item() + 1
84+
return subseq_lengths
85+
86+
def _has_cutoff_final_sequence(self, seq: torch.Tensor, last_eod_pos: int) -> bool:
87+
# Assumption: If the first token of the last sequence is padding, so is the rest.
88+
return last_eod_pos < len(seq) - 1 and (
89+
self.padding_token_id is None or seq[last_eod_pos + 1] != self.padding_token_id
90+
)

tests/models/test_gpt2_collator.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pytest
2+
import torch
3+
4+
from modalities.models.gpt2.collator import GPT2LLMCollateFn
5+
6+
7+
def test_gpt2_collate_shifts_samples_and_targets():
8+
collator = GPT2LLMCollateFn(sample_key="input_ids", target_key="labels")
9+
batch = [
10+
{"input_ids": torch.tensor([1, 2, 3, 4])},
11+
{"input_ids": torch.tensor([5, 6, 7, 8])},
12+
]
13+
14+
result = collator(batch)
15+
16+
assert result.samples["input_ids"].tolist() == [[1, 2, 3], [5, 6, 7]]
17+
assert result.targets["labels"].tolist() == [[2, 3, 4], [6, 7, 8]]
18+
19+
20+
def test_gpt2_collate_sub_seq_lengths_without_eod():
21+
collator = GPT2LLMCollateFn(
22+
sample_key="input_ids",
23+
target_key="labels",
24+
sub_seq_lengths_key="sub_seq_lengths",
25+
eod_token_id=99,
26+
)
27+
batch = [
28+
{"input_ids": torch.tensor([10, 11, 12, 13, 14])},
29+
{"input_ids": torch.tensor([20, 21, 22, 23, 24])},
30+
]
31+
32+
result = collator(batch)
33+
34+
assert result.samples["sub_seq_lengths"] == [[5], [5]]
35+
36+
37+
def test_gpt2_collate_sub_seq_lengths_with_eod():
38+
collator = GPT2LLMCollateFn(
39+
sample_key="input_ids",
40+
target_key="labels",
41+
sub_seq_lengths_key="sub_seq_lengths",
42+
eod_token_id=99,
43+
)
44+
batch = [
45+
{"input_ids": torch.tensor([1, 99, 2, 3, 99])},
46+
{"input_ids": torch.tensor([7, 8, 9, 99, 10])},
47+
]
48+
49+
result = collator(batch)
50+
51+
assert result.samples["sub_seq_lengths"] == [[2, 3], [4, 1]]
52+
53+
54+
def test_gpt2_collate_sub_seq_lengths_with_eod_and_padding():
55+
collator = GPT2LLMCollateFn(
56+
sample_key="input_ids",
57+
target_key="labels",
58+
sub_seq_lengths_key="sub_seq_lengths",
59+
eod_token_id=99,
60+
padding_token_id=0,
61+
)
62+
batch = [
63+
{"input_ids": torch.tensor([1, 99, 2, 3, 4, 5])},
64+
{"input_ids": torch.tensor([7, 8, 99, 0, 0, 0])},
65+
]
66+
67+
result = collator(batch)
68+
69+
assert result.samples["sub_seq_lengths"] == [[2, 4], [3]]
70+
71+
72+
def test_gpt2_collate_sub_seq_lengths_adds_tail_when_not_padding():
73+
collator = GPT2LLMCollateFn(
74+
sample_key="input_ids",
75+
target_key="labels",
76+
sub_seq_lengths_key="sub_seq_lengths",
77+
eod_token_id=5,
78+
padding_token_id=0,
79+
)
80+
batch = [{"input_ids": torch.tensor([1, 5, 9, 8])}]
81+
82+
result = collator(batch)
83+
84+
assert result.samples["sub_seq_lengths"] == [[2, 2]]
85+
86+
87+
def test_gpt2_collate_raises_when_sequence_starts_with_padding_and_no_eod():
88+
collator = GPT2LLMCollateFn(
89+
sample_key="input_ids",
90+
target_key="labels",
91+
sub_seq_lengths_key="sub_seq_lengths",
92+
eod_token_id=99,
93+
padding_token_id=0,
94+
)
95+
batch = [{"input_ids": torch.tensor([0, 1, 2, 3])}]
96+
97+
with pytest.raises(AssertionError, match="Sequence starts with padding token"):
98+
collator(batch)

0 commit comments

Comments
 (0)