Add tensor parallelism support for GDN layers + fix UQFF artifact count#2054
Open
ormandj wants to merge 2 commits intoEricLBuehler:masterfrom
Open
Add tensor parallelism support for GDN layers + fix UQFF artifact count#2054ormandj wants to merge 2 commits intoEricLBuehler:masterfrom
ormandj wants to merge 2 commits intoEricLBuehler:masterfrom
Conversation
The GDN linear attention layers used in Qwen3.5/Qwen3Next were not TP-aware: in_proj_qkvz and in_proj_ba were loaded as plain Linear while out_proj was a RowParallelLayer, causing a matmul dimension mismatch at inference time with NCCL TP. Fix: shard GDN weights by key-head groups in load(), matching how the full attention layers handle TP. Also adjust the hybrid cache pool dimensions to use local (per-rank) head counts. Fixes EricLBuehler#2052
Serialization filters by isq_serde_supported() (isq.rs:718) but deserialization checked against the unfiltered count, causing a mismatch when any tensor doesn't support ISQ serialization. The loading code already handles missing artifacts gracefully via the if-let at line 1064, so relax the check to only reject cases where there are more artifacts than layers.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #2052
Two issues found while running Qwen3.5-27B with NCCL TP on 2x RTX 3090:
1. GDN layers not TP-aware
In
models/gdn.rs,in_proj_qkvzandin_proj_bawere plainLinearwhileout_projwasRowParallelLayer. The forward pass ran at full width then hit the sharded out_proj, causingmismatch on matmul dim [5120, 3072] [1, 1, 6144].Fix: shard GDN weights by key-head groups during
load()whenworld_size > 1. Also adjusted the hybrid cache pool dimensions inqwen3_5/text.rs,qwen3_5_moe/text.rs, andqwen3_next.rsto use local (per-rank) head counts.2. UQFF artifact count mismatch
Serialization filters tensors by
isq_serde_supported()(isq.rs:718) but deserialization checked against the unfiltered count, causingNumber of artifacts (304) does not match the number of ISQ layers (305). The loading code already handles missing artifacts gracefully viaif let Some(artifact), so relaxed the check to only reject cases where there are more artifacts than layers.Tested with Qwen3.5-27B on 2x RTX 3090 with ISQ Q5K + NCCL TP=2 — inference works correctly. (UQFF + TP has a separate VRAM issue tracked in #2053.)