Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ settings:
num_seen_steps: 0
num_seen_samples: 0
last_step: -1
debugging:
component_key: debugging
variant_key: settings
config:
enable_determinism: false
forward_hooks:
- instance_key: error_on_nan
pass_type: BY_REFERENCE
- instance_key: print_forward_hook
pass_type: BY_REFERENCE

collate_fn:
component_key: collate_fn
Expand Down Expand Up @@ -398,6 +408,8 @@ gradient_clipper:

progress_subscriber:
component_key: progress_subscriber
# variant_key: dummy
# config: {}
variant_key: rich
config:
global_rank: ${settings.cuda_env.global_rank}
Expand Down Expand Up @@ -432,4 +444,22 @@ mfu_calculator:
pass_type: BY_REFERENCE
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE
pass_type: BY_REFERENCE

error_on_nan:
component_key: model_debugging_hook
variant_key: nan_hook
config:
model:
instance_key: initialized_model
pass_type: BY_REFERENCE
raise_exception: true

print_forward_hook:
component_key: model_debugging_hook
variant_key: print_forward_hook
config:
model:
instance_key: initialized_model
pass_type: BY_REFERENCE
print_shape_only: true
8 changes: 4 additions & 4 deletions src/modalities/checkpointing/checkpoint_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class CheckpointSaving:
"""Class for saving checkpoints based on a savig and execution strategy."""
"""Class for saving checkpoints based on a saving and execution strategy."""

def __init__(
self,
Expand All @@ -28,7 +28,7 @@ def save_checkpoint(
training_progress: TrainingProgress,
evaluation_result: dict[str, EvaluationResultBatch],
app_state: AppState,
early_stoppping_criterion_fulfilled: bool = False,
early_stopping_criterion_fulfilled: bool = False,
):
"""
Saves a checkpoint of the model and optimizer.
Expand All @@ -37,13 +37,13 @@ def save_checkpoint(
training_progress (TrainingProgress): The training progress.
evaluation_result (dict[str, EvaluationResultBatch]): The evaluation result.
app_state (AppState): The application state to be checkpointed.
early_stoppping_criterion_fulfilled (bool, optional):
early_stopping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.
"""
checkpointing_instruction = self.checkpoint_saving_strategy.get_checkpoint_instruction(
training_progress=training_progress,
evaluation_result=evaluation_result,
early_stoppping_criterion_fulfilled=early_stoppping_criterion_fulfilled,
early_stopping_criterion_fulfilled=early_stopping_criterion_fulfilled,
)

self.checkpoint_saving_execution.run_checkpoint_instruction(
Expand Down
12 changes: 6 additions & 6 deletions src/modalities/checkpointing/checkpoint_saving_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: Optional[dict[str, EvaluationResultBatch]] = None,
early_stoppping_criterion_fulfilled: bool = False,
early_stopping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Returns the checkpointing instruction.
Expand All @@ -24,7 +24,7 @@ def get_checkpoint_instruction(
training_progress (TrainingProgress): The training progress.
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stoppping_criterion_fulfilled (bool, optional):
early_stopping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.

Returns:
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
early_stoppping_criterion_fulfilled: bool = False,
early_stopping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Generates a checkpointing instruction based on the given parameters.
Expand All @@ -62,7 +62,7 @@ def get_checkpoint_instruction(
training_progress (TrainingProgress): The training progress.
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stoppping_criterion_fulfilled (bool, optional):
early_stopping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.

Returns:
Expand Down Expand Up @@ -102,7 +102,7 @@ def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
early_stoppping_criterion_fulfilled: bool = False,
early_stopping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Returns a CheckpointingInstruction object.
Expand All @@ -111,7 +111,7 @@ def get_checkpoint_instruction(
training_progress (TrainingProgress): The training progress.
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stoppping_criterion_fulfilled (bool, optional):
early_stopping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.

Returns:
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PydanticAppStateType,
PydanticCheckpointSavingIFType,
PydanticDatasetIFType,
PydanticDebuggingType,
PydanticDeviceMeshIFType,
PydanticGradientClipperIFType,
PydanticLLMDataLoaderIFType,
Expand Down Expand Up @@ -98,6 +99,7 @@ class DCPWarmstartCheckpointPaths(BaseModel):
training_target: TrainingTarget
training_progress: TrainingProgress
warmstart_checkpoint_paths: Optional[WarmstartCheckpointPaths | DCPWarmstartCheckpointPaths] = None
debugging: Optional[PydanticDebuggingType] = None

@model_validator(mode="after")
def _check_tokens_per_step_conistency(self) -> "TrainingComponentsInstantiationModel.Settings":
Expand Down
5 changes: 5 additions & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.utils.debug_components import Debugging
from modalities.utils.mfu import MFUCalculatorABC
from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF
from modalities.utils.profilers.steppable_components import SteppableComponentIF
Expand Down Expand Up @@ -90,3 +91,7 @@ def __get_pydantic_core_schema__(
PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)]
PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)]
PydanticSteppableComponentIFType = Annotated[SteppableComponentIF, PydanticThirdPartyTypeIF(SteppableComponentIF)]
PydanticRemovableHandleType = Annotated[
torch.utils.hooks.RemovableHandle, PydanticThirdPartyTypeIF(torch.utils.hooks.RemovableHandle)
]
PydanticDebuggingType = Annotated[Debugging, PydanticThirdPartyTypeIF(Debugging)]
2 changes: 1 addition & 1 deletion src/modalities/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _run_checkpointing(
training_progress=training_progress,
evaluation_result=None, # TODO implement checkpointing based on preceding evaluation results
app_state=app_state,
early_stoppping_criterion_fulfilled=False, # TODO: implement early stopping
early_stopping_criterion_fulfilled=False, # TODO: implement early stopping
)

def _run_evaluation(
Expand Down
14 changes: 12 additions & 2 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,19 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq
super().__init__()
# this also holds when using TP, since n_embd is the total embedding size and
# n_head is the number of heads globally
dim_model = n_embd // n_head
self.dim_model = n_embd // n_head
self.seq_length_dim = seq_length_dim
inv_freq = 1.0 / (base_freq ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.base_freq = base_freq

self.reset_parameters()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to call this explicitly in the constructor? Wouldn't our weight init routines call it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want the constructor to build a valid and complete instance of the module. In particular for unit testing or checkpoint loading where this initialization component is not used.


def reset_parameters(self):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not init_weights() is non-pytorch function that they call in the train.py and then initialises all modules that contain weights in a recursive fashion.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our architecture, it might be interesting to add something like this but with some "initializer" parameter. So that our initialization component just calls this on the top-level model and gives the chosen initialization method as parameter.
Forcing this method to exist would hopefully help to prevent future modules or models to suffer from the same bug as the rotary transforms did.

# If previously initialized on or moved to a device, reuse that device.
# Otherwise, use the default device of the current environment.
device = self.inv_freq.device if hasattr(self, "inv_freq") else None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would be great why there are two cases where

  1. inv_freq exists and has a device attribute
  2. it does not exist in which case device is set to None

Also the impact of setting device to None below, should be documented.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

inv_freq = 1.0 / (
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
)
self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
Expand Down
11 changes: 11 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
FSDP2DummyGradientClipperConfig,
FSDP2GradientClipperConfig,
)
from modalities.utils.debug_components import Debugging, HookRegistration
from modalities.utils.debugging_configs import DebuggingConfig, NaNHookConfig, PrintForwardHookConfig
from modalities.utils.mfu import GPT2MFUCalculator
from modalities.utils.number_conversion import (
LocalNumBatchesFromNumSamplesConfig,
Expand Down Expand Up @@ -431,4 +433,13 @@ class ComponentEntity:
SteppableForwardPass,
SteppableForwardPassConfig,
),
# Debugging components
ComponentEntity("debugging", "settings", Debugging, DebuggingConfig),
ComponentEntity("model_debugging_hook", "nan_hook", HookRegistration.register_nan_hooks, NaNHookConfig),
ComponentEntity(
"model_debugging_hook",
"print_forward_hook",
HookRegistration.register_print_forward_hooks,
PrintForwardHookConfig,
),
]
100 changes: 100 additions & 0 deletions src/modalities/utils/debug.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use these hooks? Could not find a reference in the code. Would be good to clarify how to use them.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now turned these utility functions into components and also added them to one of the example configs.

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
import os
from contextlib import contextmanager
from typing import Any

import torch

logger = logging.getLogger(__name__)


@contextmanager
def enable_deterministic_cuda():
"""Context manager to enable deterministic CUDA operations and restore previous state."""
prev_cudnn_deterministic = torch.backends.cudnn.deterministic
prev_cudnn_benchmark = torch.backends.cudnn.benchmark
prev_algos = torch.are_deterministic_algorithms_enabled()
prev_cublas_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

try:
yield
finally:
torch.backends.cudnn.deterministic = prev_cudnn_deterministic
torch.backends.cudnn.benchmark = prev_cudnn_benchmark
torch.use_deterministic_algorithms(prev_algos)
if prev_cublas_cfg is None:
os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
else:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = prev_cublas_cfg


def _detect_nan(
module: torch.nn.Module,
module_path: str | None,
target: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...],
target_name: str,
raise_exception: bool,
):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are just logging here. Since we have so much output in the tests and also runs, I would either rename it to has_nan() -> bool and return a flag indicating the presence of NAN or raise an exception additional to the logging

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a parameter to make this raise exceptions. Adding a return is not possible since it is a hook.

if isinstance(target, (list, tuple)):
if any(torch.isnan(o).any() for o in target if isinstance(o, torch.Tensor)):
logger.error(f"NaN detected in {target_name} {module.__class__.__name__}")
if module_path:
logger.error(f"Module path: {module_path}")
if raise_exception:
raise ValueError(f"NaN detected in {target_name} of module {module.__class__.__name__}")
elif isinstance(target, torch.Tensor) and torch.isnan(target).any():
logger.error(f"NaN detected in {target_name} {module.__class__.__name__}")
if module_path:
logger.error(f"Module path: {module_path}")
if raise_exception:
raise ValueError(f"NaN detected in {target_name} of module {module.__class__.__name__}")


def debug_nan_hook(
module: torch.nn.Module,
input: torch.Tensor | tuple[torch.Tensor, ...],
output: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor],
module_path: str | None = None,
raise_exception: bool = False,
):
"""Hook to detect NaN in forward pass"""
_detect_nan(module, module_path, target=input, target_name="input", raise_exception=raise_exception)
_detect_nan(module, module_path, target=output, target_name="output", raise_exception=raise_exception)


def print_forward_hook(
module: torch.nn.Module,
input: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | dict[str, Any],
output: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | dict[str, Any],
module_path: str | None = None,
print_shape_only: bool = False,
):
"""Hook to print input and output shapes during forward pass"""
if isinstance(input, (list, tuple)):
input_shapes = [inp.shape for inp in input if isinstance(inp, torch.Tensor)]
elif isinstance(input, torch.Tensor):
input_shapes = [input.shape]
else:
input_shapes = []

if isinstance(output, (list, tuple)):
output_shapes = [out.shape for out in output if isinstance(out, torch.Tensor)]
elif isinstance(output, torch.Tensor):
output_shapes = [output.shape]
else:
output_shapes = []

print(
f"Module: {module.__class__.__name__}, "
f"Path: {module_path}, "
f"Input shapes: {input_shapes}, "
f"Output shapes: {output_shapes}"
)
if not print_shape_only:
print(f">>> Input:\n{input}")
print(f">>> Output:\n{output}")
Loading