Skip to content

Commit e88b8aa

Browse files
committed
chore: minor cleanup
1 parent 32c8146 commit e88b8aa

3 files changed

Lines changed: 3 additions & 8 deletions

File tree

src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, global_rank: int, allow_partial_load: bool = False):
109109
110110
Args:
111111
global_rank (int): The global rank of the process.
112-
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True.
112+
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to False.
113113
Returns:
114114
None
115115
"""

src/modalities/checkpointing/stateful/app_state_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_raw_app_state(
4747
def get_dcp_checkpointed_app_state_(
4848
raw_app_state: AppState,
4949
checkpoint_dir_path: Path,
50-
allow_partial_load: bool = True,
50+
allow_partial_load: bool = False,
5151
) -> AppState:
5252
"""Loads the checkpointed state dict into the raw AppState object
5353
(i.e., non-checkpoint loaded AppState) in-place.
@@ -56,7 +56,7 @@ def get_dcp_checkpointed_app_state_(
5656
raw_app_state (AppState): The raw AppState object. Its ``components_to_load`` policy
5757
determines which components are restored.
5858
checkpoint_dir_path (Path): The path to the checkpoint directory.
59-
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True.
59+
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to False.
6060
6161
Raises:
6262
RuntimeError: Raises an error if the state dict has already been loaded.

src/modalities/config/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,6 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy:
126126
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)
127127

128128

129-
# class DCPCheckpointLoadingConfig(BaseModel):
130-
# global_rank: Annotated[int, Field(strict=True, ge=0)]
131-
# allow_partial_load: bool = True
132-
133-
134129
class FSDP1CheckpointSavingConfig(BaseModel):
135130
checkpoint_path: Path
136131
global_rank: Annotated[int, Field(strict=True, ge=0)]

0 commit comments

Comments
 (0)