Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,8 @@ def prepare(self, *args, device_placement=None):

if self.parallelism_config and self.parallelism_config.cp_enabled:
args = self._prepare_cp(*args)
if self.parallelism_config and self.parallelism_config.sp_enabled:
args = self._prepare_sp(*args)
# for megatron-lm, we don't need to prepare TE AO at this moment
if self.distributed_type != DistributedType.MEGATRON_LM:
if self.fp8_backend == FP8BackendType.TE:
Expand All @@ -1558,6 +1560,8 @@ def prepare(self, *args, device_placement=None):
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
)
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
if self.parallelism_config and self.parallelism_config._sequence_sharding_enabled:
result = self._wrap_sequence_sharding_dataloaders(result)
if tpu_should_fix_optimizer:
# 2. grabbing new model parameters
new_named_params = self._get_named_parameters(*result)
Expand Down Expand Up @@ -1670,6 +1674,40 @@ def _prepare_cp(self, *args):

return args

def _prepare_sp(self, *args):
"""Native ('accelerate') Ulysses sequence parallelism: register the Ulysses attention on each
model over the `sp` mesh dim. No-op for the deepspeed sp backend (handled in
`_prepare_deepspeed`)."""
if self.parallelism_config.sp_backend != "accelerate":
return args

from .utils.sequence_parallel import enable_ulysses_sp

self._sp_attention = None
sp_group = self.torch_device_mesh["sp"].get_group()
for arg in args:
if isinstance(arg, torch.nn.Module):
# the handler is stashed so `prepare` can wire a packed dataloader to push per-step
# cu_seqlens onto it (`handler.set_varlen`).
self._sp_attention = enable_ulysses_sp(arg, sp_group)

return args

def _wrap_sequence_sharding_dataloaders(self, result):
"""Wrap each prepared DataLoader in a `SequenceShardingDataLoader` that shards the sequence
contiguously over the `sp` group. `prepare_data_loader` already hands sp ranks the same
sample (dp-aware sharding divides out tp*cp*sp), so the wrapper only does the sequence
split; packed cu_seqlens are pushed onto the handler (`set_varlen`)."""
from .utils.sequence_parallel import SequenceShardingDataLoader

shard_group = self.torch_device_mesh["sp"].get_group()
wrapped = []
for obj in result:
if isinstance(obj, torch.utils.data.DataLoader) and not isinstance(obj, SequenceShardingDataLoader):
obj = SequenceShardingDataLoader(obj, shard_group, attention=self._sp_attention)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_sp_attention is a bit ugly, this might change. lmk @winglian if you have a better design for this in order to deal with packed sequences.

wrapped.append(obj)
return tuple(wrapped)

def _prepare_fsdp2(self, *args):
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
result = [
Expand Down Expand Up @@ -2219,12 +2257,17 @@ def _prepare_deepspeed(self, *args):
}
# This block is skipped when preparing just a model and DL is absent from current call's args
if batch_size_per_device is not None:
# `sp_size` divides only for the deepspeed/ALST backend, where the mpu tells DeepSpeed
# that sp is model-parallel (DP world = num_processes // sp). The native "accelerate"
# backend sets up no mpu, so DeepSpeed counts the sp ranks as data-parallel and expects
# train_batch_size == micro * grad_acc * num_processes (no // sp).
sp_divisor = sp_size if sp_backend == "deepspeed" else 1
config_kwargs["train_micro_batch_size_per_gpu"] = batch_size_per_device
config_kwargs["train_batch_size"] = (
batch_size_per_device
* deepspeed_plugin.get_value("gradient_accumulation_steps")
* self.num_processes
// sp_size
// sp_divisor
)

model = None
Expand Down Expand Up @@ -2383,12 +2426,7 @@ def _prepare_deepspeed(self, *args):
os.environ["DEEPSPEED_USE_HPU"] = "true"

mpu = None
if sp_size > 1:
if sp_backend != "deepspeed":
raise ValueError(
f"In order to use the configured {sp_size=} with DeepSpeed, you need to configure sp_backend='deepspeed', yet you configured it to be {sp_backend=}."
)

if sp_size > 1 and sp_backend == "deepspeed":
ver_min_required = "0.18.2"
if not compare_versions("deepspeed", ">=", ver_min_required):
raise ImportError(
Expand Down Expand Up @@ -2726,6 +2764,7 @@ def prepare_data_loader(
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=device_mesh,
parallelism_config=self.parallelism_config,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down
24 changes: 16 additions & 8 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ def prepare_data_loader(
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh=None,
parallelism_config=None,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -1127,15 +1128,17 @@ def prepare_data_loader(
if torch_device_mesh:
if state.distributed_type == DistributedType.DEEPSPEED:
# In DeepSpeed, the optimizer sharing level in DP is determined by the config file.
# Only considers "dp" and "tp".
# Given a device mesh (dp, tp) = (2, 3):
# tp ranks receive the SAME batch. Given a device mesh (dp, tp) = (2, 3):
# - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
# - Processes with the same DP rank will receive the same batch.
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
process_index = process_index // submesh_tp_size
num_processes = num_processes // submesh_tp_size
inner = torch_device_mesh["tp"].size() if "tp" in torch_device_mesh.mesh_dim_names else 1
# Native ("accelerate") sp/cp shard the sequence, so those ranks also get the SAME batch.
# The DeepSpeed (ALST) sp backend manages its own data adapter + mpu, so leave it untouched.
if parallelism_config is not None:
if parallelism_config.sp_backend == "accelerate" and "sp" in torch_device_mesh.mesh_dim_names:
inner *= torch_device_mesh["sp"].size()
process_index = process_index // inner
num_processes = num_processes // inner
else:
# when device mesh is used, specifically with TP
# then there is need to update process_index and num_processes
Expand All @@ -1151,15 +1154,20 @@ def prepare_data_loader(
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
submesh_sp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "sp" in torch_device_mesh.mesh_dim_names:
submesh_sp_size = torch_device_mesh["sp"].size()
if "dp_replicate" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp_replicate"].size()
if "dp_shard" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)
# tp/cp/sp are non-data-parallel inner dims: ranks within them get the SAME batch
# (cp/sp then shard its sequence). Only dp_replicate x dp_shard get distinct batches.
process_index = process_index // (submesh_tp_size * submesh_cp_size * submesh_sp_size)
num_processes = submesh_fsdp_size * submesh_dp_size

# Sanity check
Expand Down
103 changes: 87 additions & 16 deletions src/accelerate/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DeepSpeedSequenceParallelConfig,
DistributedType,
TorchContextParallelConfig,
AccelerateSequenceParallelConfig,
TorchTensorParallelConfig,
)
from accelerate.utils.versions import is_torch_version
Expand Down Expand Up @@ -53,10 +54,15 @@ class ParallelismConfig:
for downstream libraries.
cp_backend (`str`, defaults to `torch`):
Which CP backend to use: `torch` (FSDP2)
cp_handler (`~utils.TorchContextParallelConfig`, *optional*, defaults to `None`): The handler for the context parallel group.
sp_size (`int`, defaults to `1`):
The size of the sequence parallel group.
sp_backend (`str`, defaults to `deepspeed`):
Which SP backend to use:`deepspeed` (ALST/Ulysses)
The size of the sequence parallel group. If `1`, SP is not used.
sp_backend (`str`, *optional*):
Which SP backend to use: `"accelerate"` (our native Ulysses) or `"deepspeed"`
(ALST/Ulysses, which needs the DeepSpeed engine). If left unset it auto-resolves: `"deepspeed"`
under the DeepSpeed engine, else `"accelerate"`.
sp_handler (`~utils.AccelerateSequenceParallelConfig` or `~utils.DeepSpeedSequenceParallelConfig`, *optional*, defaults to `None`):
Config for the sequence parallel group;

You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
together:
Expand All @@ -73,12 +79,12 @@ class ParallelismConfig:
cp_size: Optional[int] = None
cp_backend: Literal["torch"] = None
sp_size: Optional[int] = None
sp_backend: Literal["deepspeed"] = None
sp_backend: Literal["deepspeed", "accelerate"] = None

# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
tp_handler: Union[None, TorchTensorParallelConfig] = None
cp_handler: Union[None, TorchContextParallelConfig] = None
sp_handler: Union[None, DeepSpeedSequenceParallelConfig] = None
sp_handler: Union[None, DeepSpeedSequenceParallelConfig, AccelerateSequenceParallelConfig] = None

device_mesh = None

Expand All @@ -94,15 +100,16 @@ def __repr__(self):
f"\tsp_backend={self.sp_backend},\n"
f"\ttotal_size={self.total_size}\n"
f"\ttp_handler={self.tp_handler},\n"
f"\tcp_handler={self.cp_handler})\n"
f"\tcp_handler={self.cp_handler},\n"
f"\tsp_handler={self.sp_handler})\n"
)

def to_json(self):
import copy

_non_serializable_fields = ["device_mesh"]

copy.deepcopy(
return copy.deepcopy(
{
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
for k, v in self.__dict__.items()
Expand Down Expand Up @@ -132,26 +139,42 @@ def non_dp_dim_names(self):
dims += ["sp"]
return dims

@property
def _sequence_sharding_enabled(self):
"""Whether the native (accelerate) sequence-sharding path is active and `prepare` should wrap
dataloaders in a `SequenceShardingDataLoader`. Currently Ulysses SP on the accelerate backend;
ring CP will OR its clause in here once added."""
return self.sp_enabled and self.sp_backend == "accelerate"

@property
def dp_shard_cp_dim_names(self):
"""Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
"""The FSDP shard axis — flattened into one mesh dim (`dp_shard_cp`, accelerate's name) that
`fully_shard` shards params / reduce-scatters grads over. Despite the name it spans `dp_shard`
**plus the sequence-parallel dims `cp` and `sp`**: both shard the same sample's sequence, so
FSDP must shard over them for its reduce-scatter to average their grads."""
dims = []
if self.dp_shard_enabled:
dims += ["dp_shard"]
if self.cp_enabled:
dims += ["cp"]
if self.sp_enabled:
dims += ["sp"]
return dims

@property
def dp_cp_dim_names(self):
"""Names of enabled dimensions across which loss should be averaged"""
"""Dims over which loss/grads are averaged, flattened to the `dp_cp` mesh dim. Despite the
name it spans the data-parallel dims (`dp_replicate`, `dp_shard`) **plus the sequence-parallel
dims `cp` and `sp`**."""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
if self.dp_shard_enabled:
dims += ["dp_shard"]
if self.cp_enabled:
dims += ["cp"]
if self.sp_enabled:
dims += ["sp"]
return dims

@property
Expand All @@ -170,12 +193,12 @@ def total_size(self):

@property
def non_data_parallel_size(self):
"""The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
"""Product of the non-data-parallel sizes (tensor x context/ring x sequence/Ulysses)."""
return self.tp_size * self.cp_size * self.sp_size

@property
def data_parallel_size(self):
"""The size of the data parallel dimensions, which is the product of data parallel replication and"""
"""Product of the data-parallel sizes (replication x shard)."""
return self.dp_replicate_size * self.dp_shard_size

@property
Expand All @@ -200,7 +223,7 @@ def cp_enabled(self):

@property
def sp_enabled(self):
"""True if context parallelism is enabled, i.e. `sp_size > 1`."""
"""True if sequence parallelism is enabled, i.e. `sp_size > 1`."""
return self.sp_size > 1

@property
Expand Down Expand Up @@ -272,6 +295,11 @@ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
return tuple(zip(*sorted_items))

def __post_init__(self):
# Track whether the user explicitly chose a backend (vs. leaving it to auto-resolve from
# the training engine in `_resolve_backends`). Env-set counts as explicit. Must be read
# BEFORE the defaulting below overwrites None.
self._sp_backend_explicit = self.sp_backend is not None or "PARALLELISM_CONFIG_SP_BACKEND" in os.environ

# Basic size validation
if self.dp_replicate_size is None:
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
Expand All @@ -281,10 +309,10 @@ def __post_init__(self):
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
if self.cp_size is None:
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
if self.cp_backend is None:
self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch")
if self.sp_size is None:
self.sp_size = int(os.environ.get("PARALLELISM_CONFIG_SP_SIZE", "1"))
if self.cp_backend is None:
self.cp_backend = os.environ.get("PARALLELISM_CONFIG_CP_BACKEND", "torch")
if self.sp_backend is None:
self.sp_backend = os.environ.get("PARALLELISM_CONFIG_SP_BACKEND", "deepspeed")

Expand All @@ -306,7 +334,21 @@ def __post_init__(self):

if self.sp_size > 1:
if self.sp_handler is None:
self.sp_handler = DeepSpeedSequenceParallelConfig()
self.sp_handler = (
AccelerateSequenceParallelConfig()
if self.sp_backend == "accelerate"
else DeepSpeedSequenceParallelConfig()
)
else:
sp_backends_config_map = dict(
deepspeed=DeepSpeedSequenceParallelConfig,
accelerate=AccelerateSequenceParallelConfig,
)
if not isinstance(self.sp_handler, sp_backends_config_map[self.sp_backend]):
raise ValueError(
f"ParallelismConfig's sp_backend={self.sp_backend} requires "
f"{sp_backends_config_map[self.sp_backend]}, but sp_handler was set to {type(self.sp_handler)}"
)
if self.dp_replicate_size < 1:
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
if self.dp_shard_size < 1:
Expand All @@ -321,7 +363,7 @@ def __post_init__(self):

if self.sp_size < 1:
raise ValueError(f"sp_size must be at least 1, but got {self.sp_size}")
valid_sp_backends = ["deepspeed"]
valid_sp_backends = ["deepspeed", "accelerate"]
if self.sp_backend not in valid_sp_backends:
raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}")

Expand Down Expand Up @@ -352,8 +394,37 @@ def _set_size(self, parallelism: str, size: int):
self._sizes[parallelism] = size
setattr(self, f"{parallelism}_size", size)

def _resolve_backends(self, accelerator: "Accelerator"):
"""Resolve the sp (Ulysses) backend against the training engine, once ``distributed_type``
is known. ``"deepspeed"`` (ALST) only runs under the DeepSpeed engine; ``"accelerate"``
(native Ulysses) runs under any engine (DeepSpeed ZeRO, FSDP2, or DDP). When the user
didn't pin a backend it defaults to ``"deepspeed"`` under DeepSpeed and ``"accelerate"``
otherwise; an explicit ``"deepspeed"`` without the DeepSpeed engine falls back to
``"accelerate"``. ``sp_handler`` is set to match the resolved backend."""
if self.sp_size <= 1:
return

is_deepspeed = accelerator.distributed_type == DistributedType.DEEPSPEED

if not self._sp_backend_explicit:
# ALST is the purpose-built path under DeepSpeed; native Ulysses under FSDP2 / DDP.
self.sp_backend = "deepspeed" if is_deepspeed else "accelerate"
elif self.sp_backend == "deepspeed" and not is_deepspeed:
# ALST needs the DeepSpeed engine; fall back to native Ulysses (runs under FSDP2 / DDP).
self.sp_backend = "accelerate"

handler_cls = (
DeepSpeedSequenceParallelConfig
if self.sp_backend == "deepspeed"
else AccelerateSequenceParallelConfig
)
if not isinstance(self.sp_handler, handler_cls):
self.sp_handler = handler_cls()

def _validate_accelerator(self, accelerator: "Accelerator"):
_warnings = set()
# Resolve auto backends + enforce backend<->engine compatibility now that the engine is known.
self._resolve_backends(accelerator)
if not accelerator.multi_device and self.total_size == 1:
# No distributed setup, valid parallelism config
return
Expand Down
Loading
Loading