@@ -211,6 +211,85 @@ def __init__(
211211
212212 self .reset_parameters ()
213213
214+ def reset_parameters (self ):
215+ # If previously initialized on or moved to a device, reuse that device.
216+ # Otherwise, use the default device of the current environment.
217+ device = self .inv_freq .device if hasattr (self , "inv_freq" ) and isinstance (self .inv_freq , torch .Tensor ) else None
218+
219+ rope_type = self .rope_scaling .rope_type if self .rope_scaling is not None else "default"
220+
221+ if rope_type == "yarn" :
222+ inv_freq , self .attention_scaling = self ._compute_yarn_parameters (device = device )
223+ else :
224+ inv_freq = 1.0 / (
225+ self .base_freq ** (torch .arange (0 , self .dim_model , 2 , device = device ).float () / self .dim_model )
226+ )
227+ self .attention_scaling = 1.0
228+
229+ self .register_buffer ("inv_freq" , inv_freq )
230+
231+ self ._seq_len_cached = None
232+ self ._cos_cached = None
233+ self ._sin_cached = None
234+
235+ def rotate_half (self , x : torch .Tensor ):
236+ """
237+ Rearrange tensor elements.
238+
239+ Args:
240+ x (torch.Tensor): The input tensor.
241+
242+ Returns:
243+ torch.Tensor: The output tensor.
244+
245+ """
246+ x1 , x2 = x .chunk (2 , dim = - 1 )
247+ return torch .cat ((- x2 , x1 ), dim = - 1 )
248+
249+ def apply_rotary_pos_emb (self , x , cos , sin ):
250+ """
251+ Applies rotary positional embedding to the input tensor.
252+
253+ Args:
254+ x (torch.Tensor): Input tensor.
255+ cos (torch.Tensor): Cosine values for rotary positional embedding.
256+ sin (torch.Tensor): Sine values for rotary positional embedding.
257+
258+ Returns:
259+ torch.Tensor: Tensor after applying rotary positional embedding.
260+ """
261+ # NOTE: This could probably be moved to Triton
262+
263+ # Handle a possible sequence length mismatch in between q and k
264+ cos = cos [:, :, : x .shape [self .seq_length_dim ], :]
265+ sin = sin [:, :, : x .shape [self .seq_length_dim ], :]
266+
267+ # the rotation is not really a rotation in higher dimensions,
268+ # It merely swaps and negates certain dimensions to make
269+ # the rotation below work
270+ return (x * cos ) + (self .rotate_half (x ) * sin )
271+
272+ def forward (
273+ self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor
274+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
275+ """
276+ Forward pass of the RotaryTransform module.
277+
278+ Args:
279+ q (torch.Tensor): Query tensor.
280+ k (torch.Tensor): Key tensor.
281+ v (torch.Tensor): Value tensor.
282+
283+ Returns:
284+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
285+ Tuple containing the modified query tensor, key tensor, and value tensor.
286+ """
287+ self ._cos_cached , self ._sin_cached = self ._update_cos_sin_tables (k )
288+ q = self .apply_rotary_pos_emb (q , self ._cos_cached , self ._sin_cached )
289+ k = self .apply_rotary_pos_emb (k , self ._cos_cached , self ._sin_cached )
290+
291+ return q , k , v
292+
214293 def _compute_yarn_parameters (self , device : torch .device | None ) -> tuple [torch .Tensor , float ]:
215294 """Compute YaRN inverse frequencies and the attention scaling factor."""
216295 if not isinstance (self .rope_scaling , YarnRopeScalingConfig ):
@@ -299,41 +378,6 @@ def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Te
299378
300379 return inv_freq , float (attention_factor )
301380
302- def reset_parameters (self ):
303- # If previously initialized on or moved to a device, reuse that device.
304- # Otherwise, use the default device of the current environment.
305- device = self .inv_freq .device if hasattr (self , "inv_freq" ) and isinstance (self .inv_freq , torch .Tensor ) else None
306-
307- rope_type = self .rope_scaling .rope_type if self .rope_scaling is not None else "default"
308-
309- if rope_type == "yarn" :
310- inv_freq , self .attention_scaling = self ._compute_yarn_parameters (device = device )
311- else :
312- inv_freq = 1.0 / (
313- self .base_freq ** (torch .arange (0 , self .dim_model , 2 , device = device ).float () / self .dim_model )
314- )
315- self .attention_scaling = 1.0
316-
317- self .register_buffer ("inv_freq" , inv_freq )
318-
319- self ._seq_len_cached = None
320- self ._cos_cached = None
321- self ._sin_cached = None
322-
323- def rotate_half (self , x : torch .Tensor ):
324- """
325- Rearrange tensor elements.
326-
327- Args:
328- x (torch.Tensor): The input tensor.
329-
330- Returns:
331- torch.Tensor: The output tensor.
332-
333- """
334- x1 , x2 = x .chunk (2 , dim = - 1 )
335- return torch .cat ((- x2 , x1 ), dim = - 1 )
336-
337381 def _update_cos_sin_tables (self , x ):
338382 # Update the cosine and sine tables.
339383 seq_len = x .shape [self .seq_length_dim ]
@@ -358,50 +402,6 @@ def _update_cos_sin_tables(self, x):
358402
359403 return self ._cos_cached , self ._sin_cached
360404
361- def apply_rotary_pos_emb (self , x , cos , sin ):
362- """
363- Applies rotary positional embedding to the input tensor.
364-
365- Args:
366- x (torch.Tensor): Input tensor.
367- cos (torch.Tensor): Cosine values for rotary positional embedding.
368- sin (torch.Tensor): Sine values for rotary positional embedding.
369-
370- Returns:
371- torch.Tensor: Tensor after applying rotary positional embedding.
372- """
373- # NOTE: This could probably be moved to Triton
374-
375- # Handle a possible sequence length mismatch in between q and k
376- cos = cos [:, :, : x .shape [self .seq_length_dim ], :]
377- sin = sin [:, :, : x .shape [self .seq_length_dim ], :]
378-
379- # the rotation is not really a rotation in higher dimensions,
380- # It merely swaps and negates certain dimensions to make
381- # the rotation below work
382- return (x * cos ) + (self .rotate_half (x ) * sin )
383-
384- def forward (
385- self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor
386- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
387- """
388- Forward pass of the RotaryTransform module.
389-
390- Args:
391- q (torch.Tensor): Query tensor.
392- k (torch.Tensor): Key tensor.
393- v (torch.Tensor): Value tensor.
394-
395- Returns:
396- tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
397- Tuple containing the modified query tensor, key tensor, and value tensor.
398- """
399- self ._cos_cached , self ._sin_cached = self ._update_cos_sin_tables (k )
400- q = self .apply_rotary_pos_emb (q , self ._cos_cached , self ._sin_cached )
401- k = self .apply_rotary_pos_emb (k , self ._cos_cached , self ._sin_cached )
402-
403- return q , k , v
404-
405405
406406class QueryKeyValueTransformType (Enum ):
407407 """
0 commit comments