Skip to content

Commit 2126b0b

Browse files
rrutmannCopilot
andcommitted
docs: Add docstrings
Co-authored-by: Copilot <copilot@github.com>
1 parent 779e7c1 commit 2126b0b

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def __init__(
152152
self.reset_parameters()
153153

154154
def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]:
155+
"""Compute YaRN inverse frequencies and the attention scaling factor."""
155156
if self.rope_scaling is None:
156157
raise ValueError("YaRN requires a rope_scaling config.")
157158
if self.max_position_embeddings is None:
@@ -182,6 +183,7 @@ def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.T
182183
truncate = self.rope_scaling.get("truncate", True)
183184

184185
def get_mscale(scale: float, mscale: float = 1.0) -> float:
186+
"""Return the YaRN mscale coefficient for a given scaling factor."""
185187
if scale <= 1:
186188
return 1.0
187189
return 0.1 * mscale * math.log(scale) + 1.0
@@ -197,6 +199,7 @@ def get_mscale(scale: float, mscale: float = 1.0) -> float:
197199
raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0")
198200

199201
def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float:
202+
"""Map a target number of rotations to a rotary dimension index."""
200203
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
201204

202205
def find_correction_range(
@@ -207,6 +210,7 @@ def find_correction_range(
207210
max_position_embeddings: int,
208211
truncate: bool,
209212
) -> tuple[float, float]:
213+
"""Compute the lower and upper rotary-dimension correction bounds for YaRN."""
210214
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
211215
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
212216
if truncate:
@@ -215,6 +219,7 @@ def find_correction_range(
215219
return max(low, 0), min(high, dim - 1)
216220

217221
def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor:
222+
"""Create a clamped linear ramp used to blend interpolation and extrapolation."""
218223
if min_value == max_value:
219224
max_value += 0.001
220225
linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value)
@@ -421,6 +426,7 @@ class RotaryTransformConfig(BaseModel):
421426

422427
@model_validator(mode="after")
423428
def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig":
429+
"""Validate and normalize rope_scaling, including YaRN-specific constraints."""
424430
if self.rope_scaling is None:
425431
return self
426432

0 commit comments

Comments
 (0)