Skip to content

Commit eba9c5b

Browse files
committed
fix(attention): NaNs when using padding + inter document masking with manual attention.
- Also applied some review comments.
1 parent a5354c4 commit eba9c5b

2 files changed

Lines changed: 26 additions & 17 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,6 @@ def _execute_dao_flash_with_inter_document_masking(
636636

637637
indices, cu_seqlens, max_seqlen = attention_masking_information
638638

639-
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
640-
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
641-
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
642-
643639
batch_size, seq_len, n_head_q, head_dim = q.shape
644640
n_head_kv = k.shape[2]
645641

@@ -822,11 +818,11 @@ def execute_attention(
822818
# Note, that the library is not required for the CPU-only tests.
823819
if flash_attn_func is None:
824820
raise NotImplementedError("ERROR! Dao Flash Attention is not installed.")
825-
# the next three lines are only needed for flash-attn from Daio Lab
821+
# the next three lines are only needed for flash-attn from Dao Lab
822+
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
823+
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
824+
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
826825
if attention_masking_information is None:
827-
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
828-
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
829-
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
830826
y = flash_attn_func(
831827
q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
832828
) # (B, T, nh_q, hd)
@@ -991,12 +987,18 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None:
991987
f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`."
992988
)
993989

994-
def forward(self, x: torch.Tensor, attention_masking_information: torch.Tensor | None = None) -> torch.Tensor:
990+
def forward(
991+
self,
992+
x: torch.Tensor,
993+
attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None,
994+
) -> torch.Tensor:
995995
"""
996996
Forward pass of the GPT2Block.
997997
998998
Args:
999999
x (torch.Tensor): Input tensor.
1000+
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
1001+
Attention masking information.
10001002
10011003
Returns:
10021004
torch.Tensor: Output tensor.
@@ -1140,13 +1142,14 @@ def __init__(
11401142
) # https://paperswithcode.com/method/weight-tying
11411143

11421144
@overload
1143-
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
1145+
def forward(self, inputs: dict[str, torch.Tensor | list[list[int]]]) -> dict[str, torch.Tensor]:
11441146
"""
11451147
Forward pass of the GPT2LLM module.
11461148
11471149
Args:
1148-
inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
1150+
inputs (dict[str, torch.Tensor | list[list[int]]]): A dictionary containing input tensors.
11491151
- sample_key (str): Key for the input tensor containing token ids.
1152+
- sub_seq_lengths_key (str, optional): Key for the input tensor containing subsequence lengths.
11501153
11511154
Returns:
11521155
dict[str, torch.Tensor]: A dictionary containing output tensors.
@@ -1167,12 +1170,14 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
11671170
"""
11681171
...
11691172

1170-
def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor:
1173+
def forward(
1174+
self, inputs: dict[str, torch.Tensor | list[list[int]]] | torch.Tensor
1175+
) -> dict[str, torch.Tensor] | torch.Tensor:
11711176
"""
11721177
Forward pass of the GPT2LLM module.
11731178
11741179
Args:
1175-
inputs (dict[str, torch.Tensor] | torch.Tensor): Input data.
1180+
inputs (dict[str, torch.Tensor | list[list[int]]] | torch.Tensor): Input data.
11761181
11771182
Returns:
11781183
dict[str, torch.Tensor] | torch.Tensor: Model output.
@@ -1258,6 +1263,7 @@ def manual_scaled_dot_product_attention(
12581263
attn_bias = torch.zeros(
12591264
L, S, dtype=query.dtype, device=query.device
12601265
) # device added (not part of the original code)
1266+
fully_masked = None
12611267
if attn_mask is not None and attn_mask.dim() == 3:
12621268
attn_bias = attn_bias.unsqueeze(0).repeat(attn_mask.size(0), 1, 1)
12631269
if is_causal:
@@ -1269,22 +1275,25 @@ def manual_scaled_dot_product_attention(
12691275
combined_mask = temp_mask.unsqueeze(0) & attn_mask
12701276
else:
12711277
combined_mask = temp_mask & attn_mask
1278+
fully_masked = ~combined_mask.any(dim=-1)
12721279
attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
12731280
else:
12741281
if attn_mask.dim() == 3:
12751282
temp_mask = temp_mask.unsqueeze(0)
12761283
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
12771284
attn_bias += attn_mask
1278-
attn_bias.to(query.dtype)
12791285
elif attn_mask is not None:
12801286
if attn_mask.dtype == torch.bool:
12811287
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
12821288
else:
12831289
attn_bias += attn_mask
1290+
attn_bias = attn_bias.to(query.dtype)
12841291
attn_weight = query @ key.transpose(-2, -1) * scale_factor
12851292
if attn_bias.dim() == 3:
12861293
attn_bias = attn_bias.unsqueeze(1)
12871294
attn_weight += attn_bias
12881295
attn_weight = torch.softmax(attn_weight, dim=-1)
12891296
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
1297+
if fully_masked is not None and attn_weight.dim() == 4:
1298+
attn_weight = attn_weight.masked_fill(fully_masked.unsqueeze(1).unsqueeze(-1), 0.0)
12901299
return attn_weight @ value

tests/models/test_causal_self_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,9 @@ def test_inter_document_masking_manual_matches_dao_flash_with_masks():
549549
torch.manual_seed(0)
550550
dao_layer, manual_layer = _build_matching_dao_and_manual_attention()
551551

552-
inputs = _get_random_input_seq((1, 5, 16))
553-
dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3]], max_seq_len=5)
554-
manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3]], max_seq_len=5)
552+
inputs = _get_random_input_seq((2, 5, 16))
553+
dao_mask = dao_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3], [1, 1, 2]], max_seq_len=5)
554+
manual_mask = manual_layer.prepare_inter_document_masking(in_batch_seq_lens=[[2, 3], [1, 1, 2]], max_seq_len=5)
555555

556556
output_dao = dao_layer(inputs, attention_masking_information=dao_mask)
557557
output_manual = manual_layer(inputs, attention_masking_information=manual_mask)

0 commit comments

Comments
 (0)