Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,8 @@ def fsdp_config(self, value: FSDPConfig | FSDP2Config):
self._fsdp_config = value
self._fsdp2_config = None
elif isinstance(value, FSDP2Config):
self._fsdp2_config = value
self._fsdp_config = None
self._fsdp2_config = value
else:
raise TypeError(f'Expected value to be of type FSDPConfig or FSDP2Config, but got {type(value)}.')

Expand Down
31 changes: 2 additions & 29 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_mixed_precision,
set_custom_fsdp_module_kwargs,
)
from composer.distributed.shared_utils import add_fsdp_oom_hooks
from composer.utils import FSDPConfig, StringEnum, TPConfig, dist, ensure_tuple, get_device

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']
Expand Down Expand Up @@ -262,22 +263,6 @@ def prepare_fsdp_module(
# Handles of FSDP sync hooks if automicrobatching is on
hook_handles = []

# Check if other ranks OOMed after forward/backward pass when using auto microbatching. This
# may happen when close to memory limit or with uneven memory usage across ranks. Since we
# need to do this before the model weights are gathered for the next FSDP block, we wrap every
# FSPD block with a hook that checks if any other rank OOMed.
def sync_hook(*args):
# Check if any other rank hit an OOM
found_cuda_oom_tensor = device.tensor_to_device(torch.tensor([0], dtype=torch.uint8))
dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = device.tensor_to_device(torch.tensor([0], dtype=torch.uint8))
dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')

# Necessary variables for optimizers with multiple param groups in FSDP
param_name_to_group_num = None
group_num_to_opt_group_info = None
Expand Down Expand Up @@ -492,20 +477,8 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
log.info(f'Calling prepare_te_modules_for_fsdp to enable TE weights sharding')
prepare_te_modules_for_fsdp(fsdp_obj)

# The following sync hooks are added to prevent FSDP deadlocks that are caused when some ranks OOM
# and other ranks do not OOM, leading to OOMing ranks calling all_reduce to wait on the non-OOMing
# ranks and the non-OOMing ranks calling all_gatherbase to continue with FSDP training:
#
# forward_pre_hook: before forwards of FSDP modules
# full_backward_pre_hook: before backwards of FSDP modules
# full_backward_hook: before a prefetched unshard called by FSDP's `post_backward_reshard`
if auto_microbatching:
for _, module in fsdp_obj.named_modules():
if isinstance(module, FullyShardedDataParallel):
hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True))
hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True))
else:
hook_handles.append(module.register_full_backward_hook(sync_hook))
hook_handles = add_fsdp_oom_hooks(fsdp_obj, device=device)
fsdp_obj_named_modules.update(dict(fsdp_obj.named_modules()))

if hasattr(fsdp_obj, '_exec_order_data'):
Expand Down
157 changes: 157 additions & 0 deletions composer/distributed/shared_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Shared utilities for distributed training."""

import functools
from typing import Callable, Optional

import torch
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.utils.hooks import RemovableHandle
from torchmetrics import Metric, MetricCollection

from composer.devices import Device
from composer.models import ComposerModel
from composer.utils import dist, get_device


def get_direct_children_from_composer_model(model: ComposerModel) -> list[torch.nn.Module]:
"""Returns a list of valid direct children from a ComposerModel.

A valid direct child for a ComposerModel is a module that's not a Metric or MetricCollection.

Returns:
list: List of valid direct children from a ComposerModel.
"""
assert isinstance(model, ComposerModel)
direct_children = []
for child in model.children():
if isinstance(child, (Metric, MetricCollection)):
continue
direct_children.append(child)

return direct_children


def generate_oom_hook(device: Device) -> Callable:
Comment thread
rithwik-db marked this conversation as resolved.
"""Generate a hook that checks if any other rank hit an OOM.

Note: This isn't supported for FSDP2 yet. For more details view the draft PR:
https://github.com/mosaicml/composer/pull/3866

We check if other ranks OOMed after forward/backward pass when using auto microbatching. This
may happen when close to memory limit or with uneven memory usage across ranks. Since we
need to do this before the model weights are gathered for the next FSDP1 block, we wrap every
FSDP1 block with a hook that checks if any other rank OOMed.

Here's an example of why this is needed using a simple 2-GPU setup and how it handles OOM issues during auto microbatching.

Note: The line numbers below can be (slightly) off based on future changes made to the code.

- Rank 0: Layer 1 works fine
- Rank 1: Layer 1 works fine
- Rank 0: Layer 2 OOMs
- Rank 0 raises an error _is_cuda_oom() [[trainer.py:2756]]
- Rank 0 sets found_cuda_oom to 1 [[trainer.py:2758]]
- Rank 0 creates found_cuda_oom_tensor = [1] and calls all_reduce on it with reduce_operation='MAX' [[trainer.py:2773]]
- Rank 1: Layer 2 works fine until a hook handle is hit
- Rank 1 sets found_cuda_oom_tensor = [0] [[shared_utils.py:85]]
- Rank 1 calls all_reduce to set found_cuda_oom_tensor to max([0, 1]) = 1 [[shared_utils.py:86]]
- Rank 1 sees that found_cuda_oom == 1 [[shared_utils.py:87]]
- Rank 0:
- Rank 0 creates all_ranks_finished_tensor = [1] and calls all_reduce on it with reduce_operation='MIN' [[trainer.py:2780]]
- Rank 0 sees that all_ranks_finished == 0 (since rank 1 is still in mid-batch) [[trainer.py:2781]]
- Rank 0 continues in the (while not all_ranks_finished) loop [[trainer.py:2771]]
- Rank 1:
- Rank 1 creates all_ranks_finished_tensor = [0] and calls all_reduce on it with reduce_operation='MIN' [[shared_utils.py:89]]
- Rank 1 sees that all_ranks_finished == 0 (since this rank is still in the batch) [[shared_utils.py:90]]
- Rank 1 sees that found_cuda_oom == 1, so it raises an error saying that a different rank OOMed [[shared_utils.py:93]]
- Rank 0:
- In the next round of the while loop, found_cuda_oom_tensor = [1] and calls all_reduce on it with reduce_operation='MAX' [[trainer.py:2773]]
- Rank 1:
- Rank 1 sees the error that was raised earlier (OOM on other rank) and sets found_cuda_oom to 1 [[trainer.py:2755]]
- Rank 1 creates found_cuda_oom_tensor = [1] and calls all_reduce on it with reduce_operation='MAX' [[trainer.py:2773]]
- As expected, found_cuda_oom == 1 [[trainer.py:2776]]
- Rank 0:
- Rank 0 creates all_ranks_finished_tensor = [1] (since it's in the same while loop as before) and calls all_reduce on it with reduce_operation='MIN' [[trainer.py:2780]]
- Rank 0 sees that all_ranks_finished = 1 (as we are in the same part of the trainer code as Rank 1, Rank 1 returns the same value) [[trainer.py:2782]]
- Rank 0 exits the while loop and adjusts the device_train_microbatch_size to half of the previous value [[trainer.py:2790]]
- Rank 1:
- Rank 1 creates all_ranks_finished_tensor = [1] (since it's finished the batch with an error) and calls all_reduce on it with reduce_operation='MIN' [[trainer.py:2780]]
- Rank 1 sees that all_ranks_finished == 1 (since this rank is finished the batch) [[trainer.py:2781]]
- Rank 1 exits the while loop and adjusts the device_train_microbatch_size to half of the previous value [[trainer.py:2790]]

Args:
device (torch.device): The device to check for OOM.

Returns:
Callable: The hook that checks if any other rank hit an OOM.
"""

def sync_hook(*args, device: Device):
# Check if any other rank hit an OOM
found_cuda_oom_tensor = device.tensor_to_device(torch.tensor([0], dtype=torch.uint8))
dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = device.tensor_to_device(torch.tensor([0], dtype=torch.uint8))
dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')

return functools.partial(sync_hook, device=device)


def add_fsdp_oom_hooks(model: torch.nn.Module, device: Optional[Device] = None) -> list[RemovableHandle]:
"""Add OOM hooks to the FSDP1-wrapped model and return the list of handles.

Note: This isn't supported for FSDP2 yet. For more details view the draft PR:
https://github.com/mosaicml/composer/pull/3866

The following sync hooks are added to prevent FSDP1 deadlocks that are caused when some ranks OOM
and other ranks do not OOM, leading to OOMing ranks calling all_reduce to wait on the non-OOMing
ranks and the non-OOMing ranks calling all_gatherbase to continue with FSDP training:

forward_pre_hook: before forwards of FSDP1 modules
full_backward_pre_hook: before backwards of FSDP1 modules
full_backward_hook: before a prefetched unshard called by FSDP1's `post_backward_reshard`

View https://github.com/mosaicml/composer/pull/3510 for more details.
Comment thread
rithwik-db marked this conversation as resolved.

Args:
model (torch.nn.Module): The model to add the hooks to. This can be a ComposerModel and in that scenario, we need to add hooks to valid children.
device (torch.device): The device that the module is on. If None, the current rank's device will be used.

Returns:
list[RemovableHandle]: The list of RemovableHandles for the hooks.
"""
hook_handles = []
if device is None:
device = get_device()
hook = generate_oom_hook(device)

# Gets the valid children if the input is a ComposerModel
root_modules_for_hooks = []
if isinstance(model, ComposerModel):
root_modules_for_hooks = get_direct_children_from_composer_model(model)
else:
root_modules_for_hooks.append(model)

# TODO: In FSDP1, we might not need the non-FSDP wrapped backward hook either, but we'll keep it for now until further investigation.
# TODO: If we want to reduce as many potential deadlocks as possible, we may need to add hooks before all blocking collectives:
# - register_forward_pre_hook (before blocking all_gather)
# - register_full_backward_pre_hook (before blocking all_gather)
# - register_full_backward_hook (before blocking reduce_scatter)
# In all of these cases, some combination of no activation checkpointing/offloading, reshard_after_forward=False, or high gradient memory cost
# could result in edge-case OOMs and deadlocks.
for root_module in root_modules_for_hooks:
for module in root_module.modules():
if isinstance(module, FullyShardedDataParallel):
hook_handles.append(module.register_forward_pre_hook(hook, prepend=True)) # type: ignore
hook_handles.append(module.register_full_backward_pre_hook(hook, prepend=True)) # type: ignore
else:
hook_handles.append(module.register_full_backward_hook(hook)) # type: ignore

return hook_handles
48 changes: 17 additions & 31 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.shared_utils import generate_oom_hook
from composer.loggers import (
ConsoleLogger,
Logger,
Expand Down Expand Up @@ -427,32 +428,6 @@ def _update_num_consecutive_thrashes(state: State, num_consecutive_thrashes: int
return num_consecutive_thrashes


def _create_sync_hook(state: State):
"""Check if other ranks OOMed after forward/backward pass when using auto microbatching.

This may happen when close to memory limit or with uneven memory usage across ranks. Since we
need to do this before the model weights are gathered for the next FSDP block, we wrap every
FSPD block with a hook that checks if any other rank OOMed.

This wrapper method is needed because PyTorch FSDP doesn't take `state` as an argument in hooks
that are registered using methods such as `register_forward_pre_hook`.
"""

def sync_hook(*args):
# Check if any other rank hit an OOM
found_cuda_oom_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8))
dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8))
dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError()

return sync_hook


def _readd_fsdp_sync_hooks(fsdp_modules: dict[str, torch.nn.Module], sync_hook):
"""Readds previously removed sync hooks back to FSDP modules.

Expand Down Expand Up @@ -1220,6 +1195,11 @@ def __init__(

# Distributed
parallelism_config = self._parse_parallelism_config(parallelism_config)
if parallelism_config is not None and parallelism_config.fsdp is None and auto_microbatching:
raise ValueError(
'Auto microbatching is not supported outside of FSDP1. '
'Please set a reasonable microbatch size manually or enable FSDP1.',
)
if parallelism_config is not None or dist.get_world_size() > 1:
# FSDP requires torch.distributed to be initialized, even if the world size is 1
# And torch.distributed is always required for multi-rank training
Expand Down Expand Up @@ -1859,6 +1839,7 @@ def _wrap_model_for_distributed(
self.state.seed,
)
case 2:
assert not auto_microbatching
Comment thread
rithwik-db marked this conversation as resolved.
parallelize_composer_model(
model,
optimizers,
Expand Down Expand Up @@ -2713,7 +2694,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]:
device_batch = self.state.batch

# Define sync hook for FSDP modules if automicrobatching is on
sync_hook = _create_sync_hook(self.state)
sync_hook = generate_oom_hook(self.state.device)

original_microbatch_size = self.state.device_train_microbatch_size
oom_found_this_batch = False
Expand Down Expand Up @@ -3619,9 +3600,12 @@ def _eval_loop(

# If training occurs after evaluation, readd hooks in case of memory spike
if self.state.auto_microbatching:
sync_hook = _create_sync_hook(self.state)
sync_hook = generate_oom_hook(self.state.device)
if self.state.fsdp_enabled and len(self.state.automicrobatch_fsdp_hook_handles) == 0:
self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(self.state.fsdp_modules, sync_hook)
self.state.automicrobatch_fsdp_hook_handles = _readd_fsdp_sync_hooks(
self.state.fsdp_modules,
sync_hook,
)
self.num_consecutive_non_OOM_batches = 0

def _use_grad_scaling(self, precision: Union[str, Precision], scaler: Optional[GradScaler]) -> bool:
Expand Down Expand Up @@ -3821,10 +3805,12 @@ def _parse_parallelism_config(
if parallelism_config is not None and not isinstance(parallelism_config, ParallelismConfig):
parallelism_config_args = {}
if 'fsdp' in parallelism_config and parallelism_config['fsdp'] is not None:
if isinstance(parallelism_config['fsdp'], FSDPConfig | FSDP2Config):
if isinstance(parallelism_config['fsdp'], FSDPConfig):
parallelism_config_args['fsdp'] = parallelism_config['fsdp']
elif isinstance(parallelism_config['fsdp'], FSDP2Config):
parallelism_config_args['fsdp2'] = parallelism_config['fsdp']
elif os.environ.get('FSDP_VERSION', '1') == '2':
parallelism_config_args['fsdp'] = FSDP2Config.from_compatible_attrs(parallelism_config['fsdp'])
parallelism_config_args['fsdp2'] = FSDP2Config.from_compatible_attrs(parallelism_config['fsdp'])
Comment thread
rithwik-db marked this conversation as resolved.
else:
parallelism_config_args['fsdp'] = FSDPConfig(**parallelism_config['fsdp'])
if 'tp' in parallelism_config and parallelism_config['tp'] is not None:
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_trainer_with_model(
optimizer: Optional[torch.optim.Optimizer] = None,
activation_checkpointing: bool = False,
activation_cpu_offload: bool = False,
auto_microbatching: bool = False,
) -> Trainer:
"""Helper function to create a Trainer with a model, dataloader, and FSDP2 configuration."""
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
Expand All @@ -53,6 +54,7 @@ def create_trainer_with_model(
train_dataloader=dataloader,
max_duration=max_duration,
parallelism_config=parallelism_config,
device_train_microbatch_size='auto' if auto_microbatching else None,
)
return trainer

Expand Down Expand Up @@ -303,3 +305,19 @@ def test_fsdp2_optimizer_raises_error_when_optimizer_modules_dont_match(
# We check with `optimizer.param_id.` (with the period) since `optimizer.param_id` exists
# by default in the error message's legend
assert 'optimizer.param_id.' in str(e.value)


@pytest.mark.gpu
@world_size(2) # Using world_size(2) for consistency with other FSDP2 tests in this file although not needed
@pytest.mark.filterwarnings("ignore:`device_train_microbatch_size='auto'` may potentially fail with unexpected.*")
def test_fsdp2_auto_microbatching_raises_error(
world_size: int,
):
"""Test FSDP2 raises an error when auto microbatching is used."""
del world_size

model = SimpleComposerMLP(num_features=10, device='cuda', num_classes=10)
model.add_fsdp_wrap_attribute_to_children()
with pytest.raises(ValueError) as e:
create_trainer_with_model(model=model, num_classes=10, use_fsdp2=True, auto_microbatching=True)
assert 'Auto microbatching is not supported outside of FSDP1' in str(e.value)
Loading