Commit 49c6160
committed
[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos
Pull Request resolved: #20086
Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32).
`input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design.
Tests live in the stacked test-suite diff above (clean op diff here).
Authored with assistance from Claude.
ghstack-source-id: 392609088
@exported-using-ghexport
Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)1 parent b7d4d31 commit 49c6160
9 files changed
Lines changed: 1096 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
106 | 106 | | |
107 | 107 | | |
108 | 108 | | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
109 | 112 | | |
110 | 113 | | |
111 | 114 | | |
| |||
0 commit comments