Skip to content

Commit ae02b38

Browse files
mergennachinclaude
andcommitted
Fix MLX RoPE for proportional partial rotary (Gemma 4 full-attention layers)
- custom_ops.py: support 1D freqs in the Python fake op. When freqs is 1D, compute inv_freq = 1/freqs and build angles from positions, matching the C++ runtime behavior. 2D freqs path unchanged. - MLXInterpreter.h: pass base=nullopt when freqs is provided. MLX's fast::rope requires exactly one of base or freqs. - mlx_source_transformations.py: pass dims=rotary_dim (not head_dim) with 1D freqs containing only the non-zero rotary frequencies. The old code passed 2D precomputed angles which was incorrect at the C++ level. - test_ops.py: add RopeCustomFreqsTest (3 configs) verifying export and MLX delegation with 1D custom frequencies. Co-authored-by: Claude <noreply@anthropic.com>
1 parent 6423b4b commit ae02b38

4 files changed

Lines changed: 103 additions & 11 deletions

File tree

backends/mlx/custom_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,16 @@ def rope(
228228
# final angles: [1, 1, T, half]
229229
angles = (pos_range * inv_freq) * float(scale)
230230
else:
231-
# assume freqs is already per-position, just reshape to [1,1,T,half]
232-
angles = freqs.to(torch.float32).view(1, 1, T, half)
231+
if freqs.ndim == 1:
232+
# 1D raw frequencies: compute angles = positions * (1/freqs)
233+
inv_freq = (1.0 / freqs.to(torch.float32)).view(1, 1, 1, half)
234+
pos_range = torch.arange(
235+
pos, pos + T, device=x.device, dtype=torch.float32
236+
).view(1, 1, T, 1)
237+
angles = (pos_range * inv_freq) * float(scale)
238+
else:
239+
# 2D per-position angles: reshape to [1,1,T,half]
240+
angles = freqs.to(torch.float32).view(1, 1, T, half)
233241

234242
cos = angles.cos().to(x.dtype) # [1,1,T,half]
235243
sin = angles.sin().to(x.dtype) # [1,1,T,half]

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) {
242242
freqs_arr = st.const_tensor_ref(*n.freqs);
243243
}
244244

245+
// MLX requires exactly one of base or freqs — when freqs is provided,
246+
// base must be nullopt.
247+
std::optional<float> base =
248+
freqs_arr ? std::nullopt : std::optional<float>(n.base);
249+
245250
// MLX has two overloads: rope(..., int offset, ...) and rope(..., const
246251
// array& offset, ...) Call the appropriate one based on is_vid
247252
if (n.offset.is_vid) {
@@ -250,14 +255,14 @@ inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) {
250255
st.set_tensor(
251256
n.out,
252257
fast::rope(
253-
x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s));
258+
x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s));
254259
} else {
255260
// Tensor offset from Tid
256261
const array& offset = st.const_tensor_ref(n.offset.tid);
257262
st.set_tensor(
258263
n.out,
259264
fast::rope(
260-
x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s));
265+
x, n.dims, n.traditional, base, n.scale, offset, freqs_arr, s));
261266
}
262267
}
263268

backends/mlx/test/test_ops.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,82 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
18031803
return (q, k, pos_tensor)
18041804

18051805

1806+
class RopeCustomFreqsModel(nn.Module):
1807+
"""Model that applies RoPE with custom 1D frequencies (partial rotary)."""
1808+
1809+
def __init__(self, dims: int = 32, head_dim: int = 64):
1810+
super().__init__()
1811+
self.dims = dims
1812+
self.head_dim = head_dim
1813+
# Simulate proportional RoPE: compute freqs for rotary dims only
1814+
inv_freq = 1.0 / (
1815+
500000.0 ** (torch.arange(0, dims, 2, dtype=torch.float32) / head_dim)
1816+
)
1817+
self.register_buffer("freqs", 1.0 / inv_freq, persistent=False)
1818+
1819+
def forward(
1820+
self,
1821+
q: torch.Tensor,
1822+
k: torch.Tensor,
1823+
pos_tensor: torch.Tensor,
1824+
) -> Tuple[torch.Tensor, torch.Tensor]:
1825+
pos = pos_tensor.item()
1826+
q_rot = torch.ops.mlx.rope(q, self.dims, pos, False, 0.0, 1.0, self.freqs)
1827+
k_rot = torch.ops.mlx.rope(k, self.dims, pos, False, 0.0, 1.0, self.freqs)
1828+
return q_rot, k_rot
1829+
1830+
1831+
@register_test
1832+
class RopeCustomFreqsTest(OpTestCase):
1833+
"""Test RoPE with custom 1D frequencies (partial rotary, like Gemma 4)."""
1834+
1835+
name = "rope_custom_freqs"
1836+
rtol = 1e-4
1837+
atol = 1e-4
1838+
1839+
def __init__(
1840+
self,
1841+
batch_size: int = 1,
1842+
num_heads: int = 8,
1843+
seq_len: int = 4,
1844+
head_dim: int = 64,
1845+
dims: int = 32,
1846+
pos: int = 0,
1847+
):
1848+
self.batch_size = batch_size
1849+
self.num_heads = num_heads
1850+
self.seq_len = seq_len
1851+
self.head_dim = head_dim
1852+
self.dims = dims
1853+
self.pos = pos
1854+
self.name = "rope_custom_freqs"
1855+
1856+
@classmethod
1857+
def get_test_configs(cls) -> List["RopeCustomFreqsTest"]:
1858+
configs = [
1859+
cls(),
1860+
cls(pos=10),
1861+
cls(head_dim=128, dims=64),
1862+
]
1863+
for cfg in configs:
1864+
parts = ["rope_custom_freqs"]
1865+
if cfg.pos > 0:
1866+
parts.append(f"pos{cfg.pos}")
1867+
if cfg.head_dim != 64:
1868+
parts.append(f"hd{cfg.head_dim}")
1869+
cfg.name = "_".join(parts)
1870+
return configs
1871+
1872+
def create_model(self) -> nn.Module:
1873+
return RopeCustomFreqsModel(dims=self.dims, head_dim=self.head_dim)
1874+
1875+
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
1876+
q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim)
1877+
k = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim)
1878+
pos_tensor = torch.tensor(self.pos, dtype=torch.int64)
1879+
return (q, k, pos_tensor)
1880+
1881+
18061882
from executorch.backends.mlx.llm.cache import KVCache
18071883

18081884

examples/models/gemma4_31b/mlx_source_transformations.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ def _mlx_forward(
5151
k = k.transpose(1, 2)
5252
v = v.transpose(1, 2)
5353

54-
# RoPE via mlx::rope. For proportional partial RoPE (full-attention
55-
# layers), pass precomputed frequencies since mlx.rope's built-in
56-
# frequency computation uses dims as the denominator, but Gemma 4
57-
# uses head_dim.
54+
# RoPE via mlx::rope.
5855
if self.is_sliding:
5956
q = torch.ops.mlx.rope(
6057
q, self.head_dim, start_pos, False, self.rope_theta, 1.0, None
@@ -63,9 +60,15 @@ def _mlx_forward(
6360
k, self.head_dim, start_pos, False, self.rope_theta, 1.0, None
6461
)
6562
else:
66-
freqs = torch.outer(input_pos.float(), self.inv_freq)
67-
q = torch.ops.mlx.rope(q, self.head_dim, start_pos, False, 0.0, 0.0, freqs)
68-
k = torch.ops.mlx.rope(k, self.head_dim, start_pos, False, 0.0, 0.0, freqs)
63+
# Full-attention layers use proportional partial RoPE: only
64+
# rotary_dim out of head_dim dimensions are rotated. Pass
65+
# dims=rotary_dim and the non-zero frequencies as 1D freqs.
66+
# MLX computes inv_freq = 1/freqs internally.
67+
rotary_dim = int(self.head_dim * self.partial_rotary)
68+
rotary_inv_freq = self.inv_freq[: rotary_dim // 2]
69+
mlx_freqs = 1.0 / rotary_inv_freq
70+
q = torch.ops.mlx.rope(q, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs)
71+
k = torch.ops.mlx.rope(k, rotary_dim, start_pos, False, 0.0, 1.0, mlx_freqs)
6972

7073
k_cache, v_cache = self.kv_cache.update(start_pos, k, v)
7174

0 commit comments

Comments
 (0)