Skip to content

Commit 5e039df

Browse files
committed
feat(attention): Added inter document masking for manual and flash attention.
1 parent 0596085 commit 5e039df

2 files changed

Lines changed: 827 additions & 43 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 244 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from modalities.util import parse_enum_by_name
2121

2222
try:
23-
from flash_attn import flash_attn_func
23+
from flash_attn import flash_attn_func, flash_attn_varlen_func
2424
except ModuleNotFoundError:
2525
flash_attn_func = None
26+
flash_attn_varlen_func = None
2627

2728
# Logger configuration
2829
logger = logging.getLogger(__name__)
@@ -501,6 +502,178 @@ def __init__(
501502
self.q_norm = None
502503
self.k_norm = None
503504

505+
def prepare_inter_document_masking(
506+
self, in_batch_seq_lens: list[list[int]], max_seq_len: int
507+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]:
508+
"""
509+
Prepares the inter-document attention mask based on the input batch sequence lengths.
510+
For manual attention, a 3D attention mask of shape (batch_size, total_seq_len, total_seq_len) is returned.
511+
For flash attention, the cu_seqlens are computed and returned along with the indices
512+
of valid tokens and the maximum sequence length in the batch.
513+
For sdp attention, an exception is raised for now.
514+
515+
Args:
516+
in_batch_seq_lens (list[list[int]]): A list of lists containing the sequence
517+
lengths for each document in the batch.
518+
max_seq_len (int): The maximum sequence length in the batch.
519+
520+
Returns:
521+
torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: The inter-document masking information.
522+
"""
523+
device = self.c_proj.weight.device
524+
if self.attention_impl == AttentionImplementation.MANUAL:
525+
batch_size = len(in_batch_seq_lens)
526+
attn_mask = torch.zeros((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device)
527+
for i, doc_seq_lens in enumerate(in_batch_seq_lens):
528+
doc_boundaries = torch.cumsum(torch.tensor([0] + doc_seq_lens, device=device), dim=0)
529+
for j in range(len(doc_boundaries) - 1):
530+
start_idx = doc_boundaries[j]
531+
end_idx = doc_boundaries[j + 1]
532+
attn_mask[i, start_idx:end_idx, start_idx:end_idx] = True
533+
return attn_mask
534+
if self.attention_impl == AttentionImplementation.DAO_FLASH:
535+
concatenated_lengths = self._build_concatenated_lengths_tensor(
536+
in_batch_seq_lens=in_batch_seq_lens,
537+
max_seq_len=max_seq_len,
538+
device=device,
539+
)
540+
return self._get_unpad_data_for_concatenated_sequences(concatenated_lengths)
541+
if self.attention_impl == AttentionImplementation.PYTORCH_FLASH:
542+
raise NotImplementedError(
543+
"Inter-document masking is not supported for `pytorch_flash`. " "Use `manual` or `dao_flash`."
544+
)
545+
raise NotImplementedError(
546+
f"Attention implementation {self.attention_impl} is not supported for inter-document masking."
547+
)
548+
549+
@staticmethod
550+
def _build_concatenated_lengths_tensor(
551+
in_batch_seq_lens: list[list[int]], max_seq_len: int, device: torch.device
552+
) -> torch.Tensor:
553+
"""
554+
Build a tensor of concatenated subsequence lengths for each batch item.
555+
Args:
556+
in_batch_seq_lens: A list of per-batch lists, where each inner list contains
557+
the lengths of subsequences for that batch item.
558+
max_seq_len: The maximum allowed sequence length (number of subsequences and
559+
total length constraints are validated against this value).
560+
device: The torch device on which to allocate the output tensor.
561+
Returns:
562+
A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths
563+
for each batch item, padded with zeros beyond the number of subsequences.
564+
Raises:
565+
ValueError: If a batch item has more subsequences than max_seq_len or if the
566+
sum of its subsequence lengths exceeds max_seq_len.
567+
"""
568+
batch_size = len(in_batch_seq_lens)
569+
concatenated_lengths = torch.zeros((batch_size, max_seq_len), dtype=torch.int32, device=device)
570+
for batch_idx, doc_seq_lens in enumerate(in_batch_seq_lens):
571+
if len(doc_seq_lens) > max_seq_len:
572+
raise ValueError(
573+
f"Number of subsequences ({len(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
574+
f"for batch index {batch_idx}."
575+
)
576+
if sum(doc_seq_lens) > max_seq_len:
577+
raise ValueError(
578+
f"Sum of subsequence lengths ({sum(doc_seq_lens)}) exceeds max_seq_len ({max_seq_len}) "
579+
f"for batch index {batch_idx}."
580+
)
581+
if len(doc_seq_lens) > 0:
582+
concatenated_lengths[batch_idx, : len(doc_seq_lens)] = torch.tensor(
583+
doc_seq_lens, dtype=torch.int32, device=device
584+
)
585+
return concatenated_lengths
586+
587+
@staticmethod
588+
def _get_unpad_data_for_concatenated_sequences(
589+
attention_mask_in_length: torch.Tensor,
590+
) -> tuple[torch.Tensor, torch.Tensor, int]:
591+
"""
592+
Compute unpadded indices and cumulative sequence lengths for concatenated sequences.
593+
Given a batch of per-subsequence lengths in `attention_mask_in_length`, this
594+
builds a boolean mask over the maximum sequence length, extracts flattened
595+
indices of valid (unpadded) tokens, and returns cumulative sequence lengths
596+
(CU) along with the maximum subsequence length in the batch.
597+
Args:
598+
attention_mask_in_length (torch.Tensor): Tensor of shape (num_subsequences,)
599+
containing the lengths of each subsequence in the concatenated batch.
600+
Returns:
601+
tuple[torch.Tensor, torch.Tensor, int]:
602+
- indices: 1D tensor of flattened indices for all valid (unpadded) tokens.
603+
- cu_seqlens: 1D int32 tensor of cumulative sequence lengths with a
604+
leading zero (shape: num_subsequences + 1).
605+
- max_seqlen_in_batch: Maximum subsequence length as an int.
606+
Raises:
607+
ValueError: If no subsequence lengths are provided (all zeros).
608+
"""
609+
610+
length = attention_mask_in_length.sum(dim=-1)
611+
seqlen = attention_mask_in_length.size(-1)
612+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(
613+
len(length), seqlen
614+
) < length.unsqueeze(1)
615+
seqlens_in_batch = attention_mask_in_length[attention_mask_in_length > 0]
616+
if seqlens_in_batch.numel() == 0:
617+
raise ValueError("No subsequence lengths provided for inter-document masking.")
618+
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
619+
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
620+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
621+
return indices, cu_seqlens, max_seqlen_in_batch
622+
623+
@classmethod
624+
def _execute_dao_flash_with_inter_document_masking(
625+
cls,
626+
q: torch.Tensor,
627+
k: torch.Tensor,
628+
v: torch.Tensor,
629+
dropout: float,
630+
attention_masking_information: tuple[torch.Tensor, torch.Tensor, int],
631+
) -> torch.Tensor:
632+
if flash_attn_varlen_func is None:
633+
raise NotImplementedError(
634+
"ERROR! Dao Flash Attention varlen kernel is not available. " "Install flash-attn with varlen support."
635+
)
636+
637+
indices, cu_seqlens, max_seqlen = attention_masking_information
638+
639+
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
640+
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
641+
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
642+
643+
batch_size, seq_len, n_head_q, head_dim = q.shape
644+
n_head_kv = k.shape[2]
645+
646+
q_flat = q.reshape(batch_size * seq_len, n_head_q, head_dim)
647+
k_flat = k.reshape(batch_size * seq_len, n_head_kv, head_dim)
648+
v_flat = v.reshape(batch_size * seq_len, n_head_kv, head_dim)
649+
650+
q_unpad = q_flat.index_select(0, indices)
651+
k_unpad = k_flat.index_select(0, indices)
652+
v_unpad = v_flat.index_select(0, indices)
653+
654+
y_unpad = flash_attn_varlen_func(
655+
q_unpad,
656+
k_unpad,
657+
v_unpad,
658+
cu_seqlens_q=cu_seqlens,
659+
cu_seqlens_k=cu_seqlens,
660+
max_seqlen_q=max_seqlen,
661+
max_seqlen_k=max_seqlen,
662+
dropout_p=dropout,
663+
causal=True,
664+
softmax_scale=None,
665+
window_size=(-1, -1),
666+
)
667+
668+
y = torch.zeros(
669+
(batch_size * seq_len, n_head_q, head_dim),
670+
dtype=y_unpad.dtype,
671+
device=y_unpad.device,
672+
)
673+
y.index_copy_(0, indices, y_unpad)
674+
y = y.reshape(batch_size, seq_len, n_head_q, head_dim)
675+
return y
676+
504677
def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
505678
"""
506679
Applies projections to the input tensor to get queries, keys, and values.
@@ -600,6 +773,7 @@ def execute_attention(
600773
v: torch.Tensor,
601774
dropout: float,
602775
attention_impl: AttentionImplementation,
776+
attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None,
603777
) -> torch.Tensor:
604778
"""
605779
Executes attention mechanism based on the specified implementation.
@@ -611,6 +785,8 @@ def execute_attention(
611785
v (torch.Tensor): The value tensor.
612786
dropout (float): The dropout rate.
613787
attention_impl (AttentionImplementation): The attention implementation to use.
788+
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
789+
Optional tensor containing masking information for inter-document attention.
614790
615791
Returns:
616792
torch.Tensor: The output tensor.
@@ -624,7 +800,7 @@ def execute_attention(
624800
query=q,
625801
key=k,
626802
value=v,
627-
attn_mask=None,
803+
attn_mask=attention_masking_information,
628804
dropout_p=dropout,
629805
is_causal=True,
630806
) # (B, nh_q, T, hd)
@@ -647,22 +823,37 @@ def execute_attention(
647823
if flash_attn_func is None:
648824
raise NotImplementedError("ERROR! Dao Flash Attention is not installed.")
649825
# the next three lines are only needed for flash-attn from Daio Lab
650-
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
651-
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
652-
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
653-
y = flash_attn_func(
654-
q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
655-
) # (B, T, nh_q, hd)
826+
if attention_masking_information is None:
827+
q = q.transpose(1, 2).contiguous() # (B, T, nh_q, hd)
828+
k = k.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
829+
v = v.transpose(1, 2).contiguous() # (B, T, nh_kv, hd)
830+
y = flash_attn_func(
831+
q, k, v, dropout_p=dropout, causal=True, softmax_scale=None, window_size=(-1, -1)
832+
) # (B, T, nh_q, hd)
833+
else:
834+
y = cls._execute_dao_flash_with_inter_document_masking(
835+
q=q,
836+
k=k,
837+
v=v,
838+
dropout=dropout,
839+
attention_masking_information=attention_masking_information,
840+
)
656841
else:
657842
raise NotImplementedError(f"Attention implementation {attention_impl} not supported")
658843
return y # (B, T, nh_q, hd)
659844

660-
def forward(self, x: torch.Tensor) -> torch.Tensor:
845+
def forward(
846+
self,
847+
x: torch.Tensor,
848+
attention_masking_information: torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None = None,
849+
) -> torch.Tensor:
661850
"""
662851
Forward pass of the CausalSelfAttention module.
663852
664853
Args:
665854
x (torch.Tensor): Input tensor of shape (B, T, n_embd)
855+
attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
856+
Optional tensor containing masking information for inter-document attention.
666857
667858
Returns:
668859
torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection.
@@ -675,7 +866,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
675866
if self.q_norm is not None and self.k_norm is not None:
676867
q = self.q_norm(q)
677868
k = self.k_norm(k)
678-
y = CausalSelfAttention.execute_attention(q, k, v, self.dropout, self.attention_impl) # (B, T, nh_q, hd)
869+
y = CausalSelfAttention.execute_attention(
870+
q, k, v, self.dropout, self.attention_impl, attention_masking_information
871+
) # (B, T, nh_q, hd)
679872
y = y.reshape(B, T, -1) # (B, T, n_embd), re-assemble all head outputs side by side
680873
return self.resid_dropout(self.c_proj(y)) # (B, T, n_embd), output projection
681874

@@ -798,7 +991,7 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None:
798991
f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`."
799992
)
800993

801-
def forward(self, x: torch.Tensor) -> torch.Tensor:
994+
def forward(self, x: torch.Tensor, attention_masking_information: torch.Tensor | None = None) -> torch.Tensor:
802995
"""
803996
Forward pass of the GPT2Block.
804997
@@ -808,7 +1001,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8081001
Returns:
8091002
torch.Tensor: Output tensor.
8101003
"""
811-
x = x + self.attn(self.attention_norm(x))
1004+
x = x + self.attn(self.attention_norm(x), attention_masking_information=attention_masking_information)
8121005
x = x + self.mlp(self.ffn_norm(x))
8131006
return x
8141007

@@ -839,6 +1032,7 @@ def __init__(
8391032
use_weight_tying: bool,
8401033
seed: Optional[int] = None,
8411034
enforce_swiglu_hidden_dim_multiple_of: int = 256,
1035+
sub_seq_lengths_key: str | None = None,
8421036
):
8431037
"""
8441038
Initializes the GPT2LLM object.
@@ -867,6 +1061,8 @@ def __init__(
8671061
enforce_swiglu_hidden_dim_multiple_of (int): Enforces
8681062
the hidden dimension in the SwiGLU layer to be a multiple of this value.
8691063
Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256.
1064+
sub_seq_lengths_key (str, optional): The key for sub sequence lengths to be
1065+
used for inter document masking.
8701066
"""
8711067
weight_decay_groups = {
8721068
"linear": [".attn", ".mlp", ".lm_head.weight"],
@@ -876,6 +1072,7 @@ def __init__(
8761072
super().__init__(weight_decay_groups=weight_decay_groups, seed=seed)
8771073
self.sample_key = sample_key
8781074
self.prediction_key = prediction_key
1075+
self.sub_seq_lengths_key = sub_seq_lengths_key
8791076
self.sequence_length = sequence_length
8801077
self.n_embd = n_embd
8811078
self.n_layer = n_layer
@@ -981,16 +1178,22 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t
9811178
dict[str, torch.Tensor] | torch.Tensor: Model output.
9821179
"""
9831180
if isinstance(inputs, dict):
984-
return {self.prediction_key: self.forward_impl(inputs[self.sample_key])}
1181+
return {
1182+
self.prediction_key: self.forward_impl(
1183+
inputs[self.sample_key], sub_seq_lengths=inputs.get(self.sub_seq_lengths_key)
1184+
)
1185+
}
9851186
else:
9861187
return self.forward_impl(inputs)
9871188

988-
def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
1189+
def forward_impl(self, inputs: torch.Tensor, sub_seq_lengths: list[list[int]] | None = None) -> torch.Tensor:
9891190
"""
9901191
Forward pass implementation of the GPT2LLM module.
9911192
9921193
Args:
9931194
inputs (torch.Tensor): A tensor containing input token ids.
1195+
sub_seq_lengths (list[list[int]], optional): The lengths of the subsequences of each sequence
1196+
in the batch. To be used for inter document masking.
9941197
9951198
Returns:
9961199
torch.Tensor: A tensor containing output logits.
@@ -1013,8 +1216,16 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
10131216
# TODO: use drop out also without absolute position embedding?
10141217
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h
10151218

1219+
# TODO: Handle this in case of pipeline parallelism.
1220+
if sub_seq_lengths is not None:
1221+
attention_masking_information = self.transformer.h["0"].attn.prepare_inter_document_masking(
1222+
in_batch_seq_lens=sub_seq_lengths, max_seq_len=seq_len
1223+
)
1224+
else:
1225+
attention_masking_information = None
1226+
10161227
for layer_idx in self.transformer.h:
1017-
h = self.transformer.h[layer_idx](h)
1228+
h = self.transformer.h[layer_idx](h, attention_masking_information=attention_masking_information)
10181229
h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h
10191230
h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h
10201231
return h
@@ -1047,18 +1258,32 @@ def manual_scaled_dot_product_attention(
10471258
attn_bias = torch.zeros(
10481259
L, S, dtype=query.dtype, device=query.device
10491260
) # device added (not part of the original code)
1261+
if attn_mask is not None and attn_mask.dim() == 3:
1262+
attn_bias = attn_bias.unsqueeze(0).repeat(attn_mask.size(0), 1, 1)
10501263
if is_causal:
1051-
assert attn_mask is None
10521264
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) # device added
1053-
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
1265+
if attn_mask is None:
1266+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
1267+
elif attn_mask.dtype == torch.bool:
1268+
if attn_mask.dim() == 3:
1269+
combined_mask = temp_mask.unsqueeze(0) & attn_mask
1270+
else:
1271+
combined_mask = temp_mask & attn_mask
1272+
attn_bias.masked_fill_(combined_mask.logical_not(), float("-inf"))
1273+
else:
1274+
if attn_mask.dim() == 3:
1275+
temp_mask = temp_mask.unsqueeze(0)
1276+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
1277+
attn_bias += attn_mask
10541278
attn_bias.to(query.dtype)
1055-
1056-
if attn_mask is not None:
1279+
elif attn_mask is not None:
10571280
if attn_mask.dtype == torch.bool:
10581281
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
10591282
else:
10601283
attn_bias += attn_mask
10611284
attn_weight = query @ key.transpose(-2, -1) * scale_factor
1285+
if attn_bias.dim() == 3:
1286+
attn_bias = attn_bias.unsqueeze(1)
10621287
attn_weight += attn_bias
10631288
attn_weight = torch.softmax(attn_weight, dim=-1)
10641289
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)

0 commit comments

Comments
 (0)