Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/models/quantization/test_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm import LLM, SamplingParams

from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_b12x_gemm

os.environ["TOKENIZERS_PARALLELISM"] = "true"

Expand Down Expand Up @@ -105,6 +106,7 @@ def test_models(example_prompts, model_name) -> None:
"flashinfer_cudnn",
"flashinfer_trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
"flashinfer_cutlass",
"flashinfer_b12x",
],
)
def test_nvfp4(vllm_runner, model, eager, backend):
Expand All @@ -115,6 +117,11 @@ def test_nvfp4(vllm_runner, model, eager, backend):
pytest.skip(
f"The backend {backend} is not supported with current_platform.has_device_capability(100) == False"
)
if backend == "flashinfer_b12x" and (
not current_platform.has_device_capability(120)
or not has_flashinfer_b12x_gemm()
):
pytest.skip(f"The backend {backend} requires SM120+ and FlashInfer B12x GEMM")

with vllm_runner(model, enforce_eager=eager, linear_backend=backend) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def with_default(
"flashinfer_cutlass",
"flashinfer_trtllm",
"flashinfer_cudnn",
"flashinfer_b12x",
"marlin",
"triton",
"deep_gemm",
Expand Down Expand Up @@ -197,6 +198,7 @@ class KernelConfig:
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_trtllm": Use FlashInfer with TensorRT-LLM kernels
- "flashinfer_cudnn": Use FlashInfer with cuDNN kernels
- "flashinfer_b12x": Use FlashInfer B12x NVFP4 GEMM kernels for SM12x
- "marlin": Use Marlin kernels
- "triton": Use Triton-based kernels
- "deep_gemm": Use DeepGEMM kernels
Expand Down
2 changes: 2 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,9 +1662,11 @@ def _resolve_rust_frontend_path() -> str | None:
int(os.getenv("VLLM_HAS_FLASHINFER_CUBIN", "0"))
),
# Supported options:
# - "flashinfer-b12x": use flashinfer b12x GEMM backend (SM120+)
# - "flashinfer-cudnn": use flashinfer cudnn GEMM backend
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - "cutlass": use vLLM cutlass GEMM backend
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
# - "emulation":
# use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations.
Expand Down
15 changes: 12 additions & 3 deletions vllm/model_executor/kernels/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def _get_linear_backend() -> str:
"flashinfer_cudnn": {
FlashInferCudnnNvFp4LinearKernel,
},
"flashinfer_b12x": {
FlashInferB12xNvFp4LinearKernel,
},
"marlin": {
MarlinFP8ScaledMMLinearKernel,
MarlinLinearKernel,
Expand Down Expand Up @@ -375,9 +378,6 @@ def _filter_kernels_by_backend(

_POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
PlatformEnum.CUDA: [
# FlashInferB12xNvFp4LinearKernel excluded from auto-selection until
# upstream CUTLASS SM121 MMA op guard is resolved; use
# VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x to opt in explicitly.
FlashInferCutlassNvFp4LinearKernel,
CutlassNvFp4LinearKernel,
MarlinNvFp4LinearKernel,
Expand All @@ -391,6 +391,13 @@ def _filter_kernels_by_backend(
],
}

_EXPLICIT_ONLY_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
# FlashInferB12xNvFp4LinearKernel is excluded from auto-selection until
# upstream CUTLASS SM121 MMA op guard is resolved. Users can opt in with
# --linear-backend flashinfer_b12x.
PlatformEnum.CUDA: [FlashInferB12xNvFp4LinearKernel],
}

_POSSIBLE_MXFP4_KERNELS: dict[PlatformEnum, list[type[MxFp4LinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferMxFp4LinearKernel,
Expand Down Expand Up @@ -906,6 +913,8 @@ def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:

# Apply --linear-backend filtering when set.
if linear_backend != "auto":
possible.extend(_EXPLICIT_ONLY_NVFP4_KERNELS.get(platform, []))
possible = list(dict.fromkeys(possible))
filtered = _filter_kernels_by_backend(linear_backend, possible)
if not filtered:
raise ValueError(
Expand Down
Loading