Skip to content

Commit f638aec

Browse files
committed
fix: mixed precision bug in ep layers
1 parent e0217fb commit f638aec

3 files changed

Lines changed: 43 additions & 52 deletions

File tree

config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ ep_model:
238238
device_mesh:
239239
instance_key: device_mesh
240240
pass_type: BY_REFERENCE
241+
mixed_precision_settings:
242+
param_dtype: BF_16
243+
reduce_dtype: BF_16
241244
block_names: [TransformerBlock]
242245

243246
ac_model:

src/modalities/config/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ class EPWrappedModelConfig(BaseModel):
338338
model: PydanticPytorchModuleOrListType
339339
block_names: list[str]
340340
device_mesh: PydanticDeviceMeshIFType
341+
mixed_precision_settings: FSDP2MixedPrecisionSettings
341342

342343

343344
class DebuggingEnrichedModelConfig(BaseModel):
Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,21 @@
11
import warnings
22

3-
import torch
43
import torch.distributed as dist
54
import torch.nn as nn
6-
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
75
from torch.distributed.device_mesh import DeviceMesh
6+
from torch.distributed.tensor import DTensor
87

98
from modalities.models.parallelism.expert_parallelism import ExpertParallel
9+
from modalities.running_env.env_utils import FSDP2MixedPrecisionSettings
1010
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method
1111
from modalities.util import get_module_class_from_name
1212

1313

14-
def _validate_moe_block_for_ep(module) -> None:
15-
if not hasattr(module, "experts"):
16-
raise ValueError(f"Module {type(module).__name__} has no 'experts' attribute")
17-
18-
experts = module.experts
19-
required_attrs = ["w1", "w2"]
20-
missing = [attr for attr in required_attrs if not hasattr(experts, attr)]
21-
if missing:
22-
raise ValueError(
23-
f"Module {type(module).__name__}.experts is not grouped-experts compatible. Missing: {missing}"
24-
)
25-
26-
if experts.w1.ndim != 3 or experts.w2.ndim != 3:
27-
raise ValueError(
28-
f"Expected grouped expert parameters with ndim=3. Got w1.ndim={experts.w1.ndim}, "
29-
f"w2.ndim={experts.w2.ndim}"
30-
)
31-
32-
33-
def _get_ep_target_module(module):
34-
if hasattr(module, "experts"):
35-
return module
36-
37-
ffn = getattr(module, "ffn", None)
38-
if ffn is not None and hasattr(ffn, "experts"):
39-
return ffn
40-
41-
return None
42-
43-
44-
def _attach_ep_metadata(module, ep_mesh) -> None:
45-
setattr(module, "_ep_mesh", ep_mesh)
46-
setattr(module, "_ep_group", ep_mesh.get_group())
47-
setattr(module, "_ep_size", ep_mesh.size())
48-
setattr(module, "_ep_rank", ep_mesh.get_local_rank())
49-
50-
5114
def get_ep_wrapped_model(
5215
model,
5316
block_names: list[str],
5417
device_mesh: DeviceMesh,
55-
mp_param_dtype=torch.bfloat16,
56-
mp_reduce_dtype=torch.bfloat16,
18+
mixed_precision_settings: FSDP2MixedPrecisionSettings,
5719
) -> nn.Module:
5820
block_types = []
5921
missing_block_names = []
@@ -76,34 +38,59 @@ def get_ep_wrapped_model(
7638
raise ValueError(f"None of the requested MoE block names were found: {block_names}")
7739

7840
ep_mesh = get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.EP)
79-
MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype)
41+
target_dtype = mixed_precision_settings.param_dtype.value
8042

8143
wrapped_blocks = 0
8244
for module in model.modules():
8345
if isinstance(module, block_types):
84-
ep_target_module = _get_ep_target_module(module)
85-
if ep_target_module is None:
46+
if hasattr(module, "experts"):
47+
ep_target = module
48+
elif (ffn := getattr(module, "ffn", None)) is not None and hasattr(ffn, "experts"):
49+
ep_target = ffn
50+
else:
8651
raise ValueError(
8752
f"Module {type(module).__name__} has no EP-compatible experts location. "
8853
"Expected `experts` or `ffn.experts`."
8954
)
9055

91-
if getattr(ep_target_module, "_ep_enabled", False):
56+
if getattr(ep_target, "_ep_enabled", False):
9257
continue
9358

94-
_validate_moe_block_for_ep(ep_target_module)
95-
_attach_ep_metadata(ep_target_module, ep_mesh)
59+
experts = ep_target.experts
60+
missing = [a for a in ("w1", "w2") if not hasattr(experts, a)]
61+
if missing:
62+
raise ValueError(
63+
f"Module {type(ep_target).__name__}.experts is not grouped-experts compatible. Missing: {missing}"
64+
)
65+
if experts.w1.ndim != 3 or experts.w2.ndim != 3:
66+
raise ValueError(
67+
f"Expected grouped expert parameters with ndim=3. Got w1.ndim={experts.w1.ndim}, "
68+
f"w2.ndim={experts.w2.ndim}"
69+
)
70+
71+
ep_target._ep_mesh = ep_mesh
72+
ep_target._ep_group = ep_mesh.get_group()
73+
ep_target._ep_size = ep_mesh.size()
74+
ep_target._ep_rank = ep_mesh.get_local_rank()
75+
76+
ep_target.experts = ExpertParallel()._apply(ep_target.experts, ep_mesh)
77+
ep_target.experts._ep_enabled = True
9678

97-
ep_target_module.experts = ExpertParallel()._apply(ep_target_module.experts, ep_mesh)
98-
setattr(ep_target_module.experts, "_ep_enabled", True)
79+
for pname, p in list(ep_target.experts._parameters.items()):
80+
if isinstance(p, DTensor) and p.dtype != target_dtype:
81+
local = p.to_local().to(target_dtype)
82+
ep_target.experts._parameters[pname] = nn.Parameter(
83+
DTensor.from_local(local, p.device_mesh, p.placements, run_check=False),
84+
requires_grad=p.requires_grad,
85+
)
9986

10087
wrapped_blocks += 1
10188

10289
if wrapped_blocks == 0:
10390
raise ValueError(f"No blocks matched the requested types: {[t.__name__ for t in block_types]}")
10491

105-
setattr(model, "_ep_wrapped", True)
106-
setattr(model, "_ep_mesh", ep_mesh)
107-
setattr(model, "_ep_num_wrapped_blocks", wrapped_blocks)
92+
model._ep_wrapped = True
93+
model._ep_mesh = ep_mesh
94+
model._ep_num_wrapped_blocks = wrapped_blocks
10895

10996
return model

0 commit comments

Comments
 (0)