Skip to content

Commit 7337fe4

Browse files
authored
Merge pull request #445 from Modalities/yarn_hf
YaRN
2 parents 4705675 + e12db1a commit 7337fe4

4 files changed

Lines changed: 388 additions & 32 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 208 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import math
33
from abc import abstractmethod
44
from enum import Enum
5-
from typing import Annotated, Optional, overload
5+
from numbers import Real
6+
from typing import Annotated, Literal, Optional, overload
67

78
import torch
89
import torch.nn as nn
9-
from pydantic import BaseModel, Field, model_validator, validator
10+
from pydantic import BaseModel, Field, field_validator, model_validator, validator
1011

1112
from modalities.config.lookup_enum import LookupEnum
1213
from modalities.config.utils import convert_base_model_config_to_dict
@@ -31,6 +32,60 @@
3132
# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT
3233

3334

35+
def _validate_numeric_field(field_name: str, value: object) -> float:
36+
"""Validate that a value is a real number (excluding bool) and cast to float."""
37+
if isinstance(value, bool) or not isinstance(value, Real):
38+
raise ValueError(f"rope_scaling.{field_name} must be a float")
39+
return float(value)
40+
41+
42+
class DefaultRopeScalingConfig(BaseModel):
43+
"""Configuration for default RoPE behavior."""
44+
45+
rope_type: Literal["default"] = "default"
46+
47+
48+
class YarnRopeScalingConfig(BaseModel):
49+
"""Configuration for YaRN RoPE scaling."""
50+
51+
rope_type: Literal["yarn"] = "yarn"
52+
original_max_position_embeddings: Annotated[int, Field(strict=True, ge=1)]
53+
factor: Optional[Annotated[float, Field(ge=1.0)]] = None
54+
attention_factor: Optional[Annotated[float, Field(gt=0.0)]] = None
55+
mscale: Optional[Annotated[float, Field(ge=0.0)]] = None
56+
mscale_all_dim: Optional[Annotated[float, Field(ge=0.0)]] = None
57+
beta_fast: Annotated[float, Field(ge=0.0)] = 32.0
58+
beta_slow: Annotated[float, Field(ge=0.0)] = 1.0
59+
truncate: bool = True
60+
61+
@field_validator(
62+
"factor",
63+
"attention_factor",
64+
"mscale",
65+
"mscale_all_dim",
66+
"beta_fast",
67+
"beta_slow",
68+
mode="before",
69+
)
70+
@classmethod
71+
def validate_numeric_fields(cls, value: object, info):
72+
if value is None:
73+
return value
74+
return _validate_numeric_field(info.field_name, value)
75+
76+
@model_validator(mode="after")
77+
def validate_mscale_pair(self) -> "YarnRopeScalingConfig":
78+
if (self.mscale is None) != (self.mscale_all_dim is None):
79+
raise ValueError("rope_scaling.mscale and rope_scaling.mscale_all_dim must be provided together")
80+
return self
81+
82+
83+
RopeScalingConfig = Annotated[
84+
DefaultRopeScalingConfig | YarnRopeScalingConfig,
85+
Field(discriminator="rope_type"),
86+
]
87+
88+
3489
class LayerNorms(LookupEnum):
3590
"""
3691
Enum lookup class for LayerNorms.
@@ -120,7 +175,15 @@ class RotaryTransform(QueryKeyValueTransform):
120175
XFormers implementation and removed in this implementation.#
121176
"""
122177

123-
def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq: int = 10000):
178+
def __init__(
179+
self,
180+
n_embd: int,
181+
n_head: int,
182+
seq_length_dim: int = -2,
183+
base_freq: int = 10000,
184+
max_position_embeddings: int | None = None,
185+
rope_scaling: RopeScalingConfig | None = None,
186+
):
124187
"""
125188
Initializes the RotaryTransform object.
126189
@@ -136,16 +199,33 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq
136199
self.dim_model = n_embd // n_head
137200
self.seq_length_dim = seq_length_dim
138201
self.base_freq = base_freq
202+
self.max_position_embeddings = max_position_embeddings
203+
204+
if rope_scaling is not None and not isinstance(rope_scaling, (DefaultRopeScalingConfig, YarnRopeScalingConfig)):
205+
raise TypeError(
206+
"rope_scaling must be an instance of DefaultRopeScalingConfig, YarnRopeScalingConfig, or None"
207+
)
208+
209+
self.rope_scaling = rope_scaling
210+
self.attention_scaling = 1.0
139211

140212
self.reset_parameters()
141213

142214
def reset_parameters(self):
143215
# If previously initialized on or moved to a device, reuse that device.
144216
# Otherwise, use the default device of the current environment.
145-
device = self.inv_freq.device if hasattr(self, "inv_freq") else None
146-
inv_freq = 1.0 / (
147-
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
148-
)
217+
device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None
218+
219+
rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default"
220+
221+
if rope_type == "yarn":
222+
inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device)
223+
else:
224+
inv_freq = 1.0 / (
225+
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
226+
)
227+
self.attention_scaling = 1.0
228+
149229
self.register_buffer("inv_freq", inv_freq)
150230

151231
self._seq_len_cached = None
@@ -166,24 +246,6 @@ def rotate_half(self, x: torch.Tensor):
166246
x1, x2 = x.chunk(2, dim=-1)
167247
return torch.cat((-x2, x1), dim=-1)
168248

169-
def _update_cos_sin_tables(self, x):
170-
# Update the cosine and sine tables.
171-
seq_len = x.shape[self.seq_length_dim]
172-
173-
# Reset the tables if the sequence length has changed,
174-
# or if we're on a new device (possibly due to tracing for instance)
175-
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
176-
self._seq_len_cached = seq_len
177-
t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32)
178-
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
179-
emb = torch.cat((freqs, freqs), dim=-1).to(
180-
x.device
181-
) # here, we combine the two matrices (not zipping them).
182-
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
183-
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
184-
185-
return self._cos_cached, self._sin_cached
186-
187249
def apply_rotary_pos_emb(self, x, cos, sin):
188250
"""
189251
Applies rotary positional embedding to the input tensor.
@@ -228,6 +290,118 @@ def forward(
228290

229291
return q, k, v
230292

293+
def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]:
294+
"""Compute YaRN inverse frequencies and the attention scaling factor."""
295+
if not isinstance(self.rope_scaling, YarnRopeScalingConfig):
296+
raise ValueError("YaRN requires a rope_scaling config.")
297+
if self.max_position_embeddings is None:
298+
raise ValueError("YaRN requires max_position_embeddings to be set.")
299+
300+
original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings
301+
factor = self.rope_scaling.factor
302+
if factor is None:
303+
factor = self.max_position_embeddings / original_max_position_embeddings
304+
factor_float = float(factor)
305+
306+
attention_factor = self.rope_scaling.attention_factor
307+
mscale_pair = None
308+
if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None:
309+
mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim)
310+
311+
beta_fast = self.rope_scaling.beta_fast
312+
beta_slow = self.rope_scaling.beta_slow
313+
truncate = self.rope_scaling.truncate
314+
315+
def get_mscale(scale: float, mscale: float = 1.0) -> float:
316+
"""Return the YaRN mscale coefficient for a given scaling factor."""
317+
if scale <= 1:
318+
return 1.0
319+
return 0.1 * mscale * math.log(scale) + 1.0
320+
321+
if attention_factor is None:
322+
if mscale_pair is not None:
323+
mscale, mscale_all_dim = mscale_pair
324+
attention_factor = float(
325+
get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim))
326+
)
327+
else:
328+
attention_factor = get_mscale(factor_float)
329+
330+
def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float:
331+
"""Map a target number of rotations to a rotary dimension index."""
332+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
333+
334+
def find_correction_range(
335+
low_rot: float,
336+
high_rot: float,
337+
dim: int,
338+
base: int,
339+
max_position_embeddings: int,
340+
truncate: bool,
341+
) -> tuple[float, float]:
342+
"""Compute the lower and upper rotary-dimension correction bounds for YaRN."""
343+
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
344+
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
345+
if truncate:
346+
low = math.floor(low)
347+
high = math.ceil(high)
348+
return max(low, 0), min(high, dim - 1)
349+
350+
def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor:
351+
"""Create a clamped linear ramp used to blend interpolation and extrapolation."""
352+
if min_value == max_value:
353+
max_value += 0.001
354+
linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value)
355+
ramp_func = torch.clamp(linear_func, 0, 1)
356+
return ramp_func
357+
358+
dim = self.dim_model
359+
base = self.base_freq
360+
361+
pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
362+
inv_freq_extrapolation = 1.0 / pos_freqs
363+
inv_freq_interpolation = 1.0 / (factor_float * pos_freqs)
364+
365+
low, high = find_correction_range(
366+
beta_fast,
367+
beta_slow,
368+
dim,
369+
base,
370+
original_max_position_embeddings,
371+
bool(truncate),
372+
)
373+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
374+
inv_freq = (
375+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
376+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
377+
)
378+
379+
return inv_freq, float(attention_factor)
380+
381+
def _update_cos_sin_tables(self, x):
382+
# Update the cosine and sine tables.
383+
seq_len = x.shape[self.seq_length_dim]
384+
385+
# Reset the tables if the sequence length has changed,
386+
# or if we're on a new device (possibly due to tracing for instance)
387+
if (
388+
seq_len != self._seq_len_cached
389+
or self._cos_cached is None
390+
or self._sin_cached is None
391+
or self._cos_cached.device != x.device
392+
or self._cos_cached.dtype != x.dtype
393+
):
394+
self._seq_len_cached = seq_len
395+
t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32)
396+
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
397+
emb = torch.cat((freqs, freqs), dim=-1).to(
398+
x.device
399+
) # here, we combine the two matrices (not zipping them).
400+
self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype)
401+
self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype)
402+
403+
return self._cos_cached, self._sin_cached
404+
231405

232406
class QueryKeyValueTransformType(Enum):
233407
"""
@@ -295,6 +469,15 @@ class RotaryTransformConfig(BaseModel):
295469
n_head: Annotated[int, Field(strict=True, ge=0)]
296470
seq_length_dim: Annotated[int, Field(strict=True)]
297471
base_freq: Annotated[int, Field(strict=True, ge=10000)]
472+
max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None
473+
rope_scaling: Optional[RopeScalingConfig] = None
474+
475+
@model_validator(mode="after")
476+
def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig":
477+
"""Validate rope_scaling cross-field constraints."""
478+
if isinstance(self.rope_scaling, YarnRopeScalingConfig) and self.max_position_embeddings is None:
479+
raise ValueError("YaRN requires max_position_embeddings to be set")
480+
return self
298481

299482
@validator("type_hint", pre=True, always=True)
300483
def parse_sharding_strategy_by_name(cls, name):

src/modalities/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import gc
12
from datetime import datetime
23
from enum import Enum
3-
import gc
44
from typing import Callable, Optional
55

66
import torch
@@ -388,7 +388,7 @@ def train(
388388
self.gc.run(step_count=training_progress.num_seen_steps_total)
389389
evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
390390
checkpointing_callback(training_progress=training_progress)
391-
391+
392392
profiler_cm.step()
393393

394394
@staticmethod

tests/fsdp2_parallelization/test_tensor_parallelism.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
from tests.utility import find_free_port
2121

2222

23-
def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path:
23+
def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path, file_tag: str = "") -> Path:
2424
"""Patches the original configuration file to set a custom activation type."""
2525
with original_config_path.open("r", encoding="utf-8") as f:
2626
config_dict = yaml.safe_load(f)
2727

2828
config_dict["model_raw"]["config"]["activation_type"] = activation_type
2929

30-
tmp_file_path = tmp_dir / original_config_path.name
30+
file_suffix = f"_{file_tag}" if file_tag else ""
31+
tmp_file_path = tmp_dir / f"{original_config_path.stem}{file_suffix}{original_config_path.suffix}"
3132
with tmp_file_path.open("w", encoding="utf-8") as f:
3233
yaml.safe_dump(config_dict, f)
3334

@@ -103,12 +104,16 @@ def _test_tp_sharding_impl(
103104
):
104105
# Seed before FSDP2 instantiation
105106
torch.manual_seed(42)
106-
fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir)
107+
fsdp2_path = patch_config_file(
108+
fsdp2_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_fsdp2"
109+
)
107110
fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path, tmp_path)
108111

109112
# Seed again before TP instantiation to match
110113
torch.manual_seed(42)
111-
tp_path = patch_config_file(tp_config_path, activation_type, tmp_config_dir)
114+
tp_path = patch_config_file(
115+
tp_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_tp"
116+
)
112117
tp_model, tp_mesh = self._get_components(tp_path, tmp_path)
113118

114119
# Ensure models use the correct MLP

0 commit comments

Comments
 (0)