Skip to content
Open
121 changes: 121 additions & 0 deletions tests/python/test_adstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,124 @@ def compute():
compute.grad()

assert math.isnan(x.grad[None])


def _run_basic_gradient(n_iter, shall_not_pass):
# Builds the kernel, runs forward + backward, and asserts either a correct gradient (`shall_not_pass=False`,
# adstack enabled) or a compile-time rejection (`shall_not_pass=True`, adstack disabled). Split out so the
# positive test and negative test share one implementation. The adstack is structurally required here so the
# backward compiler can reverse the dynamic `range(n_iter)` at all; without it the compiler refuses to build
# the backward kernel (`Cannot use non static range in Backwards mode`), which is exactly what the negative
# test pins. Value-correctness of the per-iteration `v` spilled on the adstack is NOT exercised by the linear
# body `v = v * 0.95 + 0.01` - see `test_adstack_basic_gradient`'s docstring for the details.
n = 4
x = qd.field(qd.f32, shape=n, needs_grad=True)
y = qd.field(qd.f32, shape=(), needs_grad=True)

@qd.kernel
def compute():
for i in x:
v = x[i]
for _ in range(n_iter):
v = v * 0.95 + 0.01
y[None] += v

x_vals = [0.1, 0.3, 0.5, 0.8]
for i, v in enumerate(x_vals):
x[i] = v
y[None] = 0.0
compute()
y.grad[None] = 1.0
for i in range(n):
x.grad[i] = 0.0

if shall_not_pass:
with pytest.raises(qd.QuadrantsCompilationError, match=r"non static range"):
compute.grad()
return

compute.grad()

# `v = v * 0.95 + 0.01` iterated n_iter times gives v_final = 0.95**n_iter * x[i] + const, so
# dv_final/dx[i] == 0.95**n_iter independent of x[i], and dy/dx[i] equals the same quantity.
expected = 0.95**n_iter
for i in range(n):
assert x.grad[i] == test_utils.approx(expected, rel=1e-4)


@pytest.mark.parametrize("n_iter", [1, 3, 10])
@test_utils.test(require=qd.extension.adstack)
def test_adstack_basic_gradient(n_iter):
# Smallest possible "does reverse-mode AD through a for-loop work at all" check. The kernel runs `n_iter`
# iterations of `v = v * 0.95 + 0.01` per element and asserts that `dy/dx[i]` matches the analytical gradient
# `0.95 ** n_iter` for every element.
#
# Internal details: the adstack is structurally required here so the backward compiler can reverse the
# dynamic `range(n_iter)` at all - the companion `test_adstack_basic_gradient_negative` pins that disabling
# the adstack raises `QuadrantsCompilationError` in exactly this kernel shape. Value-correctness of the
# stored v, on the other hand, is NOT exercised: the loop body `v = v * 0.95 + 0.01` is linear, so the
# backward chain `adj(v_prev) = 0.95 * adj(v_next)` only uses the compile-time constant 0.95 and never reads
# v from the adstack. A broken push/load/pop that returned garbage for v would still produce the same exact
# gradient. For push/load/pop value-correctness coverage, see `test_adstack_unary_loop_carried` (non-linear
# unary ops in the loop body). `n_iter = 1` exercises the single-push adstack code path; `n_iter = 10`
# exercises repeated push/pop under one forward invocation; multi-element coverage (n = 4) guards against
# per-element accumulation bugs that a single-element variant would miss.
_run_basic_gradient(n_iter=n_iter, shall_not_pass=False)
Comment thread
duburcqa marked this conversation as resolved.


@pytest.mark.parametrize("n_iter", [1, 3, 10])
@test_utils.test(ad_stack_experimental_enabled=False)
def test_adstack_basic_gradient_negative(n_iter):
# Negative counterpart of `test_adstack_basic_gradient`: with the adstack disabled the backward compiler
# cannot reverse a dynamic `range(n_iter)`, so `compute.grad()` raises `QuadrantsCompilationError("Cannot use
# non static range in Backwards mode")` deterministically for every `n_iter`. This pins the compile-time
# rejection rather than any gradient value - the helper catches the exception and returns before asserting.
_run_basic_gradient(n_iter=n_iter, shall_not_pass=True)


@pytest.mark.parametrize("n_iter", [1, 3, 10])
@pytest.mark.parametrize("use_static_loop", [True, False])
@pytest.mark.parametrize("use_varying_coeff", [True, False])
@test_utils.test(require=qd.extension.adstack)
def test_adstack_sum_linear(use_static_loop, use_varying_coeff, n_iter):
# Linear accumulation `y = sum_j v * coeff_j` across all four combinations of (static-unrolled vs dynamic loop)
# x (constant coefficient vs loop-index-varying coefficient), at three loop lengths. Replaces the earlier three
# separate tests (`test_adstack_sum_fixed_coeff`, `test_adstack_sum_constant_coeffs`,
# `test_adstack_sum_static_loop_correct`) with a single parametrized version so every branch of that truth
# table is covered at each trip count.
#
# Internal details: this test deliberately does not mutate `v` inside the loop, so the reverse pass does not
# require adstack replay of `v` to compute the right gradient - `v`'s per-iteration value is the same `x[i]`.
# The point of this test is therefore not to stress the adstack (that is `test_adstack_basic_gradient`'s job)
# but to prove that enabling the adstack extension does not silently regress linear reverse-mode AD for either
# unrolled or dynamic loop shapes. No negative counterpart is included: for `use_static_loop=True` the inner
# loop is unrolled and the backward kernel contains no dynamic range, so the adstack option does not change
# the gradient; for `use_static_loop=False` disabling the adstack would raise `QuadrantsCompilationError`
# (same compile-time rejection covered by `test_adstack_basic_gradient_negative`), which is out of scope here.
n = 4
x = qd.field(qd.f32, shape=n, needs_grad=True)
y = qd.field(qd.f32, shape=(), needs_grad=True)

@qd.kernel
def compute():
for i in x:
v = x[i]
for a in qd.static(range(n_iter)) if qd.static(use_static_loop) else range(n_iter):
if qd.static(use_varying_coeff):
y[None] += v * qd.cast(a + 1, qd.f32)
else:
y[None] += v

x_vals = [0.1, 0.3, 0.5, 0.8]
for i, v in enumerate(x_vals):
x[i] = v
y[None] = 0.0
compute()
y.grad[None] = 1.0
for i in range(n):
x.grad[i] = 0.0
compute.grad()

expected = sum(a + 1 for a in range(n_iter)) if use_varying_coeff else float(n_iter)
for i in range(n):
assert x.grad[i] == test_utils.approx(expected, rel=1e-4)
Loading