@@ -152,6 +152,7 @@ def __init__(
152152 self .reset_parameters ()
153153
154154 def _compute_yarn_parameters (self , device : torch .device | None ) -> tuple [torch .Tensor , float ]:
155+ """Compute YaRN inverse frequencies and the attention scaling factor."""
155156 if self .rope_scaling is None :
156157 raise ValueError ("YaRN requires a rope_scaling config." )
157158 if self .max_position_embeddings is None :
@@ -182,6 +183,7 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T
182183 truncate = self .rope_scaling .get ("truncate" , True )
183184
184185 def get_mscale (scale : float , mscale : float = 1.0 ) -> float :
186+ """Return the YaRN mscale coefficient for a given scaling factor."""
185187 if scale <= 1 :
186188 return 1.0
187189 return 0.1 * mscale * math .log (scale ) + 1.0
@@ -197,6 +199,7 @@ def get_mscale(scale: float, mscale: float = 1.0) -> float:
197199 raise ValueError ("YaRN requires rope_scaling.attention_factor to be a float > 0" )
198200
199201 def find_correction_dim (num_rotations : float , dim : int , base : int , max_position_embeddings : int ) -> float :
202+ """Map a target number of rotations to a rotary dimension index."""
200203 return (dim * math .log (max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
201204
202205 def find_correction_range (
@@ -207,6 +210,7 @@ def find_correction_range(
207210 max_position_embeddings : int ,
208211 truncate : bool ,
209212 ) -> tuple [float , float ]:
213+ """Compute the lower and upper rotary-dimension correction bounds for YaRN."""
210214 low = find_correction_dim (low_rot , dim , base , max_position_embeddings )
211215 high = find_correction_dim (high_rot , dim , base , max_position_embeddings )
212216 if truncate :
@@ -215,6 +219,7 @@ def find_correction_range(
215219 return max (low , 0 ), min (high , dim - 1 )
216220
217221 def linear_ramp_factor (min_value : float , max_value : float , dim : int ) -> torch .Tensor :
222+ """Create a clamped linear ramp used to blend interpolation and extrapolation."""
218223 if min_value == max_value :
219224 max_value += 0.001
220225 linear_func = (torch .arange (dim , dtype = torch .float32 , device = device ) - min_value ) / (max_value - min_value )
@@ -421,6 +426,7 @@ class RotaryTransformConfig(BaseModel):
421426
422427 @model_validator (mode = "after" )
423428 def validate_rope_scaling (self ) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig" :
429+ """Validate and normalize rope_scaling, including YaRN-specific constraints."""
424430 if self .rope_scaling is None :
425431 return self
426432
0 commit comments