Skip to content

Commit 4819a2c

Browse files
meena-at-workbryanfarrell
authored andcommitted
[SM120/121] Add FlashInfer b12x NVFP4 MoE + GEMM backends
Carries upstream PR vllm-project#40082 (vllm-project/vllm) onto the fiosco release line as a single net-change commit. Adds the FlashInfer b12x CuteDSL NVFP4 fused-MoE and FP4 GEMM backends targeting SM120/121 desktop and Spark Blackwell parts. The b12x backends are gated off auto-selection and require explicit opt-in via VLLM_FUSED_MOE_BACKEND=flashinfer-b12x and the matching linear-backend flag. Compatible with FlashInfer 0.6.11+ (Sm120 b12x kernel-class rename absorbed). Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Signed-off-by: Bryan Farrell <12701870+bryanfarrell@users.noreply.github.com>
1 parent ad7125a commit 4819a2c

10 files changed

Lines changed: 595 additions & 4 deletions

File tree

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.platforms import current_platform
8+
9+
if not current_platform.is_device_capability_family(120):
10+
pytest.skip(
11+
reason="FlashInfer CuteDSL SM12x MoE requires SM120 "
12+
"(RTX Pro 6000 / DGX Spark).",
13+
allow_module_level=True,
14+
)
15+
16+
from vllm.utils.flashinfer import has_flashinfer_b12x_moe
17+
18+
if not has_flashinfer_b12x_moe():
19+
pytest.skip(
20+
reason=(
21+
"FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout "
22+
"not available in installed FlashInfer (needs PRs #3051 and #3066)."
23+
),
24+
allow_module_level=True,
25+
)
26+
27+
# Import fp4_quantize after the skip guard — FlashInfer must be installed.
28+
from flashinfer.fp4_quantization import fp4_quantize
29+
30+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
31+
from tests.kernels.moe.utils import make_dummy_moe_config
32+
from tests.kernels.utils import torch_moe
33+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
34+
from vllm.model_executor.layers.fused_moe import fused_topk
35+
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
36+
from vllm.model_executor.layers.fused_moe.all2all_utils import (
37+
maybe_make_prepare_finalize,
38+
)
39+
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
40+
from vllm.model_executor.layers.fused_moe.experts.flashinfer_b12x_moe import (
41+
FlashInferB12xExperts,
42+
)
43+
from vllm.utils.flashinfer import flashinfer_convert_sf_to_mma_layout
44+
from vllm.utils.torch_utils import set_random_seed
45+
46+
# Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256,
47+
# n multiple of 128) while keeping tests fast.
48+
MNK_FACTORS = [
49+
(2, 128, 256),
50+
(2, 256, 512),
51+
(16, 128, 256),
52+
(64, 256, 512),
53+
]
54+
55+
56+
def _reorder_gate_up_to_up_gate(
57+
w: torch.Tensor,
58+
w_s: torch.Tensor,
59+
) -> tuple[torch.Tensor, torch.Tensor]:
60+
"""Swap gate and up-projection halves along dim=1 to [up, gate] order.
61+
62+
The SM12x kernel expects weights in [up (w3), gate (w1)] order while the
63+
BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering
64+
done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``.
65+
"""
66+
n = w.shape[1] // 2
67+
return (
68+
torch.cat([w[:, n:, :], w[:, :n, :]], dim=1),
69+
torch.cat([w_s[:, n:, :], w_s[:, :n, :]], dim=1),
70+
)
71+
72+
73+
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
74+
@pytest.mark.parametrize("e", [8, 16])
75+
@pytest.mark.parametrize("topk", [1, 2, 4])
76+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
77+
@torch.inference_mode()
78+
def test_flashinfer_b12x_moe(
79+
m: int,
80+
n: int,
81+
k: int,
82+
e: int,
83+
topk: int,
84+
dtype: torch.dtype,
85+
workspace_init,
86+
):
87+
"""Test FlashInferB12xExperts against a BF16 torch reference.
88+
89+
The SM12x kernel takes BF16 hidden states directly and fuses token
90+
dispatch, W1 GEMM, SwiGLU, and W2 GEMM into one call. We verify
91+
correctness against ``torch_moe`` using generous tolerances to account
92+
for the internal FP4 quantization of activations and weights.
93+
94+
Scale convention
95+
----------------
96+
The SM12x kernel uses ``w1_alpha`` as *both* the activation-quantisation
97+
global scale and the weight dequantisation factor. These two roles are
98+
conflated into a single parameter in ``launch_sm120_moe``, so they must
99+
equal the same value. We use ``global_scale = 1.0`` for
100+
``fp4_quantize`` so that ``w1_alpha = ones`` satisfies both roles
101+
simultaneously. The alternative — vLLM's convention of baking a large
102+
``w_gs`` into block-scale values and compensating with
103+
``g1_alphas = 1/w_gs`` — is incompatible with this kernel.
104+
"""
105+
set_random_seed(7)
106+
with set_current_vllm_config(
107+
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
108+
):
109+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
110+
111+
# Generate BF16 reference weights in [gate, up] order.
112+
# Shape: w1=(e, 2n, k), w2=(e, k, n).
113+
w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 15
114+
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15
115+
116+
# ------------------------------------------------------------------ #
117+
# Quantise weights for the SM12x kernel using FlashInfer's convention:
118+
# global_scale = 1.0 → block_scale = max_abs_block / fp4_max
119+
# w1_alpha = 1.0 (no extra global factor to compensate)
120+
#
121+
# The scale factors returned by fp4_quantize(..., is_sf_swizzled_layout=True)
122+
# are already in the swizzled 2D layout expected by convert_sf_to_mma_layout.
123+
# No additional swizzle_blockscale() call is needed.
124+
# ------------------------------------------------------------------ #
125+
gs = torch.ones(1, device="cuda", dtype=torch.float32)
126+
sf_vec_size = 16
127+
128+
# W1: reorder BF16 from [gate, up] → [up, gate], then quantise.
129+
w1_reordered = torch.cat(
130+
[w1_bf16[:, n:, :], w1_bf16[:, :n, :]], dim=1
131+
) # shape (e, 2n, k), [up, gate]
132+
w1_flat = w1_reordered.reshape(e * 2 * n, k)
133+
w1_q_flat, w1_sf_flat = fp4_quantize(
134+
w1_flat,
135+
global_scale=gs,
136+
sf_vec_size=sf_vec_size,
137+
is_sf_swizzled_layout=True,
138+
)
139+
w1_q = w1_q_flat.view(e, 2 * n, k // 2) # uint8, packed FP4
140+
w1_blockscale = w1_sf_flat.view(e, 2 * n, w1_sf_flat.shape[1]) # float8
141+
142+
# W2: no row reordering needed for the down-projection.
143+
w2_flat = w2_bf16.reshape(e * k, n)
144+
w2_q_flat, w2_sf_flat = fp4_quantize(
145+
w2_flat,
146+
global_scale=gs,
147+
sf_vec_size=sf_vec_size,
148+
is_sf_swizzled_layout=True,
149+
)
150+
w2_q = w2_q_flat.view(e, k, n // 2) # uint8, packed FP4
151+
w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1]) # float8
152+
153+
# All per-expert alphas are 1.0 (global_scale = 1.0, no compensation).
154+
ones_e = torch.ones(e, device="cuda", dtype=torch.float32)
155+
156+
quant_config = nvfp4_moe_quant_config(
157+
g1_alphas=ones_e,
158+
g2_alphas=ones_e,
159+
a1_gscale=ones_e,
160+
a2_gscale=ones_e,
161+
w1_scale=w1_blockscale,
162+
w2_scale=w2_blockscale,
163+
)
164+
165+
moe_config = make_dummy_moe_config(
166+
num_experts=e,
167+
experts_per_token=topk,
168+
hidden_dim=k,
169+
intermediate_size_per_partition=n,
170+
in_dtype=dtype,
171+
)
172+
173+
experts = FlashInferB12xExperts(
174+
moe_config=moe_config,
175+
quant_config=quant_config,
176+
)
177+
# In production, process_weights_after_loading computes these after
178+
# normalizing block scales. In the test the scales are already in final
179+
# form (global_scale=1.0), so we compute the MMA layouts directly.
180+
num_experts_w1, m1, k1_sf = w1_blockscale.shape
181+
experts.w1_sf_mma = flashinfer_convert_sf_to_mma_layout(
182+
w1_blockscale.reshape(num_experts_w1 * m1, k1_sf),
183+
m=m1,
184+
k=k1_sf * 16,
185+
num_groups=num_experts_w1,
186+
)
187+
num_experts_w2, m2, k2_sf = w2_blockscale.shape
188+
experts.w2_sf_mma = flashinfer_convert_sf_to_mma_layout(
189+
w2_blockscale.reshape(num_experts_w2 * m2, k2_sf),
190+
m=m2,
191+
k=k2_sf * 16,
192+
num_groups=num_experts_w2,
193+
)
194+
195+
kernel = mk.FusedMoEKernel(
196+
maybe_make_prepare_finalize(
197+
moe=moe_config,
198+
quant_config=quant_config,
199+
allow_new_interface=True,
200+
use_monolithic=False,
201+
),
202+
experts,
203+
inplace=False,
204+
)
205+
206+
score = torch.randn((m, e), device="cuda", dtype=dtype)
207+
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
208+
209+
sm12x_output = kernel.apply(
210+
hidden_states=a,
211+
w1=w1_q,
212+
w2=w2_q,
213+
topk_weights=topk_weights,
214+
topk_ids=topk_ids,
215+
global_num_experts=e,
216+
activation=MoEActivation.SILU,
217+
apply_router_weight_on_input=False,
218+
expert_map=None,
219+
)
220+
221+
# Reference: BF16 torch MoE using original [gate, up] BF16 weights.
222+
# torch_moe's SiluAndMul expects [gate, up] order, matching w1_bf16.
223+
torch_output = torch_moe(a, w1_bf16, w2_bf16, score, topk)
224+
225+
torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1)
226+
227+
228+
if __name__ == "__main__":
229+
test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16)

tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.platforms import current_platform
1414
from vllm.utils.flashinfer import (
1515
flashinfer_scaled_fp4_mm,
16+
has_flashinfer_b12x_gemm,
1617
)
1718
from vllm.utils.torch_utils import set_random_seed
1819

@@ -74,7 +75,7 @@ def get_ref_results(
7475
@pytest.mark.parametrize("shape", SHAPES)
7576
@pytest.mark.parametrize("seed", SEEDS)
7677
@pytest.mark.parametrize("device", CUDA_DEVICES)
77-
@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm"])
78+
@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "trtllm", "b12x"])
7879
@pytest.mark.parametrize("autotune", [False, True])
7980
@torch.inference_mode()
8081
def test_flashinfer_nvfp4_gemm(
@@ -87,6 +88,10 @@ def test_flashinfer_nvfp4_gemm(
8788
) -> None:
8889
if "trtllm" in backend and dtype == torch.float16:
8990
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
91+
if backend == "b12x" and not current_platform.has_device_capability(120):
92+
pytest.skip("b12x FP4 GEMM requires SM120+ (CC 12.0+)")
93+
if backend == "b12x" and not has_flashinfer_b12x_gemm():
94+
pytest.skip("b12x FP4 GEMM backend not available in installed FlashInfer")
9095

9196
set_random_seed(seed)
9297
m, n, packed_k = shape
@@ -105,8 +110,7 @@ def test_flashinfer_nvfp4_gemm(
105110

106111
# ops.scaled_fp4_quant returns swizzled scales, while weights
107112
# from checkpoints are in linear scales.
108-
# So instead of needing to swizzle for cutlass as in modelopt.py,
109-
# we need to unswizzle for trtllm here.
113+
# cutlass and b12x use swizzled scales directly; trtllm needs them unswizzled.
110114
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
111115
a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend
112116
)

vllm/config/kernel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def with_default(
117117
"flashinfer_trtllm",
118118
"flashinfer_cutlass",
119119
"flashinfer_cutedsl",
120+
"flashinfer_b12x",
120121
"marlin",
121122
"humming",
122123
"triton_unfused",
@@ -149,6 +150,8 @@ class KernelConfig:
149150
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels
150151
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
151152
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
153+
- "flashinfer_b12x": Use FlashInfer CuteDSL fused MoE for SM12x
154+
(RTX Pro 6000 / DGX Spark)
152155
- "marlin": Use Marlin kernels (weight-only quantization)
153156
- "humming": Use Humming Mixed Precision kernels
154157
- "triton_unfused": Use Triton unfused MoE kernels

vllm/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,7 @@ def _get_or_set_default() -> str:
15091509
"VLLM_NVFP4_GEMM_BACKEND",
15101510
None,
15111511
[
1512+
"flashinfer-b12x",
15121513
"flashinfer-cudnn",
15131514
"flashinfer-trtllm",
15141515
"flashinfer-cutlass",

vllm/model_executor/kernels/linear/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
FbgemmNvFp4LinearKernel,
8989
)
9090
from vllm.model_executor.kernels.linear.nvfp4.flashinfer import (
91+
FlashInferB12xNvFp4LinearKernel,
9192
FlashInferCudnnNvFp4LinearKernel,
9293
FlashInferCutlassNvFp4LinearKernel,
9394
FlashInferTrtllmNvFp4LinearKernel,
@@ -263,6 +264,9 @@
263264

264265
_POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
265266
PlatformEnum.CUDA: [
267+
# FlashInferB12xNvFp4LinearKernel excluded from auto-selection until
268+
# upstream CUTLASS SM121 MMA op guard is resolved; use
269+
# VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x to opt in explicitly.
266270
FlashInferCutlassNvFp4LinearKernel,
267271
CutlassNvFp4LinearKernel,
268272
MarlinNvFp4LinearKernel,
@@ -607,6 +611,7 @@ def init_wfp8_a16_linear_kernel(
607611

608612
# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
609613
_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
614+
"flashinfer-b12x": FlashInferB12xNvFp4LinearKernel,
610615
"flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel,
611616
"cutlass": CutlassNvFp4LinearKernel,
612617
"marlin": MarlinNvFp4LinearKernel,
@@ -784,6 +789,7 @@ def register_linear_kernel(
784789
"CutlassNvFp4LinearKernel",
785790
"EmulationNvFp4LinearKernel",
786791
"FbgemmNvFp4LinearKernel",
792+
"FlashInferB12xNvFp4LinearKernel",
787793
"FlashInferCutlassNvFp4LinearKernel",
788794
"FlashInferTrtllmNvFp4LinearKernel",
789795
"FlashInferCudnnNvFp4LinearKernel",

0 commit comments

Comments
 (0)