Skip to content

Commit 49c6160

Browse files
[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos
Pull Request resolved: #20086 Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32). `input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design. Tests live in the stacked test-suite diff above (clean op diff here). Authored with assistance from Claude. ghstack-source-id: 392609088 @exported-using-ghexport Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)
1 parent b7d4d31 commit 49c6160

9 files changed

Lines changed: 1096 additions & 0 deletions

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ set(WEBGPU_SRCS
3535
runtime/ops/add/BinaryOp.cpp
3636
runtime/ops/rms_norm/RmsNorm.cpp
3737
runtime/ops/update_cache/UpdateCache.cpp
38+
runtime/ops/sdpa/Sdpa.cpp
3839
runtime/ops/select_as_symint/SelectAsSymint.cpp
3940
)
4041

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ class WebGPUGraph {
106106
int64_t get_int(int id) const {
107107
return ints_[id];
108108
}
109+
bool get_bool(int id) const {
110+
return bools_[id];
111+
}
109112

110113
// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
111114
// set_symint writes the buffer + marks dirty only if the value changed.

0 commit comments

Comments
 (0)