Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
e1715aa
Fix DeepSeek V4 MLA prefix cache reuse
jasl May 6, 2026
627d922
Add Blackwell tuning config aliases
jasl May 5, 2026
7708c7f
Add portable sparse MLA Triton kernels
jasl May 6, 2026
b9f87df
Add DeepSeek V4 SM12x fallback ops
jasl May 6, 2026
fc0b613
Route SM12x DeepGEMM fallbacks
jasl May 6, 2026
4408069
Wire SM12x sparse MLA into DeepSeek V4
jasl May 6, 2026
0aa6a9e
Reduce DeepSeek V4 load overhead on GB10
jasl May 6, 2026
cc2a365
Apply weight filter to fast safetensors loading
jasl May 6, 2026
2be6fc6
Warm DeepSeek V4 startup kernels
jasl May 5, 2026
3a8d063
Add SM12x sparse MLA direct decode kernels
jasl May 6, 2026
9c0a8ed
Stabilize DeepSeek V4 MTP scheduling
jasl May 5, 2026
c6e2035
Warm DeepSeek V4 MTP spec-decode kernels
jasl May 8, 2026
43a9c4b
Tune dense FP8 block-scaled GEMM configs for SM12x DSv4
jasl May 11, 2026
38c3813
T1-D: adaptive BLOCK_M for _fp8_paged_mqa_logits_kernel (SM12x)
jasl May 11, 2026
1847275
T2-A: clamp BLOCK_D in sparse MLA finish kernel to head_dim
jasl May 11, 2026
e7f4d28
Extend DeepSeek V4 prefill warmup to max single-chunk size
jasl May 12, 2026
d607014
Extend DeepSeek V4 warmup coverage to multi-request shapes
jasl May 12, 2026
9b19586
Restore rowwise paged-MQA logits kernel for SM12x long context
jasl May 13, 2026
581061c
reasoning: defensive implicit </think> for DeepSeek V4 tool-call stre…
jasl May 13, 2026
36d9d13
sm12x: keep @torch.compile on HC head reduction via free-function wra…
jasl May 14, 2026
80b7c15
sm12x: drop multi-request prefill warmup that crashes CUTeDSL kv-gather
jasl May 14, 2026
55bdbd2
sm12x: drop vestigial cudagraph kill-switch on Triton sparse MLA
jasl May 14, 2026
8667c08
sm12x: harden sparse_attn_indexer seq_lens slice with .contiguous()
jasl May 14, 2026
21aeac5
sm12x: autotune num_warps on fp8_einsum + fused_indexer_q kernels
jasl May 14, 2026
e0220a9
sm12x: autotune num_warps/num_stages on 3 sparse MLA accumulate kernels
jasl May 14, 2026
1c192e9
sm12x: add 3 dense FP8 W8A8 Block configs for RTX PRO 6000 WS Edition
jasl May 14, 2026
214a77f
sm12x: cap C128A metadata kernel loop at effective_topk (no shape cha…
jasl May 14, 2026
5e607d8
sm12x: per-token early-loop-exit on sparse MLA accumulate inner candi…
jasl May 15, 2026
05306e3
sm12x: docs cleanup pass 1 — clarify metadata + MLA manager docstrings
jasl May 15, 2026
099baaf
sm12x: docs cleanup pass 2 — dedupe _upcast_e8m0_to_fp32 + simplify s…
jasl May 15, 2026
68f7b5f
sm12x: docs cleanup pass 3 — drop tautological is_valid in 7 accumula…
jasl May 15, 2026
c2a525b
sm12x: multi-head prefill accumulate kernel + drop fp8 einsum autotune
alexbi29 May 16, 2026
f9a8d15
sm12x: add fused-MoE FP8 W8A8 Block configs for RTX PRO 6000 (4 shape…
jasl May 17, 2026
b250225
sm120: use Triton MQA logits for direct topk fallback
jasl May 18, 2026
765b45a
sm120: use custom row topk for MQA fallback indices
jasl May 18, 2026
5695ee5
sm120: widen FP8 MQA logits tile
jasl May 18, 2026
5a9bf16
sm120: increase FP8 MQA logits row tile
jasl May 18, 2026
bb01b77
Fix DeepSeek V4 MTP sparse SWA reordering
jasl May 19, 2026
53549d9
sm12x: update DeepSeek V4 fallback imports
jasl May 19, 2026
dbc1e50
tests: update DeepSeek V4 MegaMoE refactor assumptions
jasl May 19, 2026
0455029
Fix DeepSeek V4 MLA prompt cache protection
jasl May 19, 2026
794f59c
Clean up DeepSeek V4 upstream rebase leftovers
jasl May 19, 2026
e2f1f28
Fix CUTeDSL availability probe
jasl May 19, 2026
55b5344
Fix DeepSeek V4 MTP small-batch graph hangs
jasl May 19, 2026
58a8410
Remove ineffective DeepSeek V4 mHC warmup
jasl May 19, 2026
73d389a
Tune SM120 FP8 MQA logits row tile
jasl May 19, 2026
b82fc53
Clean up SM120 rebase leftovers
jasl May 20, 2026
9bb9c4c
Remove unused SM120 splitKV decode experiment
jasl May 20, 2026
e14e412
Limit long prefill chunks behind active decode
jasl May 21, 2026
2c4c8d5
Tighten mixed prefill cap for very long prompts
jasl May 21, 2026
3f0be3b
Improve SM120 mixed prefill scheduling
jasl May 21, 2026
e874456
Clean up DeepSeek V4 reasoning parser lint
jasl May 22, 2026
a936769
Add DeepSeek V4 prefix cache pressure regression
jasl May 23, 2026
f8acd45
Keep hybrid prefix cache tail blocks
jasl May 23, 2026
ef7b24f
Stabilize SM12x sparse MLA long prefill
jasl May 24, 2026
fa00406
Tune SM12x sparse MLA single prefill topk
jasl May 24, 2026
e65a437
Protect active decode from very long prefill
jasl May 25, 2026
37a5d30
Clean sparse SWA imports after rebase
jasl May 27, 2026
c5aa455
Guard SM120 FP4 sparse indexer dependency
jasl May 27, 2026
7031dca
Absorb SM120 external Marlin fixes
jasl May 27, 2026
27b1edc
sm120: keep optimized MHC prenorm path without DeepGEMM
jasl May 28, 2026
d0ff141
sm12x: prune fallback tests and tuned config duplicates
jasl May 28, 2026
fef328e
sm12x: clear MXFP4 loading cache after setup
jasl May 29, 2026
8b94c27
sm12x: drop obsolete MHC CustomOp wrapper
jasl May 29, 2026
0b5a61c
Protect running prefills from long prefill starvation
jasl May 31, 2026
f11715c
Add chunked SM120 direct MQA top-k fallback
jasl May 31, 2026
de72673
Protect later running decodes from long prefill starvation
jasl Jun 1, 2026
fbc8666
Protect very-long prefill fairness
jasl Jun 1, 2026
9a95dcf
sm12x: avoid MHC prenorm GEMM JIT per token count
jasl Jun 1, 2026
a9cd769
test: adapt DS4 prefix cache tests to scheduler block size
jasl Jun 2, 2026
270388b
fix: export DeepSeek V4 FusedMoE metadata
jasl Jun 3, 2026
696c756
sched: defer very long prefill under decode pressure
jasl Jun 3, 2026
86594a0
sm12x: support FlashInfer CUTLASS MXFP4 opt-in
jasl Jun 2, 2026
097b913
sm12x: add sparse MLA prefill D512 split prototype
jasl Jun 2, 2026
0a3bb9f
sm12x: warm high-concurrency MTP decode workspace
jasl Jun 3, 2026
da7a7b4
sm12x: enable indexed D512 sparse MLA prefill by default
jasl Jun 4, 2026
d1c0d15
sm12x: retune D512 sparse MLA split tiles
jasl Jun 4, 2026
f178f86
fix: align prefix cache manager signatures after rebase
jasl Jun 4, 2026
72c06ad
sm12x: skip empty D512 sparse MLA tail blocks
jasl Jun 4, 2026
cd14412
sm12x: clean sparse MLA rebase leftovers
jasl Jun 5, 2026
025e942
sm12x: restore DeepSeek V4 O-proj FP8 einsum layout
jasl Jun 5, 2026
2cd728a
sm12x: restore Triton sparse MLA decode dispatch
jasl Jun 5, 2026
5de1c6c
config: skip breakable cudagraph auto-enable on SM121
jasl Jun 5, 2026
d7f6f00
sm12x: restore sparse MLA prefill stats
jasl Jun 5, 2026
d6da156
deepseek-v4: preserve ubatch prefill metadata
jasl Jun 5, 2026
a66054f
deepseek-v4: defunctionalize fused MLA insert op
jasl Jun 5, 2026
a30f7a6
[Bugfix] DeepSeek V4 reasoning parser: don't split DSML tool-call mar…
tobymao Jun 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions csrc/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int device_max_shared_mem = max_shared_mem;

int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
Expand Down Expand Up @@ -519,10 +520,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
}

cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
device_max_shared_mem);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
kernel<<<blocks, num_threads, sh_cache_size, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
Expand Down Expand Up @@ -691,9 +692,8 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor c_tmp;
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
(long)size_n * sorted_token_ids.size(0),
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
long max_c_tmp_size =
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n;
if (moe_block_size == 8) max_c_tmp_size *= 2;
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
} else {
Expand Down
60 changes: 60 additions & 0 deletions tests/compile/passes/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,72 @@ def ops_not_in_model(self):
return []


class TestFusedDeepseekV4QnormRopeKvInsert(torch.nn.Module):
OP_REGISTERED = False

def __init__(self):
super().__init__()
self.register_test_custom_op()

@classmethod
def register_test_custom_op(cls):
if not cls.OP_REGISTERED:

def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl(
q: torch.Tensor,
kv: torch.Tensor,
k_cache: torch.Tensor,
) -> None:
q.add_(kv)
k_cache.add_(kv)

def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake(
q: torch.Tensor,
kv: torch.Tensor,
k_cache: torch.Tensor,
) -> None:
return None

direct_register_custom_op(
op_name="fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
op_func=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl,
mutates_args=["q", "k_cache"],
fake_impl=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake,
)

cls.OP_REGISTERED = True

def forward(
self, q: torch.Tensor, kv: torch.Tensor, k_cache: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q, kv, k_cache
)
return q, k_cache

def example_inputs(self, num_tokens=32, hidden_size=128):
return (
torch.rand(num_tokens, hidden_size),
torch.rand(num_tokens, hidden_size),
torch.rand(num_tokens, hidden_size),
)

def ops_in_model(self, do_fusion):
return [
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default
]

def ops_not_in_model(self):
return []


MODELS_AND_DO_FUSION = {
TestSiluMul: [True, False],
TestFusedAddRMSNorm: [True, False],
TestRotaryEmbedding: [False],
TestRotaryEmbeddingSliceScatter: [False],
TestFunctionWithMutatedArgsAndReturn: [False],
TestFusedDeepseekV4QnormRopeKvInsert: [False],
}


Expand Down
32 changes: 32 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,38 @@ def test_cudagraph_sizes_post_init(
)


def test_spec_decode_cudagraph_sizes_keep_small_full_decode_batches_exact():
config = CompilationConfig(
cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE,
cudagraph_capture_sizes=[
1,
2,
4,
8,
16,
24,
32,
40,
48,
56,
64,
72,
80,
88,
96,
],
max_cudagraph_capture_size=96,
)

config.adjust_cudagraph_sizes_for_spec_decode(
uniform_decode_query_len=3,
tensor_parallel_size=1,
)

for num_reqs in range(1, 33):
assert 3 * num_reqs in config.cudagraph_capture_sizes


@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
Expand Down
50 changes: 50 additions & 0 deletions tests/config/test_deepseek_v4_cudagraph_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

from vllm.config.vllm import _should_auto_enable_deepseek_v4_breakable_cudagraph
from vllm.platforms import current_platform


def _model_config(*architectures: str):
return SimpleNamespace(architectures=list(architectures))


def test_deepseek_v4_auto_enables_breakable_cudagraph_off_sm121(monkeypatch):
monkeypatch.setattr(
current_platform,
"is_device_capability",
lambda capability, device_id=0: False,
)

assert _should_auto_enable_deepseek_v4_breakable_cudagraph(
_model_config("DeepseekV4ForCausalLM")
)
assert _should_auto_enable_deepseek_v4_breakable_cudagraph(
_model_config("DeepSeekV4MTPModel")
)


def test_deepseek_v4_skips_breakable_cudagraph_on_sm121(monkeypatch):
monkeypatch.setattr(
current_platform,
"is_device_capability",
lambda capability, device_id=0: capability == 121,
)

assert not _should_auto_enable_deepseek_v4_breakable_cudagraph(
_model_config("DeepseekV4ForCausalLM")
)


def test_non_deepseek_v4_does_not_auto_enable_breakable_cudagraph(monkeypatch):
monkeypatch.setattr(
current_platform,
"is_device_capability",
lambda capability, device_id=0: False,
)

assert not _should_auto_enable_deepseek_v4_breakable_cudagraph(
_model_config("Qwen3ForCausalLM")
)
173 changes: 173 additions & 0 deletions tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import sys
import types

import torch

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
mxfp4_mxfp8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
convert_weight_to_mxfp4_moe_kernel_format,
)


def _make_moe_config() -> FusedMoEConfig:
return FusedMoEConfig(
num_experts=2,
experts_per_token=1,
hidden_dim=16,
intermediate_size_per_partition=16,
num_local_experts=2,
num_logical_experts=2,
activation=MoEActivation.SILU,
device="cpu",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
routing_method=RoutingMethodType.TopK,
max_num_tokens=16,
)


def _make_experts(
*,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FlashInferExperts:
quant_config = mxfp4_mxfp8_moe_quant_config(
w1_scale=torch.ones((2, 32, 1), dtype=torch.float8_e4m3fn),
w2_scale=torch.ones((2, 16, 1), dtype=torch.float8_e4m3fn),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return FlashInferExperts(
moe_config=_make_moe_config(),
quant_config=quant_config,
)


def test_mxfp4_swiglu_parameters_stay_unset_without_quant_config() -> None:
experts = _make_experts()

assert experts.gemm1_alpha is None
assert experts.gemm1_beta is None
assert experts.gemm1_clamp_limit is None


def test_mxfp4_swiglu_parameters_follow_quant_config() -> None:
experts = _make_experts(
gemm1_alpha=1.25,
gemm1_beta=0.75,
gemm1_clamp_limit=5.5,
)

torch.testing.assert_close(experts.gemm1_alpha, torch.tensor([1.25, 1.25]))
torch.testing.assert_close(experts.gemm1_beta, torch.tensor([0.75, 0.75]))
torch.testing.assert_close(
experts.gemm1_clamp_limit,
torch.tensor([5.5, 5.5]),
)


def test_cutlass_mxfp8_kernel_format_converts_gate_up_layout(monkeypatch) -> None:
monkeypatch.setitem(
sys.modules,
"flashinfer",
types.SimpleNamespace(block_scale_interleave=lambda x: x.contiguous()),
)

num_experts = 1
intermediate_size = 64
hidden_size = 64
packed_hidden_size = hidden_size // 2
sf_block_size = 32

w13_weight = torch.arange(
num_experts * 2 * intermediate_size * packed_hidden_size,
dtype=torch.uint8,
).reshape(num_experts, 2 * intermediate_size, packed_hidden_size)
w2_weight = torch.arange(
num_experts * hidden_size * (intermediate_size // 2),
dtype=torch.uint8,
).reshape(num_experts, hidden_size, intermediate_size // 2)
w13_scale_u8 = torch.arange(
num_experts * 2 * intermediate_size * (hidden_size // sf_block_size),
dtype=torch.uint8,
).reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
w2_scale_u8 = torch.arange(
num_experts * hidden_size * (intermediate_size // sf_block_size),
dtype=torch.uint8,
).reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
w13_bias = torch.arange(
num_experts * 2 * intermediate_size,
dtype=torch.bfloat16,
).reshape(num_experts, 2 * intermediate_size)
w2_bias = torch.arange(
num_experts * hidden_size,
dtype=torch.bfloat16,
).reshape(num_experts, hidden_size)

(
out_w13,
out_w2,
out_w13_scale,
out_w2_scale,
out_w13_bias,
out_w2_bias,
) = convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
layer=torch.nn.Module(),
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_scale_u8.view(torch.float8_e4m3fn),
w2_weight_scale=w2_scale_u8.view(torch.float8_e4m3fn),
w13_bias=w13_bias,
w2_bias=w2_bias,
)

expected_w13 = torch.cat(
[
w13_weight[:, intermediate_size:, :],
w13_weight[:, :intermediate_size, :],
],
dim=1,
)
expected_w13_scale = torch.cat(
[
w13_scale_u8[:, intermediate_size:, :],
w13_scale_u8[:, :intermediate_size, :],
],
dim=1,
)
expected_w13_bias = torch.cat(
[
w13_bias[:, intermediate_size:],
w13_bias[:, :intermediate_size],
],
dim=1,
)

assert out_w13.is_contiguous()
assert out_w2.is_contiguous()
torch.testing.assert_close(out_w13, expected_w13)
torch.testing.assert_close(out_w2, w2_weight)
torch.testing.assert_close(out_w13_scale, expected_w13_scale)
torch.testing.assert_close(out_w2_scale, w2_scale_u8)
torch.testing.assert_close(out_w13_bias, expected_w13_bias)
torch.testing.assert_close(out_w2_bias, w2_bias)
Loading