Skip to content

Commit 4ea6fab

Browse files
committed
fix: Make pp related import optional
1 parent b889972 commit 4ea6fab

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

src/modalities/models/parallelism/pipeline_parallelism.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
from torch.distributed.pipelining.schedules import (
1515
PipelineScheduleMulti,
1616
PipelineScheduleSingle,
17-
ScheduleDualPipeV,
1817
ScheduleZBVZeroBubble,
1918
get_schedule_class,
2019
)
2120

21+
try:
22+
from torch.distributed.pipelining.schedules import ScheduleDualPipeV
23+
except ImportError:
24+
ScheduleDualPipeV = None
25+
2226
from modalities.loss_functions import Loss
2327
from modalities.models.model import NNModel
2428
from modalities.models.parallelism.stages_generator import StagesGenerator
@@ -152,7 +156,10 @@ def _get_stage_ids_of_pp_rank(
152156
num_stages: int,
153157
schedule_class: Type[PipelineScheduleSingle | PipelineScheduleMulti],
154158
) -> list[int]:
155-
style = "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
159+
v_schedules = [ScheduleZBVZeroBubble]
160+
if ScheduleDualPipeV is not None:
161+
v_schedules.append(ScheduleDualPipeV)
162+
style = "v" if schedule_class in tuple(v_schedules) else "loop"
156163
pp_size = pp_mesh.size()
157164
pp_rank = pp_mesh.get_local_rank()
158165
stages_per_rank = num_stages // pp_size

0 commit comments

Comments
 (0)