File tree Expand file tree Collapse file tree
src/modalities/models/parallelism Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1414from torch .distributed .pipelining .schedules import (
1515 PipelineScheduleMulti ,
1616 PipelineScheduleSingle ,
17- ScheduleDualPipeV ,
1817 ScheduleZBVZeroBubble ,
1918 get_schedule_class ,
2019)
2120
21+ try :
22+ from torch .distributed .pipelining .schedules import ScheduleDualPipeV
23+ except ImportError :
24+ ScheduleDualPipeV = None
25+
2226from modalities .loss_functions import Loss
2327from modalities .models .model import NNModel
2428from modalities .models .parallelism .stages_generator import StagesGenerator
@@ -152,7 +156,10 @@ def _get_stage_ids_of_pp_rank(
152156 num_stages : int ,
153157 schedule_class : Type [PipelineScheduleSingle | PipelineScheduleMulti ],
154158 ) -> list [int ]:
155- style = "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
159+ v_schedules = [ScheduleZBVZeroBubble ]
160+ if ScheduleDualPipeV is not None :
161+ v_schedules .append (ScheduleDualPipeV )
162+ style = "v" if schedule_class in tuple (v_schedules ) else "loop"
156163 pp_size = pp_mesh .size ()
157164 pp_rank = pp_mesh .get_local_rank ()
158165 stages_per_rank = num_stages // pp_size
You can’t perform that action at this time.
0 commit comments