2020from modalities .util import parse_enum_by_name
2121
2222try :
23- from flash_attn import flash_attn_func
23+ from flash_attn import flash_attn_func , flash_attn_varlen_func
2424except ModuleNotFoundError :
2525 flash_attn_func = None
26+ flash_attn_varlen_func = None
2627
2728# Logger configuration
2829logger = logging .getLogger (__name__ )
@@ -501,6 +502,178 @@ def __init__(
501502 self .q_norm = None
502503 self .k_norm = None
503504
505+ def prepare_inter_document_masking (
506+ self , in_batch_seq_lens : list [list [int ]], max_seq_len : int
507+ ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor , int ]:
508+ """
509+ Prepares the inter-document attention mask based on the input batch sequence lengths.
510+ For manual attention, a 3D attention mask of shape (batch_size, total_seq_len, total_seq_len) is returned.
511+ For flash attention, the cu_seqlens are computed and returned along with the indices
512+ of valid tokens and the maximum sequence length in the batch.
513+ For sdp attention, an exception is raised for now.
514+
515+ Args:
516+ in_batch_seq_lens (list[list[int]]): A list of lists containing the sequence
517+ lengths for each document in the batch.
518+ max_seq_len (int): The maximum sequence length in the batch.
519+
520+ Returns:
521+ torch.Tensor | tuple[torch.Tensor, torch.Tensor, int]: The inter-document masking information.
522+ """
523+ device = self .c_proj .weight .device
524+ if self .attention_impl == AttentionImplementation .MANUAL :
525+ batch_size = len (in_batch_seq_lens )
526+ attn_mask = torch .zeros ((batch_size , max_seq_len , max_seq_len ), dtype = torch .bool , device = device )
527+ for i , doc_seq_lens in enumerate (in_batch_seq_lens ):
528+ doc_boundaries = torch .cumsum (torch .tensor ([0 ] + doc_seq_lens , device = device ), dim = 0 )
529+ for j in range (len (doc_boundaries ) - 1 ):
530+ start_idx = doc_boundaries [j ]
531+ end_idx = doc_boundaries [j + 1 ]
532+ attn_mask [i , start_idx :end_idx , start_idx :end_idx ] = True
533+ return attn_mask
534+ if self .attention_impl == AttentionImplementation .DAO_FLASH :
535+ concatenated_lengths = self ._build_concatenated_lengths_tensor (
536+ in_batch_seq_lens = in_batch_seq_lens ,
537+ max_seq_len = max_seq_len ,
538+ device = device ,
539+ )
540+ return self ._get_unpad_data_for_concatenated_sequences (concatenated_lengths )
541+ if self .attention_impl == AttentionImplementation .PYTORCH_FLASH :
542+ raise NotImplementedError (
543+ "Inter-document masking is not supported for `pytorch_flash`. " "Use `manual` or `dao_flash`."
544+ )
545+ raise NotImplementedError (
546+ f"Attention implementation { self .attention_impl } is not supported for inter-document masking."
547+ )
548+
549+ @staticmethod
550+ def _build_concatenated_lengths_tensor (
551+ in_batch_seq_lens : list [list [int ]], max_seq_len : int , device : torch .device
552+ ) -> torch .Tensor :
553+ """
554+ Build a tensor of concatenated subsequence lengths for each batch item.
555+ Args:
556+ in_batch_seq_lens: A list of per-batch lists, where each inner list contains
557+ the lengths of subsequences for that batch item.
558+ max_seq_len: The maximum allowed sequence length (number of subsequences and
559+ total length constraints are validated against this value).
560+ device: The torch device on which to allocate the output tensor.
561+ Returns:
562+ A tensor of shape (batch_size, max_seq_len) containing the subsequence lengths
563+ for each batch item, padded with zeros beyond the number of subsequences.
564+ Raises:
565+ ValueError: If a batch item has more subsequences than max_seq_len or if the
566+ sum of its subsequence lengths exceeds max_seq_len.
567+ """
568+ batch_size = len (in_batch_seq_lens )
569+ concatenated_lengths = torch .zeros ((batch_size , max_seq_len ), dtype = torch .int32 , device = device )
570+ for batch_idx , doc_seq_lens in enumerate (in_batch_seq_lens ):
571+ if len (doc_seq_lens ) > max_seq_len :
572+ raise ValueError (
573+ f"Number of subsequences ({ len (doc_seq_lens )} ) exceeds max_seq_len ({ max_seq_len } ) "
574+ f"for batch index { batch_idx } ."
575+ )
576+ if sum (doc_seq_lens ) > max_seq_len :
577+ raise ValueError (
578+ f"Sum of subsequence lengths ({ sum (doc_seq_lens )} ) exceeds max_seq_len ({ max_seq_len } ) "
579+ f"for batch index { batch_idx } ."
580+ )
581+ if len (doc_seq_lens ) > 0 :
582+ concatenated_lengths [batch_idx , : len (doc_seq_lens )] = torch .tensor (
583+ doc_seq_lens , dtype = torch .int32 , device = device
584+ )
585+ return concatenated_lengths
586+
587+ @staticmethod
588+ def _get_unpad_data_for_concatenated_sequences (
589+ attention_mask_in_length : torch .Tensor ,
590+ ) -> tuple [torch .Tensor , torch .Tensor , int ]:
591+ """
592+ Compute unpadded indices and cumulative sequence lengths for concatenated sequences.
593+ Given a batch of per-subsequence lengths in `attention_mask_in_length`, this
594+ builds a boolean mask over the maximum sequence length, extracts flattened
595+ indices of valid (unpadded) tokens, and returns cumulative sequence lengths
596+ (CU) along with the maximum subsequence length in the batch.
597+ Args:
598+ attention_mask_in_length (torch.Tensor): Tensor of shape (num_subsequences,)
599+ containing the lengths of each subsequence in the concatenated batch.
600+ Returns:
601+ tuple[torch.Tensor, torch.Tensor, int]:
602+ - indices: 1D tensor of flattened indices for all valid (unpadded) tokens.
603+ - cu_seqlens: 1D int32 tensor of cumulative sequence lengths with a
604+ leading zero (shape: num_subsequences + 1).
605+ - max_seqlen_in_batch: Maximum subsequence length as an int.
606+ Raises:
607+ ValueError: If no subsequence lengths are provided (all zeros).
608+ """
609+
610+ length = attention_mask_in_length .sum (dim = - 1 )
611+ seqlen = attention_mask_in_length .size (- 1 )
612+ attention_mask_2d = torch .arange (seqlen , device = length .device , dtype = length .dtype ).expand (
613+ len (length ), seqlen
614+ ) < length .unsqueeze (1 )
615+ seqlens_in_batch = attention_mask_in_length [attention_mask_in_length > 0 ]
616+ if seqlens_in_batch .numel () == 0 :
617+ raise ValueError ("No subsequence lengths provided for inter-document masking." )
618+ indices = torch .nonzero (attention_mask_2d .flatten (), as_tuple = False ).flatten ()
619+ max_seqlen_in_batch = int (seqlens_in_batch .max ().item ())
620+ cu_seqlens = torch .nn .functional .pad (torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
621+ return indices , cu_seqlens , max_seqlen_in_batch
622+
623+ @classmethod
624+ def _execute_dao_flash_with_inter_document_masking (
625+ cls ,
626+ q : torch .Tensor ,
627+ k : torch .Tensor ,
628+ v : torch .Tensor ,
629+ dropout : float ,
630+ attention_masking_information : tuple [torch .Tensor , torch .Tensor , int ],
631+ ) -> torch .Tensor :
632+ if flash_attn_varlen_func is None :
633+ raise NotImplementedError (
634+ "ERROR! Dao Flash Attention varlen kernel is not available. " "Install flash-attn with varlen support."
635+ )
636+
637+ indices , cu_seqlens , max_seqlen = attention_masking_information
638+
639+ q = q .transpose (1 , 2 ).contiguous () # (B, T, nh_q, hd)
640+ k = k .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
641+ v = v .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
642+
643+ batch_size , seq_len , n_head_q , head_dim = q .shape
644+ n_head_kv = k .shape [2 ]
645+
646+ q_flat = q .reshape (batch_size * seq_len , n_head_q , head_dim )
647+ k_flat = k .reshape (batch_size * seq_len , n_head_kv , head_dim )
648+ v_flat = v .reshape (batch_size * seq_len , n_head_kv , head_dim )
649+
650+ q_unpad = q_flat .index_select (0 , indices )
651+ k_unpad = k_flat .index_select (0 , indices )
652+ v_unpad = v_flat .index_select (0 , indices )
653+
654+ y_unpad = flash_attn_varlen_func (
655+ q_unpad ,
656+ k_unpad ,
657+ v_unpad ,
658+ cu_seqlens_q = cu_seqlens ,
659+ cu_seqlens_k = cu_seqlens ,
660+ max_seqlen_q = max_seqlen ,
661+ max_seqlen_k = max_seqlen ,
662+ dropout_p = dropout ,
663+ causal = True ,
664+ softmax_scale = None ,
665+ window_size = (- 1 , - 1 ),
666+ )
667+
668+ y = torch .zeros (
669+ (batch_size * seq_len , n_head_q , head_dim ),
670+ dtype = y_unpad .dtype ,
671+ device = y_unpad .device ,
672+ )
673+ y .index_copy_ (0 , indices , y_unpad )
674+ y = y .reshape (batch_size , seq_len , n_head_q , head_dim )
675+ return y
676+
504677 def projection (self , x : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
505678 """
506679 Applies projections to the input tensor to get queries, keys, and values.
@@ -600,6 +773,7 @@ def execute_attention(
600773 v : torch .Tensor ,
601774 dropout : float ,
602775 attention_impl : AttentionImplementation ,
776+ attention_masking_information : torch .Tensor | tuple [torch .Tensor , torch .Tensor , int ] | None = None ,
603777 ) -> torch .Tensor :
604778 """
605779 Executes attention mechanism based on the specified implementation.
@@ -611,6 +785,8 @@ def execute_attention(
611785 v (torch.Tensor): The value tensor.
612786 dropout (float): The dropout rate.
613787 attention_impl (AttentionImplementation): The attention implementation to use.
788+ attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
789+ Optional tensor containing masking information for inter-document attention.
614790
615791 Returns:
616792 torch.Tensor: The output tensor.
@@ -624,7 +800,7 @@ def execute_attention(
624800 query = q ,
625801 key = k ,
626802 value = v ,
627- attn_mask = None ,
803+ attn_mask = attention_masking_information ,
628804 dropout_p = dropout ,
629805 is_causal = True ,
630806 ) # (B, nh_q, T, hd)
@@ -647,22 +823,37 @@ def execute_attention(
647823 if flash_attn_func is None :
648824 raise NotImplementedError ("ERROR! Dao Flash Attention is not installed." )
649825 # the next three lines are only needed for flash-attn from Daio Lab
650- q = q .transpose (1 , 2 ).contiguous () # (B, T, nh_q, hd)
651- k = k .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
652- v = v .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
653- y = flash_attn_func (
654- q , k , v , dropout_p = dropout , causal = True , softmax_scale = None , window_size = (- 1 , - 1 )
655- ) # (B, T, nh_q, hd)
826+ if attention_masking_information is None :
827+ q = q .transpose (1 , 2 ).contiguous () # (B, T, nh_q, hd)
828+ k = k .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
829+ v = v .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
830+ y = flash_attn_func (
831+ q , k , v , dropout_p = dropout , causal = True , softmax_scale = None , window_size = (- 1 , - 1 )
832+ ) # (B, T, nh_q, hd)
833+ else :
834+ y = cls ._execute_dao_flash_with_inter_document_masking (
835+ q = q ,
836+ k = k ,
837+ v = v ,
838+ dropout = dropout ,
839+ attention_masking_information = attention_masking_information ,
840+ )
656841 else :
657842 raise NotImplementedError (f"Attention implementation { attention_impl } not supported" )
658843 return y # (B, T, nh_q, hd)
659844
660- def forward (self , x : torch .Tensor ) -> torch .Tensor :
845+ def forward (
846+ self ,
847+ x : torch .Tensor ,
848+ attention_masking_information : torch .Tensor | tuple [torch .Tensor , torch .Tensor , int ] | None = None ,
849+ ) -> torch .Tensor :
661850 """
662851 Forward pass of the CausalSelfAttention module.
663852
664853 Args:
665854 x (torch.Tensor): Input tensor of shape (B, T, n_embd)
855+ attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
856+ Optional tensor containing masking information for inter-document attention.
666857
667858 Returns:
668859 torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection.
@@ -675,7 +866,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
675866 if self .q_norm is not None and self .k_norm is not None :
676867 q = self .q_norm (q )
677868 k = self .k_norm (k )
678- y = CausalSelfAttention .execute_attention (q , k , v , self .dropout , self .attention_impl ) # (B, T, nh_q, hd)
869+ y = CausalSelfAttention .execute_attention (
870+ q , k , v , self .dropout , self .attention_impl , attention_masking_information
871+ ) # (B, T, nh_q, hd)
679872 y = y .reshape (B , T , - 1 ) # (B, T, n_embd), re-assemble all head outputs side by side
680873 return self .resid_dropout (self .c_proj (y )) # (B, T, n_embd), output projection
681874
@@ -798,7 +991,7 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None:
798991 f"but got `n_embd = { n_embd } ` and `ffn_hidden = { ffn_hidden } `."
799992 )
800993
801- def forward (self , x : torch .Tensor ) -> torch .Tensor :
994+ def forward (self , x : torch .Tensor , attention_masking_information : torch . Tensor | None = None ) -> torch .Tensor :
802995 """
803996 Forward pass of the GPT2Block.
804997
@@ -808,7 +1001,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8081001 Returns:
8091002 torch.Tensor: Output tensor.
8101003 """
811- x = x + self .attn (self .attention_norm (x ))
1004+ x = x + self .attn (self .attention_norm (x ), attention_masking_information = attention_masking_information )
8121005 x = x + self .mlp (self .ffn_norm (x ))
8131006 return x
8141007
@@ -839,6 +1032,7 @@ def __init__(
8391032 use_weight_tying : bool ,
8401033 seed : Optional [int ] = None ,
8411034 enforce_swiglu_hidden_dim_multiple_of : int = 256 ,
1035+ sub_seq_lengths_key : str | None = None ,
8421036 ):
8431037 """
8441038 Initializes the GPT2LLM object.
@@ -867,6 +1061,8 @@ def __init__(
8671061 enforce_swiglu_hidden_dim_multiple_of (int): Enforces
8681062 the hidden dimension in the SwiGLU layer to be a multiple of this value.
8691063 Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256.
1064+ sub_seq_lengths_key (str, optional): The key for sub sequence lengths to be
1065+ used for inter document masking.
8701066 """
8711067 weight_decay_groups = {
8721068 "linear" : [".attn" , ".mlp" , ".lm_head.weight" ],
@@ -876,6 +1072,7 @@ def __init__(
8761072 super ().__init__ (weight_decay_groups = weight_decay_groups , seed = seed )
8771073 self .sample_key = sample_key
8781074 self .prediction_key = prediction_key
1075+ self .sub_seq_lengths_key = sub_seq_lengths_key
8791076 self .sequence_length = sequence_length
8801077 self .n_embd = n_embd
8811078 self .n_layer = n_layer
@@ -981,16 +1178,22 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t
9811178 dict[str, torch.Tensor] | torch.Tensor: Model output.
9821179 """
9831180 if isinstance (inputs , dict ):
984- return {self .prediction_key : self .forward_impl (inputs [self .sample_key ])}
1181+ return {
1182+ self .prediction_key : self .forward_impl (
1183+ inputs [self .sample_key ], sub_seq_lengths = inputs .get (self .sub_seq_lengths_key )
1184+ )
1185+ }
9851186 else :
9861187 return self .forward_impl (inputs )
9871188
988- def forward_impl (self , inputs : torch .Tensor ) -> torch .Tensor :
1189+ def forward_impl (self , inputs : torch .Tensor , sub_seq_lengths : list [ list [ int ]] | None = None ) -> torch .Tensor :
9891190 """
9901191 Forward pass implementation of the GPT2LLM module.
9911192
9921193 Args:
9931194 inputs (torch.Tensor): A tensor containing input token ids.
1195+ sub_seq_lengths (list[list[int]], optional): The lengths of the subsequences of each sequence
1196+ in the batch. To be used for inter document masking.
9941197
9951198 Returns:
9961199 torch.Tensor: A tensor containing output logits.
@@ -1013,8 +1216,16 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
10131216 # TODO: use drop out also without absolute position embedding?
10141217 h = self .transformer .drop (h ) if hasattr (self .transformer , "drop" ) else h
10151218
1219+ # TODO: Handle this in case of pipeline parallelism.
1220+ if sub_seq_lengths is not None :
1221+ attention_masking_information = self .transformer .h ["0" ].attn .prepare_inter_document_masking (
1222+ in_batch_seq_lens = sub_seq_lengths , max_seq_len = seq_len
1223+ )
1224+ else :
1225+ attention_masking_information = None
1226+
10161227 for layer_idx in self .transformer .h :
1017- h = self .transformer .h [layer_idx ](h )
1228+ h = self .transformer .h [layer_idx ](h , attention_masking_information = attention_masking_information )
10181229 h = self .transformer .lm_head_norm (h ) if hasattr (self .transformer , "lm_head_norm" ) else h
10191230 h = self .transformer .lm_head (h ) if hasattr (self .transformer , "lm_head" ) else h
10201231 return h
@@ -1047,18 +1258,32 @@ def manual_scaled_dot_product_attention(
10471258 attn_bias = torch .zeros (
10481259 L , S , dtype = query .dtype , device = query .device
10491260 ) # device added (not part of the original code)
1261+ if attn_mask is not None and attn_mask .dim () == 3 :
1262+ attn_bias = attn_bias .unsqueeze (0 ).repeat (attn_mask .size (0 ), 1 , 1 )
10501263 if is_causal :
1051- assert attn_mask is None
10521264 temp_mask = torch .ones (L , S , dtype = torch .bool , device = query .device ).tril (diagonal = 0 ) # device added
1053- attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
1265+ if attn_mask is None :
1266+ attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
1267+ elif attn_mask .dtype == torch .bool :
1268+ if attn_mask .dim () == 3 :
1269+ combined_mask = temp_mask .unsqueeze (0 ) & attn_mask
1270+ else :
1271+ combined_mask = temp_mask & attn_mask
1272+ attn_bias .masked_fill_ (combined_mask .logical_not (), float ("-inf" ))
1273+ else :
1274+ if attn_mask .dim () == 3 :
1275+ temp_mask = temp_mask .unsqueeze (0 )
1276+ attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
1277+ attn_bias += attn_mask
10541278 attn_bias .to (query .dtype )
1055-
1056- if attn_mask is not None :
1279+ elif attn_mask is not None :
10571280 if attn_mask .dtype == torch .bool :
10581281 attn_bias .masked_fill_ (attn_mask .logical_not (), float ("-inf" ))
10591282 else :
10601283 attn_bias += attn_mask
10611284 attn_weight = query @ key .transpose (- 2 , - 1 ) * scale_factor
1285+ if attn_bias .dim () == 3 :
1286+ attn_bias = attn_bias .unsqueeze (1 )
10621287 attn_weight += attn_bias
10631288 attn_weight = torch .softmax (attn_weight , dim = - 1 )
10641289 attn_weight = torch .dropout (attn_weight , dropout_p , train = True )
0 commit comments