Skip to content

Commit 99192d2

Browse files
committed
feat: Compute layer norms for a model checkpoint
1 parent 8f84b2d commit 99192d2

1 file changed

Lines changed: 158 additions & 0 deletions

File tree

scripts/compute_layer_norms.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import json
5+
import os
6+
from pathlib import Path
7+
from typing import cast
8+
9+
import torch
10+
import torch.distributed as dist
11+
from pydantic import BaseModel
12+
from torch.distributed.device_mesh import DeviceMesh
13+
from torch.distributed.tensor import DTensor
14+
15+
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading
16+
from modalities.checkpointing.stateful.app_state import AppState
17+
from modalities.config.config import ProcessGroupBackendType
18+
from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType
19+
from modalities.main import Main
20+
from modalities.running_env.cuda_env import CudaEnv
21+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method
22+
23+
24+
class ComponentsInstantiationModel(BaseModel):
25+
app_state: PydanticAppStateType
26+
device_mesh: PydanticDeviceMeshIFType | None = None
27+
28+
29+
def _parse_args() -> argparse.Namespace:
30+
parser = argparse.ArgumentParser(description="Load a Modalities DCP checkpoint into an app state.")
31+
parser.add_argument("--config-file-path", type=Path, required=True, help="Path to the YAML config file.")
32+
parser.add_argument(
33+
"--experiments-root-path",
34+
type=Path,
35+
required=True,
36+
help="Path passed to Main for resolver/context setup.",
37+
)
38+
parser.add_argument(
39+
"--checkpoint-dir-path",
40+
type=Path,
41+
default=None,
42+
help="Path to a checkpoint directory containing *.distcp files.",
43+
)
44+
parser.add_argument(
45+
"--last-checkpoint-info-path",
46+
type=Path,
47+
default=None,
48+
help="Path to last_checkpoint_info.json. Used when checkpoint-dir-path is omitted.",
49+
)
50+
return parser.parse_args()
51+
52+
53+
def _resolve_checkpoint_dir_path(args: argparse.Namespace) -> Path:
54+
if args.checkpoint_dir_path is not None and args.last_checkpoint_info_path is not None:
55+
raise ValueError("Pass either --checkpoint-dir-path or --last-checkpoint-info-path, not both.")
56+
57+
if args.checkpoint_dir_path is not None:
58+
return args.checkpoint_dir_path
59+
60+
if args.last_checkpoint_info_path is None:
61+
raise ValueError("Pass one of --checkpoint-dir-path or --last-checkpoint-info-path.")
62+
63+
with open(args.last_checkpoint_info_path, "r", encoding="utf-8") as f:
64+
checkpoint_info = json.load(f)
65+
66+
return Path(checkpoint_info["checkpoint_folder_path"])
67+
68+
69+
def _get_layer_key(parameter_name: str) -> str:
70+
# Strip common wrapping prefixes that appear for wrapped modules.
71+
name = parameter_name
72+
for prefix in ("module.", "_orig_mod.", "_fsdp_wrapped_module."):
73+
if name.startswith(prefix):
74+
name = name[len(prefix) :]
75+
76+
tokens = name.split(".")
77+
for i in range(len(tokens) - 1):
78+
if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit():
79+
if i > 0:
80+
return ".".join(tokens[i - 1 : i + 2])
81+
return ".".join(tokens[i : i + 2])
82+
83+
# Fallback: group by parent module path if no canonical layer index token exists.
84+
return ".".join(tokens[:-1]) if len(tokens) > 1 else name
85+
86+
87+
def _get_dp_shard_group(device_mesh: DeviceMesh | None):
88+
if device_mesh is None:
89+
return None
90+
try:
91+
return get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.DP_SHARD).get_group()
92+
except Exception:
93+
# Fallback to the default process group if a dedicated DP-shard group is unavailable.
94+
return None
95+
96+
97+
def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> None:
98+
layer_sq_sums: dict[str, torch.Tensor] = {}
99+
100+
for model_part_idx, model_part in enumerate(app_state.model_parts):
101+
for name, parameter in model_part.named_parameters():
102+
if not parameter.requires_grad:
103+
continue
104+
full_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name
105+
layer_key = _get_layer_key(full_name)
106+
107+
# FSDP2 parameters can be DTensors. Convert to local shard first so c10d all_reduce
108+
# operates on plain tensors instead of DTensors.
109+
local_param = parameter.to_local() if isinstance(parameter, DTensor) else parameter
110+
local_sq_sum = local_param.detach().float().pow(2).sum()
111+
layer_sq_sums[layer_key] = layer_sq_sums.get(layer_key, torch.zeros_like(local_sq_sum)) + local_sq_sum
112+
113+
# Aggregate over the DP-shard group to reconstruct global norms for sharded parameters.
114+
for layer_key, sq_sum in layer_sq_sums.items():
115+
dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group)
116+
layer_sq_sums[layer_key] = sq_sum
117+
118+
if dist.get_rank() == 0:
119+
print("Per-layer parameter L2 norms (global across DP-shards):")
120+
for layer_key in sorted(layer_sq_sums):
121+
norm = torch.sqrt(layer_sq_sums[layer_key]).item()
122+
print(f"{layer_key}: {norm:.6f}")
123+
124+
125+
def main() -> None:
126+
args = _parse_args()
127+
checkpoint_dir_path = _resolve_checkpoint_dir_path(args)
128+
129+
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
130+
rank = dist.get_rank()
131+
132+
main_obj = Main(
133+
config_path=args.config_file_path,
134+
experiments_root_path=args.experiments_root_path,
135+
)
136+
components = cast(
137+
ComponentsInstantiationModel,
138+
main_obj.build_components(components_model_type=ComponentsInstantiationModel),
139+
)
140+
141+
app_state = cast(AppState, getattr(components, "app_state"))
142+
device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None))
143+
144+
loader = DCPCheckpointLoading(global_rank=rank)
145+
loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path)
146+
147+
dp_shard_group = _get_dp_shard_group(device_mesh)
148+
_compute_and_print_layer_norms(app_state, dp_shard_group)
149+
150+
if rank == 0:
151+
print(
152+
f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} "
153+
f"(pid={os.getpid()})."
154+
)
155+
156+
157+
if __name__ == "__main__":
158+
main()

0 commit comments

Comments
 (0)