Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
"full",
"relax",
] = "relax"
VLLM_MLA_FORCE_DENSE: bool = False
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
Expand Down Expand Up @@ -1263,6 +1264,11 @@ def _get_or_set_default() -> str:
"relax",
],
),
# Force MLA to use dense attention, disabling the sparse attention
# indexer. Useful on architectures where DeepGEMM is not supported.
"VLLM_MLA_FORCE_DENSE": lambda: bool(
int(os.getenv("VLLM_MLA_FORCE_DENSE", "0"))
),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool(
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
is_deep_gemm_supported,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.import_utils import has_deep_gemm

logger = init_logger(__name__)

Expand All @@ -54,7 +53,7 @@ def _valid_deep_gemm(
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm():
if not is_deep_gemm_supported():
logger.debug_once("DeepGemm disabled: deep_gemm not available.")
return False

Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/sparse_attn_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, is_deep_gemm_supported
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
Expand Down Expand Up @@ -317,9 +317,11 @@ def __init__(
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
if current_platform.is_cuda() and not has_deep_gemm():
if current_platform.is_cuda() and not is_deep_gemm_supported():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
"Sparse Attention Indexer CUDA op requires DeepGEMM "
"to be installed and supported on this architecture. "
"Set VLLM_MLA_FORCE_DENSE=1 to use dense attention instead."
)

def forward_native(
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from transformers import DeepseekV2Config, DeepseekV3Config

import vllm._custom_ops as ops
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
Expand Down Expand Up @@ -967,7 +968,7 @@ def __init__(
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale

self.is_v32 = hasattr(config, "index_topk")
self.is_v32 = hasattr(config, "index_topk") and not envs.VLLM_MLA_FORCE_DENSE

if self.is_v32:
self.indexer_rope_emb = get_rope(
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.device = current_platform.device_type

self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
self.is_v32 = hasattr(config, "index_topk") and not envs.VLLM_MLA_FORCE_DENSE
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
Expand Down Expand Up @@ -1508,6 +1509,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if "rotary_emb.inv_freq" in name:
continue

if not self.model.is_v32 and "indexer." in name:
continue

spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
has_deep_gemm,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
Expand Down Expand Up @@ -553,7 +553,7 @@ def build(
)

# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and has_deep_gemm():
if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/worker/gpu_ubatch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.deep_gemm import set_num_sms as deep_gemm_set_num_sms
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.deep_gemm import is_deep_gemm_supported, set_num_sms as deep_gemm_set_num_sms
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts

Expand Down Expand Up @@ -158,7 +157,7 @@ def _create_sm_control_context(vllm_config: VllmConfig):

# TODO(lucas): support other kernels besides DeepGEMM
set_compute_sms = lambda sms: None
if has_deep_gemm() and comm_sms > 0:
if is_deep_gemm_supported() and comm_sms > 0:
set_compute_sms = lambda sms: deep_gemm_set_num_sms(sms)

return SMControlContextManager(
Expand Down
Loading