22import math
33from abc import abstractmethod
44from enum import Enum
5- from typing import Annotated , Optional , overload
5+ from numbers import Real
6+ from typing import Annotated , Literal , Optional , overload
67
78import torch
89import torch .nn as nn
9- from pydantic import BaseModel , Field , model_validator , validator
10+ from pydantic import BaseModel , Field , field_validator , model_validator , validator
1011
1112from modalities .config .lookup_enum import LookupEnum
1213from modalities .config .utils import convert_base_model_config_to_dict
3132# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
3233
3334
35+ def _validate_numeric_field (field_name : str , value : object ) -> float :
36+ """Validate that a value is a real number (excluding bool) and cast to float."""
37+ if isinstance (value , bool ) or not isinstance (value , Real ):
38+ raise ValueError (f"rope_scaling.{ field_name } must be a float" )
39+ return float (value )
40+
41+
42+ class DefaultRopeScalingConfig (BaseModel ):
43+ """Configuration for default RoPE behavior."""
44+
45+ rope_type : Literal ["default" ] = "default"
46+
47+
48+ class YarnRopeScalingConfig (BaseModel ):
49+ """Configuration for YaRN RoPE scaling."""
50+
51+ rope_type : Literal ["yarn" ] = "yarn"
52+ original_max_position_embeddings : Annotated [int , Field (strict = True , ge = 1 )]
53+ factor : Optional [Annotated [float , Field (ge = 1.0 )]] = None
54+ attention_factor : Optional [Annotated [float , Field (gt = 0.0 )]] = None
55+ mscale : Optional [Annotated [float , Field (ge = 0.0 )]] = None
56+ mscale_all_dim : Optional [Annotated [float , Field (ge = 0.0 )]] = None
57+ beta_fast : Annotated [float , Field (ge = 0.0 )] = 32.0
58+ beta_slow : Annotated [float , Field (ge = 0.0 )] = 1.0
59+ truncate : bool = True
60+
61+ @field_validator (
62+ "factor" ,
63+ "attention_factor" ,
64+ "mscale" ,
65+ "mscale_all_dim" ,
66+ "beta_fast" ,
67+ "beta_slow" ,
68+ mode = "before" ,
69+ )
70+ @classmethod
71+ def validate_numeric_fields (cls , value : object , info ):
72+ if value is None :
73+ return value
74+ return _validate_numeric_field (info .field_name , value )
75+
76+ @model_validator (mode = "after" )
77+ def validate_mscale_pair (self ) -> "YarnRopeScalingConfig" :
78+ if (self .mscale is None ) != (self .mscale_all_dim is None ):
79+ raise ValueError ("rope_scaling.mscale and rope_scaling.mscale_all_dim must be provided together" )
80+ return self
81+
82+
83+ RopeScalingConfig = Annotated [
84+ DefaultRopeScalingConfig | YarnRopeScalingConfig ,
85+ Field (discriminator = "rope_type" ),
86+ ]
87+
88+
3489class LayerNorms (LookupEnum ):
3590 """
3691 Enum lookup class for LayerNorms.
@@ -120,7 +175,15 @@ class RotaryTransform(QueryKeyValueTransform):
120175 XFormers implementation and removed in this implementation.#
121176 """
122177
123- def __init__ (self , n_embd : int , n_head : int , seq_length_dim : int = - 2 , base_freq : int = 10000 ):
178+ def __init__ (
179+ self ,
180+ n_embd : int ,
181+ n_head : int ,
182+ seq_length_dim : int = - 2 ,
183+ base_freq : int = 10000 ,
184+ max_position_embeddings : int | None = None ,
185+ rope_scaling : RopeScalingConfig | None = None ,
186+ ):
124187 """
125188 Initializes the RotaryTransform object.
126189
@@ -136,16 +199,33 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq
136199 self .dim_model = n_embd // n_head
137200 self .seq_length_dim = seq_length_dim
138201 self .base_freq = base_freq
202+ self .max_position_embeddings = max_position_embeddings
203+
204+ if rope_scaling is not None and not isinstance (rope_scaling , (DefaultRopeScalingConfig , YarnRopeScalingConfig )):
205+ raise TypeError (
206+ "rope_scaling must be an instance of DefaultRopeScalingConfig, YarnRopeScalingConfig, or None"
207+ )
208+
209+ self .rope_scaling = rope_scaling
210+ self .attention_scaling = 1.0
139211
140212 self .reset_parameters ()
141213
142214 def reset_parameters (self ):
143215 # If previously initialized on or moved to a device, reuse that device.
144216 # Otherwise, use the default device of the current environment.
145- device = self .inv_freq .device if hasattr (self , "inv_freq" ) else None
146- inv_freq = 1.0 / (
147- self .base_freq ** (torch .arange (0 , self .dim_model , 2 , device = device ).float () / self .dim_model )
148- )
217+ device = self .inv_freq .device if hasattr (self , "inv_freq" ) and isinstance (self .inv_freq , torch .Tensor ) else None
218+
219+ rope_type = self .rope_scaling .rope_type if self .rope_scaling is not None else "default"
220+
221+ if rope_type == "yarn" :
222+ inv_freq , self .attention_scaling = self ._compute_yarn_parameters (device = device )
223+ else :
224+ inv_freq = 1.0 / (
225+ self .base_freq ** (torch .arange (0 , self .dim_model , 2 , device = device ).float () / self .dim_model )
226+ )
227+ self .attention_scaling = 1.0
228+
149229 self .register_buffer ("inv_freq" , inv_freq )
150230
151231 self ._seq_len_cached = None
@@ -166,24 +246,6 @@ def rotate_half(self, x: torch.Tensor):
166246 x1 , x2 = x .chunk (2 , dim = - 1 )
167247 return torch .cat ((- x2 , x1 ), dim = - 1 )
168248
169- def _update_cos_sin_tables (self , x ):
170- # Update the cosine and sine tables.
171- seq_len = x .shape [self .seq_length_dim ]
172-
173- # Reset the tables if the sequence length has changed,
174- # or if we're on a new device (possibly due to tracing for instance)
175- if seq_len != self ._seq_len_cached or self ._cos_cached .device != x .device or self ._cos_cached .dtype != x .dtype :
176- self ._seq_len_cached = seq_len
177- t = torch .arange (x .shape [self .seq_length_dim ], device = x .device , dtype = torch .float32 )
178- freqs = torch .einsum ("i,j->ij" , t , self .inv_freq .to (x .dtype ))
179- emb = torch .cat ((freqs , freqs ), dim = - 1 ).to (
180- x .device
181- ) # here, we combine the two matrices (not zipping them).
182- self ._cos_cached = emb .cos ()[None , None , :, :].to (x .dtype )
183- self ._sin_cached = emb .sin ()[None , None , :, :].to (x .dtype )
184-
185- return self ._cos_cached , self ._sin_cached
186-
187249 def apply_rotary_pos_emb (self , x , cos , sin ):
188250 """
189251 Applies rotary positional embedding to the input tensor.
@@ -228,6 +290,118 @@ def forward(
228290
229291 return q , k , v
230292
293+ def _compute_yarn_parameters (self , device : torch .device | None ) -> tuple [torch .Tensor , float ]:
294+ """Compute YaRN inverse frequencies and the attention scaling factor."""
295+ if not isinstance (self .rope_scaling , YarnRopeScalingConfig ):
296+ raise ValueError ("YaRN requires a rope_scaling config." )
297+ if self .max_position_embeddings is None :
298+ raise ValueError ("YaRN requires max_position_embeddings to be set." )
299+
300+ original_max_position_embeddings = self .rope_scaling .original_max_position_embeddings
301+ factor = self .rope_scaling .factor
302+ if factor is None :
303+ factor = self .max_position_embeddings / original_max_position_embeddings
304+ factor_float = float (factor )
305+
306+ attention_factor = self .rope_scaling .attention_factor
307+ mscale_pair = None
308+ if self .rope_scaling .mscale is not None and self .rope_scaling .mscale_all_dim is not None :
309+ mscale_pair = (self .rope_scaling .mscale , self .rope_scaling .mscale_all_dim )
310+
311+ beta_fast = self .rope_scaling .beta_fast
312+ beta_slow = self .rope_scaling .beta_slow
313+ truncate = self .rope_scaling .truncate
314+
315+ def get_mscale (scale : float , mscale : float = 1.0 ) -> float :
316+ """Return the YaRN mscale coefficient for a given scaling factor."""
317+ if scale <= 1 :
318+ return 1.0
319+ return 0.1 * mscale * math .log (scale ) + 1.0
320+
321+ if attention_factor is None :
322+ if mscale_pair is not None :
323+ mscale , mscale_all_dim = mscale_pair
324+ attention_factor = float (
325+ get_mscale (factor_float , float (mscale )) / get_mscale (factor_float , float (mscale_all_dim ))
326+ )
327+ else :
328+ attention_factor = get_mscale (factor_float )
329+
330+ def find_correction_dim (num_rotations : float , dim : int , base : int , max_position_embeddings : int ) -> float :
331+ """Map a target number of rotations to a rotary dimension index."""
332+ return (dim * math .log (max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
333+
334+ def find_correction_range (
335+ low_rot : float ,
336+ high_rot : float ,
337+ dim : int ,
338+ base : int ,
339+ max_position_embeddings : int ,
340+ truncate : bool ,
341+ ) -> tuple [float , float ]:
342+ """Compute the lower and upper rotary-dimension correction bounds for YaRN."""
343+ low = find_correction_dim (low_rot , dim , base , max_position_embeddings )
344+ high = find_correction_dim (high_rot , dim , base , max_position_embeddings )
345+ if truncate :
346+ low = math .floor (low )
347+ high = math .ceil (high )
348+ return max (low , 0 ), min (high , dim - 1 )
349+
350+ def linear_ramp_factor (min_value : float , max_value : float , dim : int ) -> torch .Tensor :
351+ """Create a clamped linear ramp used to blend interpolation and extrapolation."""
352+ if min_value == max_value :
353+ max_value += 0.001
354+ linear_func = (torch .arange (dim , dtype = torch .float32 , device = device ) - min_value ) / (max_value - min_value )
355+ ramp_func = torch .clamp (linear_func , 0 , 1 )
356+ return ramp_func
357+
358+ dim = self .dim_model
359+ base = self .base_freq
360+
361+ pos_freqs = base ** (torch .arange (0 , dim , 2 , device = device , dtype = torch .float ) / dim )
362+ inv_freq_extrapolation = 1.0 / pos_freqs
363+ inv_freq_interpolation = 1.0 / (factor_float * pos_freqs )
364+
365+ low , high = find_correction_range (
366+ beta_fast ,
367+ beta_slow ,
368+ dim ,
369+ base ,
370+ original_max_position_embeddings ,
371+ bool (truncate ),
372+ )
373+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor (low , high , dim // 2 ).to (device = device , dtype = torch .float )
374+ inv_freq = (
375+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor )
376+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
377+ )
378+
379+ return inv_freq , float (attention_factor )
380+
381+ def _update_cos_sin_tables (self , x ):
382+ # Update the cosine and sine tables.
383+ seq_len = x .shape [self .seq_length_dim ]
384+
385+ # Reset the tables if the sequence length has changed,
386+ # or if we're on a new device (possibly due to tracing for instance)
387+ if (
388+ seq_len != self ._seq_len_cached
389+ or self ._cos_cached is None
390+ or self ._sin_cached is None
391+ or self ._cos_cached .device != x .device
392+ or self ._cos_cached .dtype != x .dtype
393+ ):
394+ self ._seq_len_cached = seq_len
395+ t = torch .arange (x .shape [self .seq_length_dim ], device = x .device , dtype = torch .float32 )
396+ freqs = torch .einsum ("i,j->ij" , t , self .inv_freq .to (x .dtype ))
397+ emb = torch .cat ((freqs , freqs ), dim = - 1 ).to (
398+ x .device
399+ ) # here, we combine the two matrices (not zipping them).
400+ self ._cos_cached = (emb .cos () * self .attention_scaling )[None , None , :, :].to (x .dtype )
401+ self ._sin_cached = (emb .sin () * self .attention_scaling )[None , None , :, :].to (x .dtype )
402+
403+ return self ._cos_cached , self ._sin_cached
404+
231405
232406class QueryKeyValueTransformType (Enum ):
233407 """
@@ -295,6 +469,15 @@ class RotaryTransformConfig(BaseModel):
295469 n_head : Annotated [int , Field (strict = True , ge = 0 )]
296470 seq_length_dim : Annotated [int , Field (strict = True )]
297471 base_freq : Annotated [int , Field (strict = True , ge = 10000 )]
472+ max_position_embeddings : Optional [Annotated [int , Field (strict = True , ge = 1 )]] = None
473+ rope_scaling : Optional [RopeScalingConfig ] = None
474+
475+ @model_validator (mode = "after" )
476+ def validate_rope_scaling (self ) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig" :
477+ """Validate rope_scaling cross-field constraints."""
478+ if isinstance (self .rope_scaling , YarnRopeScalingConfig ) and self .max_position_embeddings is None :
479+ raise ValueError ("YaRN requires max_position_embeddings to be set" )
480+ return self
298481
299482 @validator ("type_hint" , pre = True , always = True )
300483 def parse_sharding_strategy_by_name (cls , name ):
0 commit comments