fix(gemma4): cast RoPE offset to int before mx.arange()#4901
fix(gemma4): cast RoPE offset to int before mx.arange()#4901eauchs wants to merge 2 commits intounslothai:fix/ui-fixfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the position index generation in the Gemma4 text model by casting the offset to an integer. A review comment identifies that using int(offset) in MLX is inefficient because it triggers a CPU-GPU synchronization point and breaks compatibility with mx.compile. A suggestion was provided to use a zero-based range added to the offset to maintain performance and compilation support.
unsloth/models/gemma4_text.py
Outdated
| # x shape: (B, n_heads, L, head_dim) | ||
| seq_len = x.shape[-2] | ||
| positions = mx.arange(offset, offset + seq_len, dtype = mx.float32) | ||
| positions = mx.arange(int(offset), int(offset) + seq_len, dtype = mx.float32) |
There was a problem hiding this comment.
Using int(offset) is discouraged in MLX because it forces a synchronization point between the GPU and CPU to retrieve the value, which can significantly degrade performance during inference. Furthermore, if this code is executed within an mx.compile block, int(offset) will fail if offset is a tracer array.
A more efficient and compilation-friendly approach is to generate a zero-based range and then add the offset. This avoids the TypeError with mx.arange while supporting both integer and array-based offsets without performance penalties.
| positions = mx.arange(int(offset), int(offset) + seq_len, dtype = mx.float32) | |
| positions = mx.arange(seq_len, dtype = mx.float32) + offset |
There was a problem hiding this comment.
Good catch — updated the fix to use mx.arange(seq_len) + offset to avoid the CPU-GPU sync point and maintain mx.compile compatibility
|
@danielhanchen this fixes a TypeError crashing Gemma 4 inference for all users on the current fix/ui-fix branch — would appreciate a quick review 🙏 |
Problem
mx.arange()receives anmlx.core.arrayforoffsetinstead of aPython native int, causing a TypeError at inference time with Gemma 4 models.
Fix
Cast
offsettointbefore passing tomx.arange().Tested on
M3 Max 128GB — unsloth/gemma-4-31b-it-UD-MLX-4bit