Skip to content

[AutoDiff] Mark rsqrt as non-linear for adstack promotion#499

Merged
duburcqa merged 1 commit intoduburcqa/split_adjoint_alloca_placementfrom
duburcqa/split_autodiff_mark_rsqrt_nonlinear
Apr 17, 2026
Merged

[AutoDiff] Mark rsqrt as non-linear for adstack promotion#499
duburcqa merged 1 commit intoduburcqa/split_adjoint_alloca_placementfrom
duburcqa/split_autodiff_mark_rsqrt_nonlinear

Conversation

@duburcqa
Copy link
Copy Markdown
Contributor

@duburcqa duburcqa commented Apr 17, 2026

Summary

`AdStackAllocaJudger` walks `NonLinearOps::unary_collections` to decide whether the operand alloca of a unary op needs AdStack backing. rsqrt's reverse formula is `-0.5 * rsqrt(x)^3` and already recomputes on the operand, so correctness inside a dynamic loop hinges entirely on the operand's alloca being promoted to AdStack. Without membership in the set, a loop-variant alloca feeding only rsqrt stayed plain and the reversed loop read the last-forward value, producing wrong gradients.

Test plan

@duburcqa duburcqa force-pushed the duburcqa/split_autodiff_tanh_exp_recompute branch from 83f6e6c to 521db21 Compare April 17, 2026 11:43
@duburcqa duburcqa force-pushed the duburcqa/split_autodiff_mark_rsqrt_nonlinear branch from d0c43e2 to c060858 Compare April 17, 2026 11:43
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — adding rsqrt to unary_collections is correct and follows the established pattern.

Extended reasoning...

Overview

This PR modifies a single static set in NonLinearOps::unary_collections (plus minor reformatting) within quadrants/transforms/auto_diff.cpp. No logic, control flow, or data structures are changed beyond membership of this set.

Security risks

None. This is a pure compiler IR transformation with no authentication, I/O, or external interfaces.

Level of scrutiny

Low. The change is mechanical: rsqrt's reverse formula (-0.5 * rsqrt(operand)^3) recomputes entirely on the operand — identical in structure to the tan/tanh/exp cases added previously, all of which carry the same adstack-promotion requirement. The AdStackAllocaJudger::visit(UnaryOpStmt*) visitor already handles the collection membership check; this PR simply registers rsqrt with it.

Other factors

The one bug report flagged (min/max absent from binary_collections) is explicitly pre-existing and unrelated to this diff. No prior reviews exist on this PR. Test coverage is planned via test_adstack_unary_loop_carried[rsqrt] landing with #491.

Comment on lines 47 to 48
inline static const std::set<BinaryOpType> binary_collections{BinaryOpType::mul, BinaryOpType::div,
BinaryOpType::atan2, BinaryOpType::pow};
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟣 BinaryOpType::min and BinaryOpType::max are absent from NonLinearOps::binary_collections (line 48), so AdStackAllocaJudger never promotes their operand allocas to adstack. The reverse-mode gradient for min/max uses bin->lhs and bin->rhs directly via cmp_lt to route gradients, requiring per-iteration correct values in the reversed loop; without adstack promotion the reversed loop reads the last-forward operand value and silently routes all non-final-iteration gradients to the wrong branch. Add BinaryOpType::min and BinaryOpType::max to binary_collections analogously to how this PR adds rsqrt to unary_collections.

Extended reasoning...

What the bug is and how it manifests

BinaryOpType::min and BinaryOpType::max are absent from NonLinearOps::binary_collections (line 48). AdStackAllocaJudger::visit(BinaryOpStmt*) consults this set to decide whether an operand alloca needs adstack promotion. Because min and max are not in the set, operand allocas that feed only a min or max operation in a dynamic loop are never promoted — they stay as plain allocas — and the reversed loop silently reads stale per-iteration values, producing wrong gradients.

The specific code path that triggers it

MakeAdjoint::visit(BinaryOpStmt*) handles min/max at the lines:
auto cmp = bin->op_type == BinaryOpType::min ? cmp_lt(bin->lhs, bin->rhs) : cmp_lt(bin->rhs, bin->lhs);
accumulate(bin->lhs, sel(cmp, adjoint(bin), zero));
accumulate(bin->rhs, sel(cmp, zero, adjoint(bin)));

Both bin->lhs and bin->rhs are used directly from the forward IR in the comparison. In a reversed loop this comparison must see the per-iteration forward operand values. That requires those operands' backing allocas to ride the adstack. The promotion decision happens in AdStackAllocaJudger::run(AllocaStmt*), which returns true only if (!load_only_) && is_stack_needed_. is_stack_needed_ is set to true in visit(BinaryOpStmt*) only when the op type is in binary_collections — and min/max are absent.

Why existing code doesn't prevent it

The only other mechanism that could set is_stack_needed_ for this scenario is the LocalStore cycle check: if (local_loaded_ && dest == target_alloca_backup_) is_stack_needed_ = true. This fires when a load of the alloca is followed by a store back to it in the same loop body (an accumulator pattern). However, for the simple pattern store(alloca, x[i]) → load(alloca) → min(loaded, constant) with no second store, local_loaded_ is true after the load but no subsequent store ever fires the check, leaving is_stack_needed_ = false. The alloca stays plain.

What the impact would be

BackupSSA::load() creates a plain AllocaStmt and emits a single LocalStoreStmt immediately after the forward definition of the operand, outside the reversed loop. Each forward iteration overwrites this single flat alloca. In the reversed loop, the LocalLoadStmt inserted by generic_visit reads the single flat alloca and obtains the last-forward-iteration value for every backward step. For a loop of length N, only the final reversed iteration (corresponding to the last forward iteration) routes the gradient correctly; all N-1 preceding backward steps compare operands from iteration N and accumulate gradient into the wrong branch — silent gradient corruption with no error or warning.

How to fix it

Add BinaryOpType::min and BinaryOpType::max to NonLinearOps::binary_collections, analogously to how this PR adds UnaryOpType::rsqrt to unary_collections:

inline static const std::set binary_collections{
BinaryOpType::mul, BinaryOpType::div, BinaryOpType::atan2, BinaryOpType::pow,
BinaryOpType::min, BinaryOpType::max};

Step-by-step proof

Consider a simple reverse-mode AD kernel over a loop:
for i in range(N):
t = alloca; store(t, arr[i]) // step 1: init t each iteration
loaded = load(t) // step 2: local_loaded_ = true
result = min(loaded, threshold) // step 3: op_type not in binary_collections -> is_stack_needed_ stays false

After AdStackAllocaJudger returns false, ReplaceLocalVarWithStacks leaves t as a plain AllocaStmt. BackupSSA then emits:
t_backup = alloca (at top of IB)
store(t_backup, arr[i]) // immediately after the load/store, executed each forward iteration

In the reversed loop for i from N-1 downto 0:
loaded_reversed = load(t_backup) // always reads arr[N-1], not arr[i]
cmp = cmp_lt(loaded_reversed, threshold) // uses arr[N-1] for every i
// gradient routed incorrectly for i != N-1

The fix (adding min/max to binary_collections) causes AdStackAllocaJudger to return true, promoting t to an AdStackAllocaStmt. Each forward iteration pushes the current arr[i] onto the stack; each reversed iteration pops and reads the correct per-iteration value.

`AdStackAllocaJudger` walks `NonLinearOps::unary_collections` to decide whether the
operand alloca of a unary op needs AdStack backing. rsqrt's reverse formula is
`-0.5 * rsqrt(x)^3` and already recomputes on the operand, so correctness inside a
dynamic loop hinges entirely on the operand's alloca being promoted to AdStack.
Without membership in the set, a loop-variant alloca feeding only rsqrt stayed plain
and the reversed loop read the last-forward value, producing wrong gradients.

Covered by extending test_adstack_unary_loop_carried with qd.rsqrt.
@duburcqa duburcqa force-pushed the duburcqa/split_autodiff_tanh_exp_recompute branch from 521db21 to 274e8da Compare April 17, 2026 12:12
@duburcqa duburcqa force-pushed the duburcqa/split_autodiff_mark_rsqrt_nonlinear branch from c060858 to 0f7ff3b Compare April 17, 2026 12:12
Base automatically changed from duburcqa/split_autodiff_tanh_exp_recompute to duburcqa/split_adjoint_alloca_placement April 17, 2026 12:12
@duburcqa duburcqa merged commit 0f7ff3b into duburcqa/split_adjoint_alloca_placement Apr 17, 2026
2 of 10 checks passed
@duburcqa duburcqa deleted the duburcqa/split_autodiff_mark_rsqrt_nonlinear branch April 17, 2026 12:12
@duburcqa duburcqa restored the duburcqa/split_autodiff_mark_rsqrt_nonlinear branch April 17, 2026 12:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant