Skip to content

Add tensor parallelism support for GDN layers + fix UQFF artifact count#2054

Open
ormandj wants to merge 2 commits intoEricLBuehler:masterfrom
ormandj:fix/gdn-tp-support
Open

Add tensor parallelism support for GDN layers + fix UQFF artifact count#2054
ormandj wants to merge 2 commits intoEricLBuehler:masterfrom
ormandj:fix/gdn-tp-support

Conversation

@ormandj
Copy link
Copy Markdown

@ormandj ormandj commented Apr 4, 2026

Fixes #2052

Two issues found while running Qwen3.5-27B with NCCL TP on 2x RTX 3090:

1. GDN layers not TP-aware

In models/gdn.rs, in_proj_qkvz and in_proj_ba were plain Linear while out_proj was RowParallelLayer. The forward pass ran at full width then hit the sharded out_proj, causing mismatch on matmul dim [5120, 3072] [1, 1, 6144].

Fix: shard GDN weights by key-head groups during load() when world_size > 1. Also adjusted the hybrid cache pool dimensions in qwen3_5/text.rs, qwen3_5_moe/text.rs, and qwen3_next.rs to use local (per-rank) head counts.

2. UQFF artifact count mismatch

Serialization filters tensors by isq_serde_supported() (isq.rs:718) but deserialization checked against the unfiltered count, causing Number of artifacts (304) does not match the number of ISQ layers (305). The loading code already handles missing artifacts gracefully via if let Some(artifact), so relaxed the check to only reject cases where there are more artifacts than layers.

Tested with Qwen3.5-27B on 2x RTX 3090 with ISQ Q5K + NCCL TP=2 — inference works correctly. (UQFF + TP has a separate VRAM issue tracked in #2053.)

ormandj added 2 commits April 3, 2026 22:56
The GDN linear attention layers used in Qwen3.5/Qwen3Next were not
TP-aware: in_proj_qkvz and in_proj_ba were loaded as plain Linear
while out_proj was a RowParallelLayer, causing a matmul dimension
mismatch at inference time with NCCL TP.

Fix: shard GDN weights by key-head groups in load(), matching how
the full attention layers handle TP. Also adjust the hybrid cache
pool dimensions to use local (per-rank) head counts.

Fixes EricLBuehler#2052
Serialization filters by isq_serde_supported() (isq.rs:718) but
deserialization checked against the unfiltered count, causing a
mismatch when any tensor doesn't support ISQ serialization.

The loading code already handles missing artifacts gracefully via
the if-let at line 1064, so relax the check to only reject cases
where there are more artifacts than layers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3.5-27B: matmul dimension mismatch with NCCL tensor parallelism (2 GPUs)

1 participant