[ExecuTorch][WebGPU] Add et_vk.apply_rotary_emb (interleaved RoPE) + ValueList multi-output#20264
[ExecuTorch][WebGPU] Add et_vk.apply_rotary_emb (interleaved RoPE) + ValueList multi-output#20264JulianCloudNTH wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20264
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 22 New Failures, 1 Unrelated FailureAs of commit f2d1ae0 with merge base 5526971 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Stack from ghstack (oldest at bottom):
Adds the WebGPU backend handler for
et_vk.apply_rotary_emb.default(interleaved Llama rotary positional embedding) plus theValueListgraph-value support its multi-output signature requires.The op rotates the query and key tensors by a shared
freqs_cos/freqs_sinpair and is composed of two dispatches of one WGSL kernel: each thread handles one (even, odd) element pair of a head row (out[2i] = x[2i]*cos - x[2i+1]*sin,out[2i+1] = x[2i]*sin + x[2i+1]*cos), one dispatch writingxq_outand one writingxk_out, mirroring the Vulkanapply_rotary_embreference (buffer-only, fp32, the interleaved.defaultvariant). Each dispatch owns a distinct compute pipeline (the graph destructor releases per dispatch, so a shared handle would double-free); the workgroup size is awg_sizepipeline-override constant clamped to the device limit, both 1D dispatch counts go throughWebGPUUtils::compute_1d_workgroup_countand are validated before any GPU-object allocation, and the embedded WGSL header is generated bygen_wgsl_headers.py.The two outputs (
xq_out,xk_out) are serialized by the Vulkan exporter as a singleValueListgraph value, which the runtime did not previously model. This adds theValueType::ValueListvalue kind, avalue_lists_table populated duringbuild(), and aget_value_listaccessor the handler uses to resolve the output ids. While in that code path it also closes a latent gap: a constant tensor whoseconstant_idis set but whose constants table is missing or out of range now throws (fail-loud) rather than silently leaving the buffer uninitialized.Differential Revision: D108428756