Skip to content

Commit 2e2d743

Browse files
[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos
Pull Request resolved: #20086 Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32). `input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design. Tests live in the stacked test-suite diff above (clean op diff here). Authored with assistance from Claude. ghstack-source-id: 391626188 @exported-using-ghexport Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)
1 parent eed5a5f commit 2e2d743

9 files changed

Lines changed: 1072 additions & 0 deletions

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(WEBGPU_SRCS
3434
runtime/ops/add/BinaryOp.cpp
3535
runtime/ops/rms_norm/RmsNorm.cpp
3636
runtime/ops/update_cache/UpdateCache.cpp
37+
runtime/ops/sdpa/Sdpa.cpp
3738
runtime/ops/select_as_symint/SelectAsSymint.cpp
3839
)
3940

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class WebGPUGraph {
105105
int64_t get_int(int id) const {
106106
return ints_[id];
107107
}
108+
bool get_bool(int id) const {
109+
return bools_[id];
110+
}
108111

109112
// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
110113
// set_symint writes the buffer + marks dirty only if the value changed.

0 commit comments

Comments
 (0)