[AutoDiff] Mark rsqrt as non-linear for adstack promotion#499
Conversation
83f6e6c to
521db21
Compare
d0c43e2 to
c060858
Compare
There was a problem hiding this comment.
LGTM — adding rsqrt to unary_collections is correct and follows the established pattern.
Extended reasoning...
Overview
This PR modifies a single static set in NonLinearOps::unary_collections (plus minor reformatting) within quadrants/transforms/auto_diff.cpp. No logic, control flow, or data structures are changed beyond membership of this set.
Security risks
None. This is a pure compiler IR transformation with no authentication, I/O, or external interfaces.
Level of scrutiny
Low. The change is mechanical: rsqrt's reverse formula (-0.5 * rsqrt(operand)^3) recomputes entirely on the operand — identical in structure to the tan/tanh/exp cases added previously, all of which carry the same adstack-promotion requirement. The AdStackAllocaJudger::visit(UnaryOpStmt*) visitor already handles the collection membership check; this PR simply registers rsqrt with it.
Other factors
The one bug report flagged (min/max absent from binary_collections) is explicitly pre-existing and unrelated to this diff. No prior reviews exist on this PR. Test coverage is planned via test_adstack_unary_loop_carried[rsqrt] landing with #491.
| inline static const std::set<BinaryOpType> binary_collections{BinaryOpType::mul, BinaryOpType::div, | ||
| BinaryOpType::atan2, BinaryOpType::pow}; |
There was a problem hiding this comment.
🟣 BinaryOpType::min and BinaryOpType::max are absent from NonLinearOps::binary_collections (line 48), so AdStackAllocaJudger never promotes their operand allocas to adstack. The reverse-mode gradient for min/max uses bin->lhs and bin->rhs directly via cmp_lt to route gradients, requiring per-iteration correct values in the reversed loop; without adstack promotion the reversed loop reads the last-forward operand value and silently routes all non-final-iteration gradients to the wrong branch. Add BinaryOpType::min and BinaryOpType::max to binary_collections analogously to how this PR adds rsqrt to unary_collections.
Extended reasoning...
What the bug is and how it manifests
BinaryOpType::min and BinaryOpType::max are absent from NonLinearOps::binary_collections (line 48). AdStackAllocaJudger::visit(BinaryOpStmt*) consults this set to decide whether an operand alloca needs adstack promotion. Because min and max are not in the set, operand allocas that feed only a min or max operation in a dynamic loop are never promoted — they stay as plain allocas — and the reversed loop silently reads stale per-iteration values, producing wrong gradients.
The specific code path that triggers it
MakeAdjoint::visit(BinaryOpStmt*) handles min/max at the lines:
auto cmp = bin->op_type == BinaryOpType::min ? cmp_lt(bin->lhs, bin->rhs) : cmp_lt(bin->rhs, bin->lhs);
accumulate(bin->lhs, sel(cmp, adjoint(bin), zero));
accumulate(bin->rhs, sel(cmp, zero, adjoint(bin)));
Both bin->lhs and bin->rhs are used directly from the forward IR in the comparison. In a reversed loop this comparison must see the per-iteration forward operand values. That requires those operands' backing allocas to ride the adstack. The promotion decision happens in AdStackAllocaJudger::run(AllocaStmt*), which returns true only if (!load_only_) && is_stack_needed_. is_stack_needed_ is set to true in visit(BinaryOpStmt*) only when the op type is in binary_collections — and min/max are absent.
Why existing code doesn't prevent it
The only other mechanism that could set is_stack_needed_ for this scenario is the LocalStore cycle check: if (local_loaded_ && dest == target_alloca_backup_) is_stack_needed_ = true. This fires when a load of the alloca is followed by a store back to it in the same loop body (an accumulator pattern). However, for the simple pattern store(alloca, x[i]) → load(alloca) → min(loaded, constant) with no second store, local_loaded_ is true after the load but no subsequent store ever fires the check, leaving is_stack_needed_ = false. The alloca stays plain.
What the impact would be
BackupSSA::load() creates a plain AllocaStmt and emits a single LocalStoreStmt immediately after the forward definition of the operand, outside the reversed loop. Each forward iteration overwrites this single flat alloca. In the reversed loop, the LocalLoadStmt inserted by generic_visit reads the single flat alloca and obtains the last-forward-iteration value for every backward step. For a loop of length N, only the final reversed iteration (corresponding to the last forward iteration) routes the gradient correctly; all N-1 preceding backward steps compare operands from iteration N and accumulate gradient into the wrong branch — silent gradient corruption with no error or warning.
How to fix it
Add BinaryOpType::min and BinaryOpType::max to NonLinearOps::binary_collections, analogously to how this PR adds UnaryOpType::rsqrt to unary_collections:
inline static const std::set binary_collections{
BinaryOpType::mul, BinaryOpType::div, BinaryOpType::atan2, BinaryOpType::pow,
BinaryOpType::min, BinaryOpType::max};
Step-by-step proof
Consider a simple reverse-mode AD kernel over a loop:
for i in range(N):
t = alloca; store(t, arr[i]) // step 1: init t each iteration
loaded = load(t) // step 2: local_loaded_ = true
result = min(loaded, threshold) // step 3: op_type not in binary_collections -> is_stack_needed_ stays false
After AdStackAllocaJudger returns false, ReplaceLocalVarWithStacks leaves t as a plain AllocaStmt. BackupSSA then emits:
t_backup = alloca (at top of IB)
store(t_backup, arr[i]) // immediately after the load/store, executed each forward iteration
In the reversed loop for i from N-1 downto 0:
loaded_reversed = load(t_backup) // always reads arr[N-1], not arr[i]
cmp = cmp_lt(loaded_reversed, threshold) // uses arr[N-1] for every i
// gradient routed incorrectly for i != N-1
The fix (adding min/max to binary_collections) causes AdStackAllocaJudger to return true, promoting t to an AdStackAllocaStmt. Each forward iteration pushes the current arr[i] onto the stack; each reversed iteration pops and reads the correct per-iteration value.
`AdStackAllocaJudger` walks `NonLinearOps::unary_collections` to decide whether the operand alloca of a unary op needs AdStack backing. rsqrt's reverse formula is `-0.5 * rsqrt(x)^3` and already recomputes on the operand, so correctness inside a dynamic loop hinges entirely on the operand's alloca being promoted to AdStack. Without membership in the set, a loop-variant alloca feeding only rsqrt stayed plain and the reversed loop read the last-forward value, producing wrong gradients. Covered by extending test_adstack_unary_loop_carried with qd.rsqrt.
521db21 to
274e8da
Compare
c060858 to
0f7ff3b
Compare
0f7ff3b
into
duburcqa/split_adjoint_alloca_placement
Summary
`AdStackAllocaJudger` walks `NonLinearOps::unary_collections` to decide whether the operand alloca of a unary op needs AdStack backing. rsqrt's reverse formula is `-0.5 * rsqrt(x)^3` and already recomputes on the operand, so correctness inside a dynamic loop hinges entirely on the operand's alloca being promoted to AdStack. Without membership in the set, a loop-variant alloca feeding only rsqrt stayed plain and the reversed loop read the last-forward value, producing wrong gradients.
Test plan