Skip to content

Native tensor parallelism for wxformer_next (closes #415)#418

Open
jsschreck wants to merge 11 commits into
v2/fsdp2-parallelfrom
v2/native-tp-wxformer
Open

Native tensor parallelism for wxformer_next (closes #415)#418
jsschreck wants to merge 11 commits into
v2/fsdp2-parallelfrom
v2/native-tp-wxformer

Conversation

@jsschreck

Copy link
Copy Markdown
Collaborator

Native DTensor tensor parallelism for the wxformer_next family, replacing the disabled hand-rolled TpCol/TpRow path. The acceptance gate passed on Derecho: tp=1 and tp=2 produce digit-for-digit identical loss trajectories from the same seed (4 GPUs, fsdp2 dp=2 × tp=2), and a tp=2 run resumes from a DCP checkpoint and trains a second epoch. Closes #415.

Stacked on #407 (v2/fsdp2-parallel); retarget to main after it merges.

What's here

  • wxformer_next on nn.Linear projections. The transformer blocks' 1×1 Conv2d projections are now nn.Linear, and the fused to_qkv is split into separate to_q/to_k/to_v (torchtitan pattern). ColwiseParallel sharding is per-head-correct by construction, so the q/k/v boundary scrambling that sank the legacy path is structurally impossible. Attention derives its head count from the local channel width (heads/tp per rank).
  • Checkpoint remap. remap_conv_state_dict converts old checkpoints on load: qkv split + (O,I,1,1) → (O,I), spectral-norm weight_orig/u/v handled, idempotent. Conv-vs-Linear numerical equivalence is tested to 7e-7.
  • Opt-in native TP. apply_native_tensor_parallel builds Colwise/Rowwise plans from a _tp_plan class attribute and composes with fully_shard over the dp×tp mesh. Models without the opt-in still hit the legacy NotImplementedError. Checkpointing goes through the DCP full-state APIs, so native TP + fsdp2 saves/resumes like plain fsdp2.
  • Reduced TP under spectral norm. SN-wrapped blocks are skipped as whole colwise/rowwise pairs (warn, not raise). Note apply_spectral_norm wraps every layer in the transformer stages, so with use_spectral_norm: true nothing shards and a loud warning says so; full TP needs use_spectral_norm: false for now. A DTensor-compatible spectral norm is the long-term fix.
  • Mixed-mesh gradient clipping. TP'd params live on the (dp, tp) mesh and the rest on the dp mesh; torch's clip_grad_norm_ cannot stack norms across meshes. New clip_grad_norm_/total_grad_norm in credit/parallel/collectives.py group by mesh and combine; the plain-tensor path is bit-identical to torch's (tested).
  • Init seeding fix. Per-dp-rank seeding before model construction made fsdp2's global init a mesh-layout-dependent mixture (fully_shard never broadcasts). All ranks now seed identically for construction, then re-seed per dp rank for runtime diversity. This commit is also cherry-picked onto Updates to CREDIT parallelism following pytorch's latest schema #407 (35686f8) since it affects every fsdp2 run.

Testing

  • 843 tests pass (47 new), including two real 2-process gloo DTensor tests: tp=2 matches serial to 7e-7 with gradients reaching all params, and the mixed-mesh clip reproduces and fixes the exact GPU crash.
  • GPU acceptance (tests/manual/gen2_parallelism/run_tp_parity.pbs, Derecho job 6431309): tp1 PASS, tp2 PASS, tp2 DCP resume PASS, parity tp1 == tp2 EXACT.

Copied credit/models/wxformer/wxformer_next.py verbatim from
origin/cube-sphere (content of 99d255f, the newest committed version) and
registered it as 'nextgen_wxformer'. This is the base for the native
tensor-parallel refactor (issue #415); CubeSphereWxFormer stays on the
cube-sphere branch and inherits the TP support when that branch merges.
… remap

Replace the conv transformer imported from crossformer with local
Attention/FeedForward/Transformer classes that use nn.Linear for the
qkv/attention-out and FFN up/down projections (a 1x1 conv is a Linear over
the channel dim, so the math is unchanged; verified to 1e-6). The fused
to_qkv is split into separate to_q/to_k/to_v Linears, the torchtitan
pattern, so native ColwiseParallel sharding keeps whole heads per rank.
The attention forward derives the head count from the local channel width,
which makes it correct under tensor parallelism (heads // tp per rank).

Old conv-format checkpoints load via remap_conv_state_dict: to_qkv tensors
split into q/k/v, 1x1 kernels viewed as (O, I), spectral-norm
weight_orig/u/v handled. load_model/load_model_name remap automatically
(idempotent on new-format checkpoints). Note: with spectral norm, q/k/v
are now normalized per-projection instead of jointly over the fused
matrix; weights are identical, regularization differs slightly.

crossformer.py (V1 wxformer) is untouched. Blocks declare _tp_plan for
the native parallelize_module path landing next (issue #415).
apply_native_tensor_parallel applies torch parallelize_module with
Colwise/RowwiseParallel to every block that declares a _tp_plan dict
(wxformer_next's Attention and FeedForward). Params become DTensors that
keep their FQNs and logical shapes, so checkpointing/EMA/optimizer state
go through the same DCP full-state APIs as FSDP2, and fully_shard over
the dp submesh composes on top (2D dp x tp mesh, the torchtitan
configuration). The backward all-reduce at the column-parallel input
(Megatron's 'f' operator, missing from the legacy path) comes free from
DTensor autograd.

Validated at parallelize time with clear errors: planned layers must be
nn.Linear, colwise out_features / rowwise in_features must divide by the
TP degree, heads % tp == 0 via the block's _tp_constraints, and spectral
norm is rejected (its power-iteration hook mixes plain u/v buffers with
DTensor weights; set use_spectral_norm: false for tensor > 1).

distributed_model_wrapper_gen2 routes to the native path when
supports_native_tp(model); models without a _tp_plan still hit the
legacy apply_tensor_parallel NotImplementedError. The trainer's
replicated-grad sync hook is unchanged: apply_native_tensor_parallel
stashes _tp_group, and sync_replicated_gradients now also skips DTensors
with a Shard placement on the 'tp' mesh dim (wrapper-agnostic, works on
the 1D tp mesh and the 2D dp x tp composition).
run_tp_parity.pbs trains nextgen_wxformer twice on a 4-GPU Derecho node
with the same dp layout and seed: tp=1 (nproc=2, dp=2) and tp=2 (nproc=4,
dp=2 x tp=2). check_parity.py gates on identical per-step loss
trajectories, exact first and then rtol=1e-4 to allow float32
reassociation in the rowwise all_reduce. This is the acceptance test the
legacy hand-rolled TP could never pass. Configs use
use_spectral_norm: false (DTensor TP rejects spectral norm) and a tiny
4-stage model whose head counts (2,4,8,16) all divide tp=2.
Native DTensor TP keeps param FQNs and logical shapes, so with data:
fsdp2 the DCP full-state APIs (get/set_model_state_dict) save and resume
TP runs like any FSDP2 model. The legacy warn-on-save and
raise-on-resume guards now apply only to the legacy module-swapping
path: apply_native_tensor_parallel marks the model with _tp_native, and
base_trainer / load_model_states_and_optimizer skip the guards for
native TP under fsdp2. The legacy apply_tensor_parallel error now points
users at the supported route (nextgen_wxformer).

Adds a real 2-process gloo CPU test: parallelize a wxformer_next
Transformer over a (tp=2) device mesh and require the forward to match
the serial model (max diff observed 7e-7) with every parameter receiving
a gradient — the wiring check the mocked plan tests cannot cover. The
manual run_tp_parity.pbs gains a tp2 resume phase to prove the DCP
round-trip on GPU.
…ising

apply_native_tensor_parallel no longer raises when a _tp_plan layer is
wrapped in spectral norm (the power-iteration hook is incompatible with
DTensor-sharded weights). Instead the WHOLE colwise/rowwise group is
skipped — a colwise layer's sharded output is only valid feeding its
rowwise partner — and the block stays fully replicated on every TP rank,
which is correct because TP peers see identical inputs (dp contract) and
sync_replicated_gradients pins the replicas. A summary warning reports
sharded vs skipped block counts; if nothing sharded, a loud warning says
TP had no effect (tp ranks repeat identical compute) but training still
runs. _tp_group/_tp_native are stashed either way so the trainer sync
hook and DCP checkpoint paths behave the same.

Detection covers both spectral-norm styles (hook-based weight_orig and
parametrize-based parametrizations.weight). Skipped blocks also skip
_tp_constraints, since an unsharded block has no head-divisibility
requirement.

Note: with the default use_spectral_norm: true, apply_spectral_norm
wraps every Linear in the encoder Transformer stages, i.e. ALL _tp_plan
layers — reduced TP then shards nothing. Full TP still requires
use_spectral_norm: false (README updated).

Tests: SN-on-some-blocks (skip only those, warn), SN-on-all (no-op warn,
no raise), parametrize-style detection, constraint skip, and a second
2-process gloo parity test with a partially-SN model (21 of 28 params
sharded, tp=2 matches serial, grads reach every param incl. weight_orig).
With native TP composed under fully_shard, the TP'd projections carry
grads on the 2D (dp, tp) mesh while every other param lives on the 1D dp
mesh. torch.nn.utils.clip_grad_norm_ stacks all per-grad norms with
foreach ops, and DTensor dispatch rejects operands on different meshes,
so the first optimizer step crashed at the clip (Derecho job 6430525).

Add clip_grad_norm_/total_grad_norm to credit.parallel.collectives:
grads are grouped by device mesh, each homogeneous group goes through
torch.nn.utils.get_total_norm unchanged, DTensor group norms are made
global with full_tensor() (torchtitan's approach), and the group norms
combine as (sum n_i**p)**(1/p) via vector_norm (max for p=inf). Scaling
reuses torch.nn.utils.clip_grads_with_norm_ per mesh group, since the
foreach multiply rejects mixed lists too. With a single group both steps
reduce to exactly torch's computation, so the homogeneous modes already
green in the gen2 smoke matrix (plain DDP tensors, all-DTensor FSDP2,
domain) are numerically unchanged.

trainer_gen2 now uses the helper at both clip sites. The dynamic
grad-norm path had the same mixed-mesh stack problem and also mishandled
DTensor norms (a DTensor's .norm() is itself a DTensor); DTensor grads
now go through the mesh-grouped total norm (already global, no extra
all_reduce) while plain grads keep the local sq-sum + SUM all_reduce
semantics, including the documented over-count of replicated copies.

Tests: gloo 2-process mixed-group clip (DTensor group + plain group from
a partial-spectral-norm TP model) must match torch's clip on the
materialized full grads and confirm torch itself rejects the mixed
collection; CPU parity tests pin the plain-tensor path bit-identical to
torch.nn.utils.clip_grad_norm_ (L2/inf, active/inactive clip, tensor
max_norm, empty grads).
FSDP2's fully_shard does not broadcast params from rank 0 (unlike DDP),
so seeding with seed + data_rank before load_model made every dp rank
construct a different model; the resulting global model was a mixture
of per-rank inits that varied with the mesh layout, breaking tp=1 vs
tp=2 parity from the first batch and making fsdp2 init non-reproducible
across world sizes.

train_gen2.py now seeds in two stages: the base config seed on all
ranks before model construction (identical weights everywhere, as
ring-CRPS also requires), then seed + data_rank after the distributed
wrapper so runtime RNG (dropout, stochastic preblocks, ensemble
perturbations) keeps per-dp-rank diversity while TP/domain peers still
share masks.

Adds CPU unit tests pinning the pattern: identical state_dicts when the
rank offset is applied only after construction, divergence under the
old order, and per-rank runtime RNG diversity after the re-seed.
@github-actions

Copy link
Copy Markdown

PR Review Checklist

Required

  • There is a clear description of what issue or feature the pull request is addressing.
  • Issues covered by the pull request are tagged in the description.
  • All CI checks and tests pass.
  • Tests have been updated to cover code changes.
  • Documentation has been updated to cover code changes and renders properly.
  • The dependency lists in pyproject.toml and requirements.txt have been updated.
  • The reviewer has provided both positive and constructive feedback in their review response.
  • Changes affecting GPU-related functionality have been run by the submitter on Casper and/or Derecho to verify that the code runs as expected.

Recommended

  • Updated public facing methods have full docstrings.
  • Variable names balance being unambiguous and concise.
  • Comments have been added to describe more complex operations.
  • Code minimizes redundancy with the use of loops, function/method calls, and robust data structures.
  • If dependencies are added, they should not burden the user with additional installation steps or break other parts of the code.
  • If dependencies are removed, changes to existing code should be tested and any issues addressed.
  • Tests cover expected outcomes of a function and known edge cases.
  • Type annotation is added to inputs and outputs of functions/methods.

Mirrors the #407 review change for crossformer's opt-in attributes:
wxformer_next's FeedForward and Attention now take tp_plan in __init__
(None means the standard plan), and supports_native_tp /
apply_native_tensor_parallel read instance attributes, which still
recognizes class-attribute declarations.
@djgagne djgagne self-requested a review June 23, 2026 00:01
@djgagne djgagne added the enhancement New feature or request label Jun 23, 2026
@djgagne djgagne added this to the v2026.2.0 milestone Jun 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants