@@ -21,13 +21,15 @@ class DeviceMeshConfig(BaseModel):
2121 tensor_parallel_degree : Annotated [int , Field (strict = True , gt = 0 )] = 1
2222 pipeline_parallel_degree : Annotated [int , Field (strict = True , gt = 0 )] = 1
2323 context_parallel_degree : Annotated [int , Field (strict = True , gt = 0 )] = 1
24+ expert_parallel_degree : Annotated [int , Field (strict = True , gt = 0 )] = 1
2425 enable_loss_parallel : Optional [bool ] = False
2526 world_size : Annotated [int , Field (strict = True , gt = 0 )]
2627
2728 @model_validator (mode = "after" )
2829 def _validate (self ):
2930 for d in (
3031 self .context_parallel_degree ,
32+ self .expert_parallel_degree ,
3133 self .tensor_parallel_degree ,
3234 self .pipeline_parallel_degree ,
3335 ):
@@ -50,6 +52,7 @@ def _validate(self):
5052 self .data_parallel_shard_degree = self .world_size // (
5153 self .data_parallel_replicate_degree
5254 * self .context_parallel_degree
55+ * self .expert_parallel_degree
5356 * self .tensor_parallel_degree
5457 * self .pipeline_parallel_degree
5558 )
@@ -58,12 +61,14 @@ def _validate(self):
5861 self .data_parallel_replicate_degree = self .world_size // (
5962 self .data_parallel_shard_degree
6063 * self .context_parallel_degree
64+ * self .expert_parallel_degree
6165 * self .tensor_parallel_degree
6266 * self .pipeline_parallel_degree
6367 )
6468 if (
6569 self .data_parallel_shard_degree
6670 * self .data_parallel_replicate_degree
71+ * self .expert_parallel_degree
6772 * self .tensor_parallel_degree
6873 * self .pipeline_parallel_degree
6974 * self .context_parallel_degree
@@ -72,6 +77,7 @@ def _validate(self):
7277 raise ConfigError (
7378 f"Invalid parallel dims: data_parallel_shard_degree({ self .data_parallel_shard_degree } ) * "
7479 f"data_parallel_replicate_degree({ self .data_parallel_replicate_degree } ) * "
80+ f"expert_parallel_degree({ self .expert_parallel_degree } ) * "
7581 f"tensor_parallel_degree({ self .tensor_parallel_degree } ) *"
7682 f"* pipeline_parallel_degree({ self .pipeline_parallel_degree } ) *"
7783 f"context_parallel_degree({ self .context_parallel_degree } )!= WORLD_SIZE({ self .world_size } )"
@@ -85,6 +91,7 @@ class ParallelismDegrees(Enum):
8591 DP_REPLICATE = "dp_replicate"
8692 DP_SHARD = "dp_shard"
8793 CP = "cp"
94+ EP = "ep"
8895 TP = "tp"
8996 PP = "pp"
9097
@@ -96,6 +103,7 @@ def get_device_mesh(
96103 tensor_parallel_degree : int ,
97104 pipeline_parallel_degree : int ,
98105 context_parallel_degree : int ,
106+ expert_parallel_degree : int ,
99107 enable_loss_parallel : bool ,
100108 world_size : int ,
101109) -> DeviceMesh :
@@ -109,6 +117,7 @@ def get_device_mesh(
109117 tensor_parallel_degree (int): The tensor parallel degree.
110118 pipeline_parallel_degree (int): The pipeline parallel degree.
111119 context_parallel_degree (int): The context parallel degree.
120+ expert_parallel_degree (int): The expert parallel degree.
112121 enable_loss_parallel (bool): Whether to enable loss parallelism.
113122 world_size (int): The world size.
114123
@@ -123,13 +132,15 @@ def get_device_mesh(
123132 data_parallel_replicate_degree ,
124133 data_parallel_shard_degree ,
125134 context_parallel_degree ,
135+ expert_parallel_degree ,
126136 tensor_parallel_degree ,
127137 ],
128138 [
129139 ParallelismDegrees .PP .value ,
130140 ParallelismDegrees .DP_REPLICATE .value ,
131141 ParallelismDegrees .DP_SHARD .value ,
132142 ParallelismDegrees .CP .value ,
143+ ParallelismDegrees .EP .value ,
133144 ParallelismDegrees .TP .value ,
134145 ],
135146 strict = True ,
0 commit comments