Skip to content

Add backward pass to fused rope kernels#3612

Open
NahButch wants to merge 1 commit into
huggingface:mainfrom
NahButch:rope-backward
Open

Add backward pass to fused rope kernels#3612
NahButch wants to merge 1 commit into
huggingface:mainfrom
NahButch:rope-backward

Conversation

@NahButch

Copy link
Copy Markdown

rope, rope_i, and rope_thd used apply_op3_no_bwd, so loss.backward() silently returned no gradient for any Var upstream of a rotary embedding, while the rope_*_slow compositions are differentiable. Same naming footgun as rms_norm (#3526) and softmax_last_dim (#3591).

The rotation is linear in xs, so the backward is the same fused rope applied to the incoming gradient with sin negated — reusing the fast kernels on every backend; the cos/sin tables get no gradient. Adds a gradient test comparing the fused path against slow-path autograd for all three variants; it fails on the previous behavior with 'no gradient for rope input'.

Fixes #3568

🤖 Generated with Claude Code

rope, rope_i, and rope_thd used apply_op3_no_bwd, so loss.backward()
silently returned no gradient for any Var upstream of a rotary
embedding, while the rope_*_slow compositions are differentiable.
Same naming footgun as rms_norm (huggingface#3526) and softmax_last_dim (huggingface#3591).

The rotation is linear in xs, so the backward is the same fused rope
applied to the incoming gradient with sin negated; cos/sin tables get
no gradient. Adds a gradient test comparing the fused path against the
slow-path autograd for all three variants; it fails on the previous
behavior with 'no gradient for rope input'.

Fixes huggingface#3568

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

candle-nn: rotary_emb::rope is apply_op3_no_bwd; severs autograd (same pattern as #2168 / PR #3526)

1 participant