Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions credit/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,23 @@ def distributed_model_wrapper_gen2(conf: dict, model, device):

# ── 2. Tensor parallelism ─────────────────────────────────────────────
if tp_size > 1:
from credit.parallel import apply_tensor_parallel
from credit.parallel import apply_native_tensor_parallel, apply_tensor_parallel, supports_native_tp

tp_mesh = submeshes.get("tp")
if tp_mesh is None:
raise ValueError("TP mesh not found — check tensor > 1 and world_size")
model = apply_tensor_parallel(model, tp_mesh)
logging.info(f"[V2] Tensor parallelism: degree={tp_size}")
if supports_native_tp(model):
# TP must convert params to DTensors on the target device before
# fully_shard; move first so distribute_tensor doesn't shard CPU
# params onto a cuda mesh.
model.to(device)
model = apply_native_tensor_parallel(model, tp_mesh)
logging.info(f"[V2] Native tensor parallelism (DTensor): degree={tp_size}")
else:
# Legacy hand-rolled TP — disabled, raises NotImplementedError
# citing issue #415 (models without a _tp_plan are unsupported).
model = apply_tensor_parallel(model, tp_mesh)
logging.info(f"[V2] Tensor parallelism: degree={tp_size}")

# Ensure all parameters/buffers are on the target device after domain/TP
# transforms (which may create new modules via wrapping). Must happen before
Expand Down
5 changes: 5 additions & 0 deletions credit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
"DownscalingSegmentationModel",
"Loading downscaling U-net",
),
"nextgen_wxformer": (
"credit.models.wxformer.wxformer_next",
"NextGenWXFormer",
"Loading NextGen WXFormer (CrossFormer U-Net + spectral GNN bottleneck + column attention) ...",
),
}

# Backward-compatible name -> (module_path, class_name) for direct attribute access
Expand Down
Loading
Loading