Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 106 additions & 35 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from typing import Any, Optional

import torch
import torch.distributed as dist
Expand All @@ -9,17 +9,15 @@
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_pplx

from .base_device_communicator import All2AllManagerBase, Cache

if has_flashinfer_all2all():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe

logger = init_logger(__name__)

Expand Down Expand Up @@ -67,7 +65,6 @@ def dispatch(
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

hidden_states = self.naive_multicast(
Expand All @@ -84,7 +81,6 @@ def combine(
ep_rank = self.rank if is_sequence_parallel else self.dp_rank

dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

Expand Down Expand Up @@ -117,10 +113,7 @@ def dispatch(
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
Expand All @@ -137,10 +130,7 @@ def combine(
"""
Reduce-scatter hidden_states across all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
Expand All @@ -165,7 +155,7 @@ def __init__(self, cpu_group):
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
from pplx_kernels.nvshmem import (
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
Expand All @@ -192,7 +182,7 @@ def __init__(self, cpu_group):
self.handle_cache = Cache()

def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
import pplx_kernels as pplx

return self.handle_cache.get_or_create(
kwargs,
Expand All @@ -218,9 +208,7 @@ def destroy(self):
handle.destroy()

if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
from pplx_kernels.nvshmem import nvshmem_finalize

logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
Expand Down Expand Up @@ -277,7 +265,7 @@ def _make_all2all_kwargs(self) -> dict[Any, Any]:
num_rdma_bytes = None
num_qps_per_rank = None

if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
if self.internode:
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = self.num_sms // 2
else:
Expand All @@ -300,7 +288,7 @@ def get_handle(self, kwargs):
"args are computed in the Manager itself."
)

import deep_ep # type: ignore[import-not-found]
import deep_ep

buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
Expand All @@ -310,7 +298,7 @@ def get_handle(self, kwargs):
return handle

def set_num_sms(self, num_sms: int):
import deep_ep # type: ignore[import-not-found]
import deep_ep

# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
Expand Down Expand Up @@ -344,7 +332,7 @@ def _make_all2all_kwargs(
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep # type: ignore[import-not-found]
import deep_ep

# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
Expand All @@ -363,16 +351,14 @@ def _make_all2all_kwargs(
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
)

def get_handle(self, kwargs):
"""
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep # type: ignore[import-not-found]
import deep_ep

buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
Expand All @@ -382,7 +368,7 @@ def get_handle(self, kwargs):
return handle

# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> int | None:
def max_sms_used(self) -> Optional[int]:
return 0


Expand All @@ -391,11 +377,6 @@ class FlashInferAllToAllManager(All2AllManagerBase):
All2All communication based on flashinfer kernels.
"""

# This type lint could be removed after all of the work in
# https://github.com/vllm-project/vllm/issues/26533 done.
rank: int
world_size: int

def __init__(self, cpu_group):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
Expand Down Expand Up @@ -488,3 +469,93 @@ def cleanup(self):
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False


class UCCLEPLLAll2AllManager(All2AllManagerBase):
"""
All2All communication based on UCCL-EP Low-Latency kernels.
Uses the same interface as DeepEP, so we can reuse DeepEPLLPrepareAndFinalize.
"""

def __init__(self, cpu_group):
from vllm.utils import has_uccl_ep

assert has_uccl_ep(), (
"UCCL-EP kernels not found. Please install uccl_ep package."
)
super().__init__(cpu_group)
self.handle_cache = Cache()

# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20

def _make_all2all_kwargs(
self,
max_num_tokens_per_dp_rank: int,
token_hidden_size: int,
num_ep_ranks: int,
num_global_experts: int,
num_local_experts: int,
) -> dict[Any, Any]:
"""
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
can dispatch all the ranks must hold the same value.
token_hidden_size: the hidden dimension of each token.
num_ep_ranks: the number of EP group ranks.
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import uccl_ep

# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = num_local_experts
num_rdma_bytes = uccl_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts,
)

assert num_rdma_bytes is not None

# Debug logging
logger.info(
f"[UCCL-EP DEBUG] Buffer initialization params:\n"
f" max_num_tokens_per_dp_rank: {max_num_tokens_per_dp_rank}\n"
f" token_hidden_size: {token_hidden_size}\n"
f" num_ep_ranks: {num_ep_ranks}\n"
f" num_global_experts: {num_global_experts}\n"
f" num_local_experts: {num_local_experts}\n"
f" num_rdma_bytes: {num_rdma_bytes}"
)

return dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
)

def get_handle(self, kwargs):
"""
The kwargs for UCCLEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
Returns a uccl_ep.Buffer instance.
Since UCCL-EP has the same API as DeepEP, this buffer can be directly
used by DeepEPLLPrepareAndFinalize without any modifications.
"""
import uccl_ep

buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("UCCL-EP all2all args %s", buffer_kwargs)
handle: uccl_ep.Buffer = self.handle_cache.get_or_create(
buffer_kwargs, uccl_ep.Buffer
)
return handle

# UCCL-EP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
return 0
Loading