@@ -636,10 +636,6 @@ def _execute_dao_flash_with_inter_document_masking(
636636
637637 indices , cu_seqlens , max_seqlen = attention_masking_information
638638
639- q = q .transpose (1 , 2 ).contiguous () # (B, T, nh_q, hd)
640- k = k .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
641- v = v .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
642-
643639 batch_size , seq_len , n_head_q , head_dim = q .shape
644640 n_head_kv = k .shape [2 ]
645641
@@ -822,11 +818,11 @@ def execute_attention(
822818 # Note, that the library is not required for the CPU-only tests.
823819 if flash_attn_func is None :
824820 raise NotImplementedError ("ERROR! Dao Flash Attention is not installed." )
825- # the next three lines are only needed for flash-attn from Daio Lab
821+ # the next three lines are only needed for flash-attn from Dao Lab
822+ q = q .transpose (1 , 2 ).contiguous () # (B, T, nh_q, hd)
823+ k = k .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
824+ v = v .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
826825 if attention_masking_information is None :
827- q = q .transpose (1 , 2 ).contiguous () # (B, T, nh_q, hd)
828- k = k .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
829- v = v .transpose (1 , 2 ).contiguous () # (B, T, nh_kv, hd)
830826 y = flash_attn_func (
831827 q , k , v , dropout_p = dropout , causal = True , softmax_scale = None , window_size = (- 1 , - 1 )
832828 ) # (B, T, nh_q, hd)
@@ -991,12 +987,18 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None:
991987 f"but got `n_embd = { n_embd } ` and `ffn_hidden = { ffn_hidden } `."
992988 )
993989
994- def forward (self , x : torch .Tensor , attention_masking_information : torch .Tensor | None = None ) -> torch .Tensor :
990+ def forward (
991+ self ,
992+ x : torch .Tensor ,
993+ attention_masking_information : torch .Tensor | tuple [torch .Tensor , torch .Tensor , int ] | None = None ,
994+ ) -> torch .Tensor :
995995 """
996996 Forward pass of the GPT2Block.
997997
998998 Args:
999999 x (torch.Tensor): Input tensor.
1000+ attention_masking_information (torch.Tensor | tuple[torch.Tensor, torch.Tensor, int] | None):
1001+ Attention masking information.
10001002
10011003 Returns:
10021004 torch.Tensor: Output tensor.
@@ -1140,13 +1142,14 @@ def __init__(
11401142 ) # https://paperswithcode.com/method/weight-tying
11411143
11421144 @overload
1143- def forward (self , inputs : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
1145+ def forward (self , inputs : dict [str , torch .Tensor | list [ list [ int ]] ]) -> dict [str , torch .Tensor ]:
11441146 """
11451147 Forward pass of the GPT2LLM module.
11461148
11471149 Args:
1148- inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
1150+ inputs (dict[str, torch.Tensor | list[list[int]] ]): A dictionary containing input tensors.
11491151 - sample_key (str): Key for the input tensor containing token ids.
1152+ - sub_seq_lengths_key (str, optional): Key for the input tensor containing subsequence lengths.
11501153
11511154 Returns:
11521155 dict[str, torch.Tensor]: A dictionary containing output tensors.
@@ -1167,12 +1170,14 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
11671170 """
11681171 ...
11691172
1170- def forward (self , inputs : dict [str , torch .Tensor ] | torch .Tensor ) -> dict [str , torch .Tensor ] | torch .Tensor :
1173+ def forward (
1174+ self , inputs : dict [str , torch .Tensor | list [list [int ]]] | torch .Tensor
1175+ ) -> dict [str , torch .Tensor ] | torch .Tensor :
11711176 """
11721177 Forward pass of the GPT2LLM module.
11731178
11741179 Args:
1175- inputs (dict[str, torch.Tensor] | torch.Tensor): Input data.
1180+ inputs (dict[str, torch.Tensor | list[list[int]] ] | torch.Tensor): Input data.
11761181
11771182 Returns:
11781183 dict[str, torch.Tensor] | torch.Tensor: Model output.
@@ -1258,6 +1263,7 @@ def manual_scaled_dot_product_attention(
12581263 attn_bias = torch .zeros (
12591264 L , S , dtype = query .dtype , device = query .device
12601265 ) # device added (not part of the original code)
1266+ fully_masked = None
12611267 if attn_mask is not None and attn_mask .dim () == 3 :
12621268 attn_bias = attn_bias .unsqueeze (0 ).repeat (attn_mask .size (0 ), 1 , 1 )
12631269 if is_causal :
@@ -1269,22 +1275,25 @@ def manual_scaled_dot_product_attention(
12691275 combined_mask = temp_mask .unsqueeze (0 ) & attn_mask
12701276 else :
12711277 combined_mask = temp_mask & attn_mask
1278+ fully_masked = ~ combined_mask .any (dim = - 1 )
12721279 attn_bias .masked_fill_ (combined_mask .logical_not (), float ("-inf" ))
12731280 else :
12741281 if attn_mask .dim () == 3 :
12751282 temp_mask = temp_mask .unsqueeze (0 )
12761283 attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
12771284 attn_bias += attn_mask
1278- attn_bias .to (query .dtype )
12791285 elif attn_mask is not None :
12801286 if attn_mask .dtype == torch .bool :
12811287 attn_bias .masked_fill_ (attn_mask .logical_not (), float ("-inf" ))
12821288 else :
12831289 attn_bias += attn_mask
1290+ attn_bias = attn_bias .to (query .dtype )
12841291 attn_weight = query @ key .transpose (- 2 , - 1 ) * scale_factor
12851292 if attn_bias .dim () == 3 :
12861293 attn_bias = attn_bias .unsqueeze (1 )
12871294 attn_weight += attn_bias
12881295 attn_weight = torch .softmax (attn_weight , dim = - 1 )
12891296 attn_weight = torch .dropout (attn_weight , dropout_p , train = True )
1297+ if fully_masked is not None and attn_weight .dim () == 4 :
1298+ attn_weight = attn_weight .masked_fill (fully_masked .unsqueeze (1 ).unsqueeze (- 1 ), 0.0 )
12901299 return attn_weight @ value
0 commit comments