Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions quadrants/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@
class NonLinearOps {
public:
inline static const std::set<TernaryOpType> ternary_collections{TernaryOpType::select};
inline static const std::set<UnaryOpType> unary_collections{
UnaryOpType::abs, UnaryOpType::sin, UnaryOpType::cos, UnaryOpType::tan, UnaryOpType::tanh,
UnaryOpType::asin, UnaryOpType::acos, UnaryOpType::exp, UnaryOpType::log, UnaryOpType::sqrt};
UnaryOpType::abs, UnaryOpType::sin, UnaryOpType::cos, UnaryOpType::tan, UnaryOpType::tanh, UnaryOpType::asin,
UnaryOpType::acos, UnaryOpType::exp, UnaryOpType::log, UnaryOpType::sqrt, UnaryOpType::rsqrt};
inline static const std::set<BinaryOpType> binary_collections{BinaryOpType::mul, BinaryOpType::div,
BinaryOpType::atan2, BinaryOpType::pow};
};

Check notice on line 49 in quadrants/transforms/auto_diff.cpp

View check run for this annotation

Claude / Claude Code Review

AdStackAllocaJudger only traces first LocalLoad: multi-load alloca not promoted for rsqrt

The `AdStackAllocaJudger::visit(LocalLoadStmt*)` method updates `target_alloca_` to the FIRST `LocalLoadStmt` it encounters, so any subsequent `LocalLoad` from the same alloca (where `stmt->src == original_alloca \!= target_alloca_`) is silently ignored; if only the second load feeds a nonlinear op like `rsqrt`, `is_stack_needed_` stays false and the alloca is not promoted to AdStack. This is a pre-existing limitation that affects all ops in `NonLinearOps::unary_collections` equally — the PR cor

class IndependentBlocksJudger : public BasicStmtVisitor {
public:
Expand Down
1 change: 1 addition & 0 deletions tests/python/test_adstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
(qd.sqrt, "sqrt"),
(qd.tanh, "tanh"),
(qd.exp, "exp"),
(qd.rsqrt, "rsqrt"),
],
)
@test_utils.test(require=qd.extension.adstack, ad_stack_experimental_enabled=True)
Expand Down
Loading