Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions quadrants/codegen/spirv/detail/spirv_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,24 +194,56 @@ class TaskCodegen : public IRVisitor {

bool use_volatile_buffer_access_{false};

// Where the primal/adjoint storage for an AdStack lives. `heap_float` backs f32 adstacks and `heap_int` backs
// i32 and u1 adstacks (u1 stored as i32 to match the historical Function-scope path's bool->int remap in
// `get_array_type`); other primitive types are hard-errored by `visit(AdStackAllocaStmt)`, so no Function-scope
// fallback exists. Each kind maps to its own per-dispatch StorageBuffer (`BufferType::AdStackHeapFloat` /
// `BufferType::AdStackHeapInt`).
enum class AdStackHeapKind { heap_float, heap_int };
struct AdStackSpirv {
spirv::Value count_var; // u32, Function scope - current number of entries
spirv::Value primal_arr; // Array<storage_type, max_size>, Function scope
spirv::Value adjoint_arr; // Array<storage_type, max_size>, Function scope
// `elem_type` is the logical loop-carried value's SPIR-V type (e.g. bool for a u1 adstack). `storage_type`
// is what the backing array is actually declared as: identical to `elem_type` except for u1, where the
// array is declared as i32 because `IRBuilder::get_array_type` silently promotes OpTypeBool (which has no
// defined storage layout under LogicalAddressing) to i32. Push/LoadTop/AccAdjoint must use `storage_type`
// for the OpAccessChain / load-store pair, and cast between `elem_type` and `storage_type` around the
// caller-visible value - otherwise SPIR-V codegen emits `OpAccessChain %_ptr_Function_bool %arr_of_int_N`,
// which spirv-val rejects with "result type OpTypeBool does not match the type that results from
// indexing into OpTypeInt" and AMD's native Vulkan driver runs anyway and segfaults the dispatch.
spirv::SType elem_type;
spirv::SType storage_type;
spirv::Value count_var; // u32, Function scope - current number of entries
AdStackHeapKind heap_kind;
// Offsets are in elements of the heap's element type (f32 or i32).
uint32_t heap_primal_offset{0};
uint32_t heap_adjoint_offset{0};
uint32_t max_size{0};
spirv::SType elem_type;
};
std::unordered_map<const Stmt *, AdStackSpirv> ad_stacks_;
spirv::Value ad_stack_access(spirv::Value arr, spirv::Value index, const spirv::SType &elem_type);
// Total per-thread heap strides, pre-computed from the IR before any visitor runs so that
// `invoc_id * stride` captures the final value. Exposed via `task_attribs.ad_stack_heap_per_thread_stride_*` so
// the runtime can size the heaps. The float stride is counted in f32 elements, the int stride in i32 elements.
uint32_t ad_stack_heap_per_thread_stride_float_{0};
uint32_t ad_stack_heap_per_thread_stride_int_{0};
// Running offsets into the per-thread slice assigned to the next AdStackAllocaStmt visitor. Each ends equal to
// the corresponding stride once every alloca has been visited.
uint32_t ad_stack_heap_next_offset_float_{0};
uint32_t ad_stack_heap_next_offset_int_{0};
// Buffers are cached for reuse across push/pop/load-top visitors and (re)computed lazily on first use inside a
// task so the `OpLoad` falls inside the dispatch body rather than the function header.
spirv::Value ad_stack_heap_buffer_float_;
spirv::Value ad_stack_heap_buffer_int_;
// `invoc_id * stride` thread-base values. Despite being cached like the buffers, these are NOT lazy: they are
// emitted eagerly from `visit(AdStackAllocaStmt)` so the `OpIMul` lives in the alloca's enclosing block, which
// strictly dominates every sibling inner loop that later references the cached SSA id. Emitting them lazily
// from the first `AdStackPush/LoadTop` visitor would place the multiply in the first loop's body, and the
// second sibling loop would reuse an SSA id defined in a non-dominating block (SPIR-V spec section 2.16).
// Do NOT move these to a lazy path; the corresponding getters enforce eager emission.
spirv::Value ad_stack_heap_thread_base_float_;
spirv::Value ad_stack_heap_thread_base_int_;
// Return (lazily) the StorageBuffer of `Array<f32>` that backs f32 adstacks for this dispatch, and the
// per-thread base index inside it.
spirv::Value get_ad_stack_heap_buffer_float();
spirv::Value get_ad_stack_heap_thread_base_float();
spirv::Value ad_stack_heap_float_ptr(uint32_t offset, spirv::Value count);
// Same accessors for the int-typed heap buffer (backs i32 and u1 adstacks).
spirv::Value get_ad_stack_heap_buffer_int();
spirv::Value get_ad_stack_heap_thread_base_int();
spirv::Value ad_stack_heap_int_ptr(uint32_t offset, spirv::Value count);
// Routes to the correct backing-typed pointer (`*f32` for `heap_float`, `*i32` for `heap_int`) based on
// `info.heap_kind`. See comment on the implementation for the bool<->i32 conversion contract.
spirv::Value ad_stack_slot_ptr(AdStackSpirv &info, spirv::Value idx, bool primal);
spirv::SType ad_stack_backing_type(const AdStackSpirv &info) const;
};
} // namespace detail
} // namespace spirv
Expand Down
15 changes: 15 additions & 0 deletions quadrants/codegen/spirv/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ std::string TaskAttributes::buffers_name(BufferInfo b) {
if (b.type == BufferType::Root) {
return std::string("Root: ") + std::to_string(b.root_id);
}
if (b.type == BufferType::ListGen) {
return "ListGen";
}
if (b.type == BufferType::ExtArr) {
return "ExtArr";
}
if (b.type == BufferType::AdStackOverflow) {
return "AdStackOverflow";
}
if (b.type == BufferType::AdStackHeapFloat) {
return "AdStackHeapFloat";
}
if (b.type == BufferType::AdStackHeapInt) {
return "AdStackHeapInt";
}
QD_ERROR("unrecognized buffer type");
}

Expand Down
42 changes: 39 additions & 3 deletions quadrants/codegen/spirv/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@ namespace spirv {
* Per offloaded task attributes.
*/
struct TaskAttributes {
enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr, AdStackOverflow };
enum class BufferType {
Root,
GlobalTmps,
Args,
Rets,
ListGen,
ExtArr,
AdStackOverflow,
AdStackHeapFloat,
AdStackHeapInt,
};

struct BufferInfo {
BufferType type;
Expand Down Expand Up @@ -111,12 +121,36 @@ struct TaskAttributes {
return (const_begin && const_end);
}

QD_IO_DEF(begin, end, const_begin, const_end);
// When the range end is non-const and the IR encodes it as a product of one or more ndarray-shape lookups (a
// common `qd.ndrange(arr.shape[...], ...)` pattern), the codegen extracts each `ExternalTensorShapeAlongAxisStmt`
// into this list. At launch time, host-side `LaunchContextBuilder` already has every ndarray's shape in
// `array_ptrs` / struct args, so the actual iteration bound is `product over refs of arg[arg_id].shape[axis]`.
// The runtime uses that as a tight cap on `advisory_total_num_threads` to avoid oversizing the per-thread
// adstack heap (otherwise `kMaxNumThreadsGridStrideLoop` defaults to 131072 for a B=1 workload and the heap
// allocation requests multi-GB that exceeds Metal's `maxBufferLength`). Empty means the end expression could
// not be simplified to a pure product of shape lookups; fall back to the advisory thread count in that case.
struct ArgShapeRef {
std::vector<int> arg_id;
int axis{0};
QD_IO_DEF(arg_id, axis);
};
std::vector<ArgShapeRef> end_shape_product;

QD_IO_DEF(begin, end, const_begin, const_end, end_shape_product);
};
std::vector<BufferBind> buffer_binds;
// Only valid when |task_type| is range_for.
std::optional<RangeForAttributes> range_for_attribs;

// Per-thread stride, in f32 elements, of the f32-typed heap-backed adstack slice used by this task, bound as
// BufferType::AdStackHeapFloat. Zero when the task has no f32 adstack. The runtime multiplies this by the
// dispatched invocation count to size the shared adstack buffer.
uint32_t ad_stack_heap_per_thread_stride_float{0};
// Per-thread stride, in i32 elements, of the int-typed heap-backed adstack slice used by this task, bound as
// BufferType::AdStackHeapInt. Backs both i32 and u1 adstacks (u1 is stored as i32, matching the existing
// Function-scope path). Zero when the task has no non-f32 adstack.
uint32_t ad_stack_heap_per_thread_stride_int{0};

static std::string buffers_name(BufferInfo b);

std::string debug_string() const;
Expand All @@ -126,7 +160,9 @@ struct TaskAttributes {
advisory_num_threads_per_group,
task_type,
buffer_binds,
range_for_attribs);
range_for_attribs,
ad_stack_heap_per_thread_stride_float,
ad_stack_heap_per_thread_stride_int);
};

/**
Expand Down
Loading
Loading