Skip to content

Commit e0217fb

Browse files
committed
feat: add dedicated EP dimension to device mesh
1 parent 3eb373e commit e0217fb

4 files changed

Lines changed: 18 additions & 29 deletions

File tree

config_files/training/config_lorem_ipsum_long_moe_ep_fsdp2.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ device_mesh:
184184
device_type: cuda
185185
data_parallel_replicate_degree: 1
186186
data_parallel_shard_degree: -1
187-
tensor_parallel_degree: 4
187+
expert_parallel_degree: 4
188188
world_size: ${settings.cuda_env.world_size}
189189

190190
dp_degree:
@@ -238,7 +238,6 @@ ep_model:
238238
device_mesh:
239239
instance_key: device_mesh
240240
pass_type: BY_REFERENCE
241-
ep_mesh_dim_name: tp
242241
block_names: [TransformerBlock]
243242

244243
ac_model:

src/modalities/config/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ class EPWrappedModelConfig(BaseModel):
338338
model: PydanticPytorchModuleOrListType
339339
block_names: list[str]
340340
device_mesh: PydanticDeviceMeshIFType
341-
ep_mesh_dim_name: str | None = None
342341

343342

344343
class DebuggingEnrichedModelConfig(BaseModel):

src/modalities/models/moe/model_factory.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,12 @@
55
import torch.nn as nn
66
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
77
from torch.distributed.device_mesh import DeviceMesh
8+
89
from modalities.models.parallelism.expert_parallelism import ExpertParallel
10+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method
911
from modalities.util import get_module_class_from_name
1012

1113

12-
# TODO refactor these funtions into a utils
13-
def _resolve_ep_mesh(device_mesh: DeviceMesh, ep_mesh_dim_name: str | None) -> DeviceMesh:
14-
mesh_dim_names = tuple(device_mesh.mesh_dim_names or ())
15-
16-
if ep_mesh_dim_name is not None:
17-
if ep_mesh_dim_name not in mesh_dim_names:
18-
raise ValueError(f"ep_mesh_dim_name='{ep_mesh_dim_name}' not in mesh_dim_names={mesh_dim_names}")
19-
return device_mesh[ep_mesh_dim_name]
20-
21-
if len(mesh_dim_names) <= 1:
22-
return device_mesh
23-
24-
raise ValueError(
25-
"DeviceMesh has multiple dimensions. Pass ep_mesh_dim_name explicitly. "
26-
f"Available dimensions: {mesh_dim_names}"
27-
)
28-
29-
3014
def _validate_moe_block_for_ep(module) -> None:
3115
if not hasattr(module, "experts"):
3216
raise ValueError(f"Module {type(module).__name__} has no 'experts' attribute")
@@ -64,16 +48,10 @@ def _attach_ep_metadata(module, ep_mesh) -> None:
6448
setattr(module, "_ep_rank", ep_mesh.get_local_rank())
6549

6650

67-
def _apply_ep(module, ep_mesh) -> None:
68-
module.experts = ExpertParallel()._apply(module.experts, ep_mesh)
69-
setattr(module.experts, "_ep_enabled", True)
70-
71-
7251
def get_ep_wrapped_model(
7352
model,
7453
block_names: list[str],
7554
device_mesh: DeviceMesh,
76-
ep_mesh_dim_name: str | None = None,
7755
mp_param_dtype=torch.bfloat16,
7856
mp_reduce_dtype=torch.bfloat16,
7957
) -> nn.Module:
@@ -97,7 +75,7 @@ def get_ep_wrapped_model(
9775
if len(block_types) == 0:
9876
raise ValueError(f"None of the requested MoE block names were found: {block_names}")
9977

100-
ep_mesh = _resolve_ep_mesh(device_mesh, ep_mesh_dim_name)
78+
ep_mesh = get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.EP)
10179
MixedPrecisionPolicy(param_dtype=mp_param_dtype, reduce_dtype=mp_reduce_dtype)
10280

10381
wrapped_blocks = 0
@@ -115,7 +93,9 @@ def get_ep_wrapped_model(
11593

11694
_validate_moe_block_for_ep(ep_target_module)
11795
_attach_ep_metadata(ep_target_module, ep_mesh)
118-
_apply_ep(ep_target_module, ep_mesh)
96+
97+
ep_target_module.experts = ExpertParallel()._apply(ep_target_module.experts, ep_mesh)
98+
setattr(ep_target_module.experts, "_ep_enabled", True)
11999

120100
wrapped_blocks += 1
121101

src/modalities/running_env/fsdp/device_mesh.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ class DeviceMeshConfig(BaseModel):
2121
tensor_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1
2222
pipeline_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1
2323
context_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1
24+
expert_parallel_degree: Annotated[int, Field(strict=True, gt=0)] = 1
2425
enable_loss_parallel: Optional[bool] = False
2526
world_size: Annotated[int, Field(strict=True, gt=0)]
2627

2728
@model_validator(mode="after")
2829
def _validate(self):
2930
for d in (
3031
self.context_parallel_degree,
32+
self.expert_parallel_degree,
3133
self.tensor_parallel_degree,
3234
self.pipeline_parallel_degree,
3335
):
@@ -50,6 +52,7 @@ def _validate(self):
5052
self.data_parallel_shard_degree = self.world_size // (
5153
self.data_parallel_replicate_degree
5254
* self.context_parallel_degree
55+
* self.expert_parallel_degree
5356
* self.tensor_parallel_degree
5457
* self.pipeline_parallel_degree
5558
)
@@ -58,12 +61,14 @@ def _validate(self):
5861
self.data_parallel_replicate_degree = self.world_size // (
5962
self.data_parallel_shard_degree
6063
* self.context_parallel_degree
64+
* self.expert_parallel_degree
6165
* self.tensor_parallel_degree
6266
* self.pipeline_parallel_degree
6367
)
6468
if (
6569
self.data_parallel_shard_degree
6670
* self.data_parallel_replicate_degree
71+
* self.expert_parallel_degree
6772
* self.tensor_parallel_degree
6873
* self.pipeline_parallel_degree
6974
* self.context_parallel_degree
@@ -72,6 +77,7 @@ def _validate(self):
7277
raise ConfigError(
7378
f"Invalid parallel dims: data_parallel_shard_degree({self.data_parallel_shard_degree}) * "
7479
f"data_parallel_replicate_degree({self.data_parallel_replicate_degree}) * "
80+
f"expert_parallel_degree({self.expert_parallel_degree}) * "
7581
f"tensor_parallel_degree({self.tensor_parallel_degree}) *"
7682
f"* pipeline_parallel_degree({self.pipeline_parallel_degree}) *"
7783
f"context_parallel_degree({self.context_parallel_degree})!= WORLD_SIZE({self.world_size})"
@@ -85,6 +91,7 @@ class ParallelismDegrees(Enum):
8591
DP_REPLICATE = "dp_replicate"
8692
DP_SHARD = "dp_shard"
8793
CP = "cp"
94+
EP = "ep"
8895
TP = "tp"
8996
PP = "pp"
9097

@@ -96,6 +103,7 @@ def get_device_mesh(
96103
tensor_parallel_degree: int,
97104
pipeline_parallel_degree: int,
98105
context_parallel_degree: int,
106+
expert_parallel_degree: int,
99107
enable_loss_parallel: bool,
100108
world_size: int,
101109
) -> DeviceMesh:
@@ -109,6 +117,7 @@ def get_device_mesh(
109117
tensor_parallel_degree (int): The tensor parallel degree.
110118
pipeline_parallel_degree (int): The pipeline parallel degree.
111119
context_parallel_degree (int): The context parallel degree.
120+
expert_parallel_degree (int): The expert parallel degree.
112121
enable_loss_parallel (bool): Whether to enable loss parallelism.
113122
world_size (int): The world size.
114123
@@ -123,13 +132,15 @@ def get_device_mesh(
123132
data_parallel_replicate_degree,
124133
data_parallel_shard_degree,
125134
context_parallel_degree,
135+
expert_parallel_degree,
126136
tensor_parallel_degree,
127137
],
128138
[
129139
ParallelismDegrees.PP.value,
130140
ParallelismDegrees.DP_REPLICATE.value,
131141
ParallelismDegrees.DP_SHARD.value,
132142
ParallelismDegrees.CP.value,
143+
ParallelismDegrees.EP.value,
133144
ParallelismDegrees.TP.value,
134145
],
135146
strict=True,

0 commit comments

Comments
 (0)