Skip to content

fix(gemma4): cast RoPE offset to int before mx.arange()#4901

Open
eauchs wants to merge 2 commits intounslothai:fix/ui-fixfrom
eauchs:fix/gemma4-rope-offset-int-cast
Open

fix(gemma4): cast RoPE offset to int before mx.arange()#4901
eauchs wants to merge 2 commits intounslothai:fix/ui-fixfrom
eauchs:fix/gemma4-rope-offset-int-cast

Conversation

@eauchs
Copy link
Copy Markdown

@eauchs eauchs commented Apr 7, 2026

Problem

mx.arange() receives an mlx.core.array for offset instead of a
Python native int, causing a TypeError at inference time with Gemma 4 models.

Fix

Cast offset to int before passing to mx.arange().

Tested on

M3 Max 128GB — unsloth/gemma-4-31b-it-UD-MLX-4bit

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the position index generation in the Gemma4 text model by casting the offset to an integer. A review comment identifies that using int(offset) in MLX is inefficient because it triggers a CPU-GPU synchronization point and breaks compatibility with mx.compile. A suggestion was provided to use a zero-based range added to the offset to maintain performance and compilation support.

# x shape: (B, n_heads, L, head_dim)
seq_len = x.shape[-2]
positions = mx.arange(offset, offset + seq_len, dtype = mx.float32)
positions = mx.arange(int(offset), int(offset) + seq_len, dtype = mx.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using int(offset) is discouraged in MLX because it forces a synchronization point between the GPU and CPU to retrieve the value, which can significantly degrade performance during inference. Furthermore, if this code is executed within an mx.compile block, int(offset) will fail if offset is a tracer array.

A more efficient and compilation-friendly approach is to generate a zero-based range and then add the offset. This avoids the TypeError with mx.arange while supporting both integer and array-based offsets without performance penalties.

Suggested change
positions = mx.arange(int(offset), int(offset) + seq_len, dtype = mx.float32)
positions = mx.arange(seq_len, dtype = mx.float32) + offset

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — updated the fix to use mx.arange(seq_len) + offset to avoid the CPU-GPU sync point and maintain mx.compile compatibility

@eauchs
Copy link
Copy Markdown
Author

eauchs commented Apr 7, 2026

@danielhanchen this fixes a TypeError crashing Gemma 4 inference for all users on the current fix/ui-fix branch — would appreciate a quick review 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant