Skip to content

webgpu: bypass manual mRoPE for text-only Qwen3.5 when GQA fuses RoPE#2245

Draft
qjia7 wants to merge 1 commit into
mainfrom
qwen35-text-only-fused-rope-bypass
Draft

webgpu: bypass manual mRoPE for text-only Qwen3.5 when GQA fuses RoPE#2245
qjia7 wants to merge 1 commit into
mainfrom
qwen35-text-only-fused-rope-bypass

Conversation

@qjia7

@qjia7 qjia7 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Summary

For text-only Qwen3.5, multi-head RoPE (mRoPE) collapses to standard 1D RoPE:
Qwen3_5TextRotaryEmbedding expands a 2D position_ids into 3 identical axes
and apply_interleaved_mrope returns freqs[0] unchanged. The manual mRoPE
subgraph (Shape → Expand → interleaved cos/sin caches → custom kernel) is
therefore equivalent to a plain fused-RoPE pass inside GQA.

When the GQA operator supports fused RoPE (use_rope_in_attn=True, e.g. on
WebGPU), this PR detects the text-only case and routes through the fused path,
bypassing the manual mRoPE subgraph entirely. This removes the Shape → Memcpy node that reads a dynamic tensor shape at runtime — the path that
prevents WebGPU graph capture on Qwen3.5 text-only models.

Changes (src/python/py/models/builders/qwen.py only):

  • Add use_text_only_fused_rope flag: true when is_text_only and use_rope_in_attn.
  • When flag is set: call make_rotary_embedding_caches() (standard 2D cos/sin for GQA), skip mRoPE config, leave use_rope_in_attn=True.
  • When flag is not set: keep existing mRoPE path unchanged (VL mode and non-fused-RoPE EPs).
  • make_position_ids_reformatting: early-return None when fused RoPE is active (no position_ids tensor on the data flow).

Test plan

  • Verify text-only Qwen3.5-0.8B generates correct output on WebGPU with graph capture enabled
  • Verify multimodal Qwen3.5 (VL mode) is unaffected — still uses manual mRoPE path
  • Run existing Qwen integration tests: qwen3-0.6b, qwen2.5-0.5b-instruct

Text-only mRoPE collapses to standard 1D RoPE because
Qwen3_5TextRotaryEmbedding expands a 2D position_ids to 3 identical
axes and apply_interleaved_mrope returns freqs[0] unchanged. When
GQA can perform fused RoPE we therefore bypass the manual mRoPE
subgraph entirely, which removes the Shape -> Memcpy path that
blocks WebGPU graph capture.
@qjia7 qjia7 force-pushed the qwen35-text-only-fused-rope-bypass branch from dd1fcfb to 86440e3 Compare June 26, 2026 08:47
self.attention_attrs["q_norm"] = True
self.attention_attrs["k_norm"] = True
super().make_attention_init(config)
super().make_attention_init()
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)

def make_attention_init(self, config):
def make_attention_init(self):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants