Skip to content

Commit b889972

Browse files
committed
feat: Allow gloo process group for checkpointing
1 parent 7337fe4 commit b889972

2 files changed

Lines changed: 35 additions & 1 deletion

File tree

src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import inspect
23
from enum import Enum
34
from pathlib import Path
45

@@ -46,6 +47,7 @@ def __init__(
4647
checkpoint_path: Path,
4748
experiment_id: str,
4849
global_rank: int,
50+
use_gloo_process_group_for_planning: bool = False,
4951
):
5052
"""
5153
Initializes the FSDPCheckpointSaving class.
@@ -54,13 +56,16 @@ def __init__(
5456
checkpoint_path (Path): folder path to the checkpoint
5557
experiment_id (str): ID of the experiment
5658
global_rank (int): global rank within the current process group
59+
use_gloo_process_group_for_planning (bool): Whether to use a temporary
60+
Gloo process group for DCP planning collectives.
5761
5862
Returns:
5963
None
6064
"""
6165
self.checkpoint_path = checkpoint_path
6266
self.global_rank = global_rank
6367
self.experiment_id = experiment_id
68+
self.use_gloo_process_group_for_planning = use_gloo_process_group_for_planning
6469

6570
def _get_checkpointing_path(
6671
self,
@@ -193,6 +198,7 @@ def __init__(
193198
checkpoint_path: Path,
194199
experiment_id: str,
195200
global_rank: int,
201+
use_gloo_process_group_for_planning: bool = False,
196202
):
197203
"""
198204
Initializes the FSDP2CheckpointSaving class.
@@ -201,13 +207,16 @@ def __init__(
201207
checkpoint_path (Path): folder path to the checkpoint
202208
experiment_id (str): ID of the experiment
203209
global_rank (int): global rank within the current process group
210+
use_gloo_process_group_for_planning (bool): Whether to use a temporary
211+
Gloo process group for DCP planning collectives.
204212
205213
Returns:
206214
None
207215
"""
208216
self.checkpoint_path = checkpoint_path
209217
self.global_rank = global_rank
210218
self.experiment_id = experiment_id
219+
self.use_gloo_process_group_for_planning = use_gloo_process_group_for_planning
211220

212221
def _get_checkpointing_folder_path(
213222
self,
@@ -243,7 +252,31 @@ def _save_checkpoint(self, app_state: AppState, training_progress: TrainingProgr
243252
distributed_checkpoint_path.mkdir(parents=True, exist_ok=True)
244253
get_logger().info(f"Saving distributed model checkpoint to {distributed_checkpoint_path}...")
245254
state_dict = {"app": app_state}
246-
dcp.save(state_dict, checkpoint_id=distributed_checkpoint_path)
255+
# DCP performs object collectives while planning writes. On some clusters, these
256+
# fail through NCCL; prefer a dedicated Gloo group for planning collectives.
257+
dcp_kwargs = {}
258+
gloo_group = None
259+
try:
260+
save_sig = inspect.signature(dcp.save)
261+
if (
262+
self.use_gloo_process_group_for_planning
263+
and "process_group" in save_sig.parameters
264+
and dist.is_initialized()
265+
):
266+
try:
267+
gloo_group = dist.new_group(backend="gloo")
268+
dcp_kwargs["process_group"] = gloo_group
269+
get_logger().info("Using Gloo process group for DCP metadata collectives.")
270+
except Exception as e:
271+
get_logger().warning(f"Could not create Gloo process group for DCP; using default group. {e}")
272+
273+
dcp.save(state_dict, checkpoint_id=distributed_checkpoint_path, **dcp_kwargs)
274+
finally:
275+
if gloo_group is not None:
276+
try:
277+
dist.destroy_process_group(gloo_group)
278+
except Exception as e:
279+
get_logger().warning(f"Failed to destroy temporary Gloo process group: {e}")
247280
get_logger().info("Distributed checkpoint saved.")
248281

249282
if self.global_rank == 0:

src/modalities/config/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class DCPCheckpointSavingConfig(BaseModel):
138138
checkpoint_path: Path
139139
global_rank: Annotated[int, Field(strict=True, ge=0)]
140140
experiment_id: str
141+
use_gloo_process_group_for_planning: bool = False
141142

142143

143144
class CheckpointSavingConfig(BaseModel):

0 commit comments

Comments
 (0)