Native tensor parallelism for wxformer_next (closes #415)#418
Open
jsschreck wants to merge 11 commits into
Open
Native tensor parallelism for wxformer_next (closes #415)#418jsschreck wants to merge 11 commits into
jsschreck wants to merge 11 commits into
Conversation
Copied credit/models/wxformer/wxformer_next.py verbatim from origin/cube-sphere (content of 99d255f, the newest committed version) and registered it as 'nextgen_wxformer'. This is the base for the native tensor-parallel refactor (issue #415); CubeSphereWxFormer stays on the cube-sphere branch and inherits the TP support when that branch merges.
… remap Replace the conv transformer imported from crossformer with local Attention/FeedForward/Transformer classes that use nn.Linear for the qkv/attention-out and FFN up/down projections (a 1x1 conv is a Linear over the channel dim, so the math is unchanged; verified to 1e-6). The fused to_qkv is split into separate to_q/to_k/to_v Linears, the torchtitan pattern, so native ColwiseParallel sharding keeps whole heads per rank. The attention forward derives the head count from the local channel width, which makes it correct under tensor parallelism (heads // tp per rank). Old conv-format checkpoints load via remap_conv_state_dict: to_qkv tensors split into q/k/v, 1x1 kernels viewed as (O, I), spectral-norm weight_orig/u/v handled. load_model/load_model_name remap automatically (idempotent on new-format checkpoints). Note: with spectral norm, q/k/v are now normalized per-projection instead of jointly over the fused matrix; weights are identical, regularization differs slightly. crossformer.py (V1 wxformer) is untouched. Blocks declare _tp_plan for the native parallelize_module path landing next (issue #415).
apply_native_tensor_parallel applies torch parallelize_module with Colwise/RowwiseParallel to every block that declares a _tp_plan dict (wxformer_next's Attention and FeedForward). Params become DTensors that keep their FQNs and logical shapes, so checkpointing/EMA/optimizer state go through the same DCP full-state APIs as FSDP2, and fully_shard over the dp submesh composes on top (2D dp x tp mesh, the torchtitan configuration). The backward all-reduce at the column-parallel input (Megatron's 'f' operator, missing from the legacy path) comes free from DTensor autograd. Validated at parallelize time with clear errors: planned layers must be nn.Linear, colwise out_features / rowwise in_features must divide by the TP degree, heads % tp == 0 via the block's _tp_constraints, and spectral norm is rejected (its power-iteration hook mixes plain u/v buffers with DTensor weights; set use_spectral_norm: false for tensor > 1). distributed_model_wrapper_gen2 routes to the native path when supports_native_tp(model); models without a _tp_plan still hit the legacy apply_tensor_parallel NotImplementedError. The trainer's replicated-grad sync hook is unchanged: apply_native_tensor_parallel stashes _tp_group, and sync_replicated_gradients now also skips DTensors with a Shard placement on the 'tp' mesh dim (wrapper-agnostic, works on the 1D tp mesh and the 2D dp x tp composition).
run_tp_parity.pbs trains nextgen_wxformer twice on a 4-GPU Derecho node with the same dp layout and seed: tp=1 (nproc=2, dp=2) and tp=2 (nproc=4, dp=2 x tp=2). check_parity.py gates on identical per-step loss trajectories, exact first and then rtol=1e-4 to allow float32 reassociation in the rowwise all_reduce. This is the acceptance test the legacy hand-rolled TP could never pass. Configs use use_spectral_norm: false (DTensor TP rejects spectral norm) and a tiny 4-stage model whose head counts (2,4,8,16) all divide tp=2.
Native DTensor TP keeps param FQNs and logical shapes, so with data: fsdp2 the DCP full-state APIs (get/set_model_state_dict) save and resume TP runs like any FSDP2 model. The legacy warn-on-save and raise-on-resume guards now apply only to the legacy module-swapping path: apply_native_tensor_parallel marks the model with _tp_native, and base_trainer / load_model_states_and_optimizer skip the guards for native TP under fsdp2. The legacy apply_tensor_parallel error now points users at the supported route (nextgen_wxformer). Adds a real 2-process gloo CPU test: parallelize a wxformer_next Transformer over a (tp=2) device mesh and require the forward to match the serial model (max diff observed 7e-7) with every parameter receiving a gradient — the wiring check the mocked plan tests cannot cover. The manual run_tp_parity.pbs gains a tp2 resume phase to prove the DCP round-trip on GPU.
…ising apply_native_tensor_parallel no longer raises when a _tp_plan layer is wrapped in spectral norm (the power-iteration hook is incompatible with DTensor-sharded weights). Instead the WHOLE colwise/rowwise group is skipped — a colwise layer's sharded output is only valid feeding its rowwise partner — and the block stays fully replicated on every TP rank, which is correct because TP peers see identical inputs (dp contract) and sync_replicated_gradients pins the replicas. A summary warning reports sharded vs skipped block counts; if nothing sharded, a loud warning says TP had no effect (tp ranks repeat identical compute) but training still runs. _tp_group/_tp_native are stashed either way so the trainer sync hook and DCP checkpoint paths behave the same. Detection covers both spectral-norm styles (hook-based weight_orig and parametrize-based parametrizations.weight). Skipped blocks also skip _tp_constraints, since an unsharded block has no head-divisibility requirement. Note: with the default use_spectral_norm: true, apply_spectral_norm wraps every Linear in the encoder Transformer stages, i.e. ALL _tp_plan layers — reduced TP then shards nothing. Full TP still requires use_spectral_norm: false (README updated). Tests: SN-on-some-blocks (skip only those, warn), SN-on-all (no-op warn, no raise), parametrize-style detection, constraint skip, and a second 2-process gloo parity test with a partially-SN model (21 of 28 params sharded, tp=2 matches serial, grads reach every param incl. weight_orig).
With native TP composed under fully_shard, the TP'd projections carry grads on the 2D (dp, tp) mesh while every other param lives on the 1D dp mesh. torch.nn.utils.clip_grad_norm_ stacks all per-grad norms with foreach ops, and DTensor dispatch rejects operands on different meshes, so the first optimizer step crashed at the clip (Derecho job 6430525). Add clip_grad_norm_/total_grad_norm to credit.parallel.collectives: grads are grouped by device mesh, each homogeneous group goes through torch.nn.utils.get_total_norm unchanged, DTensor group norms are made global with full_tensor() (torchtitan's approach), and the group norms combine as (sum n_i**p)**(1/p) via vector_norm (max for p=inf). Scaling reuses torch.nn.utils.clip_grads_with_norm_ per mesh group, since the foreach multiply rejects mixed lists too. With a single group both steps reduce to exactly torch's computation, so the homogeneous modes already green in the gen2 smoke matrix (plain DDP tensors, all-DTensor FSDP2, domain) are numerically unchanged. trainer_gen2 now uses the helper at both clip sites. The dynamic grad-norm path had the same mixed-mesh stack problem and also mishandled DTensor norms (a DTensor's .norm() is itself a DTensor); DTensor grads now go through the mesh-grouped total norm (already global, no extra all_reduce) while plain grads keep the local sq-sum + SUM all_reduce semantics, including the documented over-count of replicated copies. Tests: gloo 2-process mixed-group clip (DTensor group + plain group from a partial-spectral-norm TP model) must match torch's clip on the materialized full grads and confirm torch itself rejects the mixed collection; CPU parity tests pin the plain-tensor path bit-identical to torch.nn.utils.clip_grad_norm_ (L2/inf, active/inactive clip, tensor max_norm, empty grads).
FSDP2's fully_shard does not broadcast params from rank 0 (unlike DDP), so seeding with seed + data_rank before load_model made every dp rank construct a different model; the resulting global model was a mixture of per-rank inits that varied with the mesh layout, breaking tp=1 vs tp=2 parity from the first batch and making fsdp2 init non-reproducible across world sizes. train_gen2.py now seeds in two stages: the base config seed on all ranks before model construction (identical weights everywhere, as ring-CRPS also requires), then seed + data_rank after the distributed wrapper so runtime RNG (dropout, stochastic preblocks, ensemble perturbations) keeps per-dp-rank diversity while TP/domain peers still share masks. Adds CPU unit tests pinning the pattern: identical state_dicts when the rank offset is applied only after construction, divergence under the old order, and per-rank runtime RNG diversity after the re-seed.
PR Review ChecklistRequired
Recommended
|
Mirrors the #407 review change for crossformer's opt-in attributes: wxformer_next's FeedForward and Attention now take tp_plan in __init__ (None means the standard plan), and supports_native_tp / apply_native_tensor_parallel read instance attributes, which still recognizes class-attribute declarations.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Native DTensor tensor parallelism for the wxformer_next family, replacing the disabled hand-rolled TpCol/TpRow path. The acceptance gate passed on Derecho: tp=1 and tp=2 produce digit-for-digit identical loss trajectories from the same seed (4 GPUs, fsdp2 dp=2 × tp=2), and a tp=2 run resumes from a DCP checkpoint and trains a second epoch. Closes #415.
Stacked on #407 (
v2/fsdp2-parallel); retarget to main after it merges.What's here
to_qkvis split into separateto_q/to_k/to_v(torchtitan pattern). ColwiseParallel sharding is per-head-correct by construction, so the q/k/v boundary scrambling that sank the legacy path is structurally impossible. Attention derives its head count from the local channel width (heads/tp per rank).remap_conv_state_dictconverts old checkpoints on load: qkv split + (O,I,1,1) → (O,I), spectral-norm weight_orig/u/v handled, idempotent. Conv-vs-Linear numerical equivalence is tested to 7e-7.apply_native_tensor_parallelbuilds Colwise/Rowwise plans from a_tp_planclass attribute and composes withfully_shardover the dp×tp mesh. Models without the opt-in still hit the legacy NotImplementedError. Checkpointing goes through the DCP full-state APIs, so native TP + fsdp2 saves/resumes like plain fsdp2.apply_spectral_normwraps every layer in the transformer stages, so withuse_spectral_norm: truenothing shards and a loud warning says so; full TP needsuse_spectral_norm: falsefor now. A DTensor-compatible spectral norm is the long-term fix.clip_grad_norm_cannot stack norms across meshes. Newclip_grad_norm_/total_grad_normincredit/parallel/collectives.pygroup by mesh and combine; the plain-tensor path is bit-identical to torch's (tested).Testing
tests/manual/gen2_parallelism/run_tp_parity.pbs, Derecho job 6431309): tp1 PASS, tp2 PASS, tp2 DCP resume PASS, parity tp1 == tp2 EXACT.