11import json
2+ import inspect
23from enum import Enum
34from pathlib import Path
45
@@ -46,6 +47,7 @@ def __init__(
4647 checkpoint_path : Path ,
4748 experiment_id : str ,
4849 global_rank : int ,
50+ use_gloo_process_group_for_planning : bool = False ,
4951 ):
5052 """
5153 Initializes the FSDPCheckpointSaving class.
@@ -54,13 +56,16 @@ def __init__(
5456 checkpoint_path (Path): folder path to the checkpoint
5557 experiment_id (str): ID of the experiment
5658 global_rank (int): global rank within the current process group
59+ use_gloo_process_group_for_planning (bool): Whether to use a temporary
60+ Gloo process group for DCP planning collectives.
5761
5862 Returns:
5963 None
6064 """
6165 self .checkpoint_path = checkpoint_path
6266 self .global_rank = global_rank
6367 self .experiment_id = experiment_id
68+ self .use_gloo_process_group_for_planning = use_gloo_process_group_for_planning
6469
6570 def _get_checkpointing_path (
6671 self ,
@@ -193,6 +198,7 @@ def __init__(
193198 checkpoint_path : Path ,
194199 experiment_id : str ,
195200 global_rank : int ,
201+ use_gloo_process_group_for_planning : bool = False ,
196202 ):
197203 """
198204 Initializes the FSDP2CheckpointSaving class.
@@ -201,13 +207,16 @@ def __init__(
201207 checkpoint_path (Path): folder path to the checkpoint
202208 experiment_id (str): ID of the experiment
203209 global_rank (int): global rank within the current process group
210+ use_gloo_process_group_for_planning (bool): Whether to use a temporary
211+ Gloo process group for DCP planning collectives.
204212
205213 Returns:
206214 None
207215 """
208216 self .checkpoint_path = checkpoint_path
209217 self .global_rank = global_rank
210218 self .experiment_id = experiment_id
219+ self .use_gloo_process_group_for_planning = use_gloo_process_group_for_planning
211220
212221 def _get_checkpointing_folder_path (
213222 self ,
@@ -243,7 +252,31 @@ def _save_checkpoint(self, app_state: AppState, training_progress: TrainingProgr
243252 distributed_checkpoint_path .mkdir (parents = True , exist_ok = True )
244253 get_logger ().info (f"Saving distributed model checkpoint to { distributed_checkpoint_path } ..." )
245254 state_dict = {"app" : app_state }
246- dcp .save (state_dict , checkpoint_id = distributed_checkpoint_path )
255+ # DCP performs object collectives while planning writes. On some clusters, these
256+ # fail through NCCL; prefer a dedicated Gloo group for planning collectives.
257+ dcp_kwargs = {}
258+ gloo_group = None
259+ try :
260+ save_sig = inspect .signature (dcp .save )
261+ if (
262+ self .use_gloo_process_group_for_planning
263+ and "process_group" in save_sig .parameters
264+ and dist .is_initialized ()
265+ ):
266+ try :
267+ gloo_group = dist .new_group (backend = "gloo" )
268+ dcp_kwargs ["process_group" ] = gloo_group
269+ get_logger ().info ("Using Gloo process group for DCP metadata collectives." )
270+ except Exception as e :
271+ get_logger ().warning (f"Could not create Gloo process group for DCP; using default group. { e } " )
272+
273+ dcp .save (state_dict , checkpoint_id = distributed_checkpoint_path , ** dcp_kwargs )
274+ finally :
275+ if gloo_group is not None :
276+ try :
277+ dist .destroy_process_group (gloo_group )
278+ except Exception as e :
279+ get_logger ().warning (f"Failed to destroy temporary Gloo process group: { e } " )
247280 get_logger ().info ("Distributed checkpoint saved." )
248281
249282 if self .global_rank == 0 :
0 commit comments