11import warnings
22
3- import torch
43import torch .distributed as dist
54import torch .nn as nn
6- from torch .distributed ._composable .fsdp import MixedPrecisionPolicy
75from torch .distributed .device_mesh import DeviceMesh
6+ from torch .distributed .tensor import DTensor
87
98from modalities .models .parallelism .expert_parallelism import ExpertParallel
9+ from modalities .running_env .env_utils import FSDP2MixedPrecisionSettings
1010from modalities .running_env .fsdp .device_mesh import ParallelismDegrees , get_mesh_for_parallelism_method
1111from modalities .util import get_module_class_from_name
1212
1313
14- def _validate_moe_block_for_ep (module ) -> None :
15- if not hasattr (module , "experts" ):
16- raise ValueError (f"Module { type (module ).__name__ } has no 'experts' attribute" )
17-
18- experts = module .experts
19- required_attrs = ["w1" , "w2" ]
20- missing = [attr for attr in required_attrs if not hasattr (experts , attr )]
21- if missing :
22- raise ValueError (
23- f"Module { type (module ).__name__ } .experts is not grouped-experts compatible. Missing: { missing } "
24- )
25-
26- if experts .w1 .ndim != 3 or experts .w2 .ndim != 3 :
27- raise ValueError (
28- f"Expected grouped expert parameters with ndim=3. Got w1.ndim={ experts .w1 .ndim } , "
29- f"w2.ndim={ experts .w2 .ndim } "
30- )
31-
32-
33- def _get_ep_target_module (module ):
34- if hasattr (module , "experts" ):
35- return module
36-
37- ffn = getattr (module , "ffn" , None )
38- if ffn is not None and hasattr (ffn , "experts" ):
39- return ffn
40-
41- return None
42-
43-
44- def _attach_ep_metadata (module , ep_mesh ) -> None :
45- setattr (module , "_ep_mesh" , ep_mesh )
46- setattr (module , "_ep_group" , ep_mesh .get_group ())
47- setattr (module , "_ep_size" , ep_mesh .size ())
48- setattr (module , "_ep_rank" , ep_mesh .get_local_rank ())
49-
50-
5114def get_ep_wrapped_model (
5215 model ,
5316 block_names : list [str ],
5417 device_mesh : DeviceMesh ,
55- mp_param_dtype = torch .bfloat16 ,
56- mp_reduce_dtype = torch .bfloat16 ,
18+ mixed_precision_settings : FSDP2MixedPrecisionSettings ,
5719) -> nn .Module :
5820 block_types = []
5921 missing_block_names = []
@@ -76,34 +38,59 @@ def get_ep_wrapped_model(
7638 raise ValueError (f"None of the requested MoE block names were found: { block_names } " )
7739
7840 ep_mesh = get_mesh_for_parallelism_method (device_mesh , ParallelismDegrees .EP )
79- MixedPrecisionPolicy ( param_dtype = mp_param_dtype , reduce_dtype = mp_reduce_dtype )
41+ target_dtype = mixed_precision_settings . param_dtype . value
8042
8143 wrapped_blocks = 0
8244 for module in model .modules ():
8345 if isinstance (module , block_types ):
84- ep_target_module = _get_ep_target_module (module )
85- if ep_target_module is None :
46+ if hasattr (module , "experts" ):
47+ ep_target = module
48+ elif (ffn := getattr (module , "ffn" , None )) is not None and hasattr (ffn , "experts" ):
49+ ep_target = ffn
50+ else :
8651 raise ValueError (
8752 f"Module { type (module ).__name__ } has no EP-compatible experts location. "
8853 "Expected `experts` or `ffn.experts`."
8954 )
9055
91- if getattr (ep_target_module , "_ep_enabled" , False ):
56+ if getattr (ep_target , "_ep_enabled" , False ):
9257 continue
9358
94- _validate_moe_block_for_ep (ep_target_module )
95- _attach_ep_metadata (ep_target_module , ep_mesh )
59+ experts = ep_target .experts
60+ missing = [a for a in ("w1" , "w2" ) if not hasattr (experts , a )]
61+ if missing :
62+ raise ValueError (
63+ f"Module { type (ep_target ).__name__ } .experts is not grouped-experts compatible. Missing: { missing } "
64+ )
65+ if experts .w1 .ndim != 3 or experts .w2 .ndim != 3 :
66+ raise ValueError (
67+ f"Expected grouped expert parameters with ndim=3. Got w1.ndim={ experts .w1 .ndim } , "
68+ f"w2.ndim={ experts .w2 .ndim } "
69+ )
70+
71+ ep_target ._ep_mesh = ep_mesh
72+ ep_target ._ep_group = ep_mesh .get_group ()
73+ ep_target ._ep_size = ep_mesh .size ()
74+ ep_target ._ep_rank = ep_mesh .get_local_rank ()
75+
76+ ep_target .experts = ExpertParallel ()._apply (ep_target .experts , ep_mesh )
77+ ep_target .experts ._ep_enabled = True
9678
97- ep_target_module .experts = ExpertParallel ()._apply (ep_target_module .experts , ep_mesh )
98- setattr (ep_target_module .experts , "_ep_enabled" , True )
79+ for pname , p in list (ep_target .experts ._parameters .items ()):
80+ if isinstance (p , DTensor ) and p .dtype != target_dtype :
81+ local = p .to_local ().to (target_dtype )
82+ ep_target .experts ._parameters [pname ] = nn .Parameter (
83+ DTensor .from_local (local , p .device_mesh , p .placements , run_check = False ),
84+ requires_grad = p .requires_grad ,
85+ )
9986
10087 wrapped_blocks += 1
10188
10289 if wrapped_blocks == 0 :
10390 raise ValueError (f"No blocks matched the requested types: { [t .__name__ for t in block_types ]} " )
10491
105- setattr ( model , " _ep_wrapped" , True )
106- setattr ( model , " _ep_mesh" , ep_mesh )
107- setattr ( model , " _ep_num_wrapped_blocks" , wrapped_blocks )
92+ model . _ep_wrapped = True
93+ model . _ep_mesh = ep_mesh
94+ model . _ep_num_wrapped_blocks = wrapped_blocks
10895
10996 return model
0 commit comments