[Feature] TRITON_MLA_SPARSE backend for SM8x/11x/12x DSA Sparse MLA Support#38476
[Feature] TRITON_MLA_SPARSE backend for SM8x/11x/12x DSA Sparse MLA Support#38476haosdent wants to merge 1 commit into
Conversation
|
Documentation preview: https://vllm--38476.org.readthedocs.build/en/38476/ |
There was a problem hiding this comment.
Code Review
This pull request introduces the TRITON_MLA_SPARSE attention backend, providing a Triton-based fallback for sparse MLA on GPUs like NVIDIA Ampere. It also refactors FP8 MQA logit fallbacks into a new module and updates the sparse attention indexer to use these PyTorch implementations when DeepGEMM is unsupported. A review comment suggests moving a module-level import to the top of the file to comply with PEP 8 guidelines.
| Logits tensor of shape [B * next_n, max_model_len], dtype | ||
| `torch.float32`. | ||
| """ | ||
| from vllm.utils.math_utils import cdiv |
There was a problem hiding this comment.
To adhere to PEP 8 guidelines, module-level imports should be placed at the top of the file. Please move this import statement to the top of the module, for example, after the from vllm.platforms import current_platform import. This improves code readability and consistency.
References
- PEP 8: E402 module level import not at top of file. Imports should be at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
FYI, we don’t plan to support a torch native mqa_logits implementation. I also question whether it’s necessary to support sparse MLA on SM80. |
Update: This PR is no longer a torch-native Ampere is still extremely commonly used... We need this for DS3.2 or GLM-5. |
Perhaps we can integrate Triton now? |
|
Hello, after I modified the code according to your PR, the GLM5 model service started normally. However, the response speed is very slow, with only about 3 tokens being responded to per second. My device is also an A800 with 80G of storage capacity. Is this normal? |
|
@workcode-del I believe that it's only expected to be anywhere remotely fast when no PyTorch fallbacks exist. |
Could you please explain how to achieve the condition where there are no PyTorch fallbacks? |
|
This pull request has merge conflicts that must be resolved before it can be |
|
@workcode-del |
|
I was able to use this patch to run GLM-5.1 on an 8-node DGX Spark cluster. Performance is obviously not stellar (~5 t/s) but it's a great first step with compatibility. |
17f68af to
6760f0c
Compare
|
Thanks @ianlevesque , your 8 x DGX Spark is incredible! I just add new triton kernels to try to address the performance issue, may you test again when you are available? |
|
@haosdent retried with the new patch, it did improve to 10 t/s or so. |
|
|
"Does this patch support turboquant-vllm? Since turboquant-vllm enables KV cache compression, I'm interested in its compatibility. I actually opened this issue: https://github.com/varjoranta/turboquant-vllm/issues/56. Could you please help me look into it?" @haosdent |
|
@haosdent Can we add support for GLM-5.2-NVFP4? It seems that it has a specific shape that needs to be added manually (576 MLA heads?). |
|
I think merging the main branch might solve it. |
Does single machine A100x8 SMX work for 96 concurrency? |
|
Thanks for this backend, @haosdent — it's the missing piece for sparse-MLA on Ampere. We used it to bring up GLM-5.2 ( 1. GLM-5.2 needs IndexShare (this PR builds the indexer per-layer)GLM-5.2 uses a per-layer The good news: the machinery already exists here — # GLM-5.2 IndexShare: layers with indexer_types[i] == "shared" reuse the previous
# "full" layer's top-k indices (via the shared topk_indices_buffer) instead of
# running their own indexer. Matches modeling_glm_moe_dsa.py:
# self.skip_topk = config.indexer_types[layer_idx] == "shared"
indexer_types = getattr(config, "indexer_types", None)
if indexer_types is not None:
_layer_idx = extract_layer_index(prefix)
if 0 <= _layer_idx < len(indexer_types) and indexer_types[_layer_idx] == "shared":
_skip_topk = True
2. Heads-up: a
|
|
@Kasempiternal Thank you so much for the contribution! |
|
Validation run on our 8×A100 box; investigation + write-up done with Claude Code. Independent confirmation: GLM-5.2 ( Setup
Proof it routes correctly (server startup log):
Throughput (512-tok prompt / 128-tok greedy gen,
~56 tok/s single-stream decode, plateauing near 625 tok/s aggregate at 32-way concurrency, TTFT sub-second throughout. For reference, llama.cpp (GGUF UD-Q4_K_XL) on the same 8x A100 does ~24.5 tok/s single-stream and saturates around 70 tok/s aggregate — so this path is ~2.3x single-stream and ~9x aggregate. Output is coherent (verified on arithmetic, Rayleigh scattering, and a "who wrote Hamlet" prompt that closed its Full write-up (weights -> cherry-pick + conflict resolution -> precompiled overlay install -> serve -> benchmarks): https://gist.github.com/timinar/c8d2eca4e2ea7d11db57a1e6e62d06a2. |
Is it meaningful for 32 concurrency for SM_8x? The 1M-context does not likely work under such concurrency size. |
Not really. It was max 32k context, I believe. So too short for multi-turn agentic tasks, but could be useful for some other types of local work. |
|
Actually, pipeline parallelism will multiply the context to the number of GPUs, due to GLM-5.x being MLA. |
I realized this topic assumes INT4 (Q4) rather than NVFP4. Based on any 4-bit solution, 1M context window on an 8xA100 SXM node is still possible. However, vLLM needs to additionally implement context parallelism rather than relying on pipeline parallelism or tensor parallelism. Don't be upset and keep up more enhancements. AIME benchmark requires at least 160K context to complete reasoning and 32K isn't long enough. My current trouble falls into how to reproduce their 99% scores on AIME-2026. |
|
In MLA models, the number of KV cache heads is considered 1. But in pipeline parallelism, there is no KV cache head duplication, so the KV cache gets multiplied by the number of GPUs. |
Pipeline parallel supports 1M context but suffers from low cross-GPU utilization. Tutel images just supported paged Context-Sparse MLA recently: https://hub.docker.com/r/tutelgroup/deepseek-671b |
|
Is there a way to use proper tool calling and reasoning parsers with Tutel? |
Claude Code connecting to Tutel will directly trigger GLM's tool calling. Without tool calls, I get only 90% for AIME-26 over GLM5.2-NVFP4, but it was >96% for GLM-5-NVFP4 and GLM-5.1-NVFP4. I don't know if NVFP4 dislike GLM-5.2 or not, so I wonder what the score would be if evaluating AIME-26 over Q4? Maybe officially claimed 99% for AIME-26 is never reproducible, however, I don't have enough-strong GPU environment (Hxx/Bxx) to evaluate GLM-5.2-BF16 or GLM-5.2-FP8. |
|
tool call parse fix |
| AttentionBackendEnum.FLASHMLA, | ||
| AttentionBackendEnum.FLASHINFER_MLA, | ||
| AttentionBackendEnum.TRITON_MLA, | ||
| AttentionBackendEnum.TRITON_MLA_SPARSE, |
There was a problem hiding this comment.
This backend shouldn't be prioritized over FlashMLA sparse, this will hurt SM90 performance. Please swap these two.
|
Great job! |
|
Hi @timinar How about cyankiwi/GLM-5.2-AWQ-INT4 accuracy recovery? Is it as smart as the non-quantized version of the model? |
I tested NVFP4 instead of INT4 version, without tool use, aime-2026 gets 90-93% only. Enabling tools improves it to > 99%. |
|
Has anyone encountered the following issue when trying to deploy GLM-5.2 on 8x A100 GPUs with Nvidia 535 drivers (cu129)?
This workaround feels too hacky and diverges significantly from the reference behavior in the gist (where things reportedly work out-of-the-box on 0.90 memory limit without Triton capturing crashes). Has anyone successfully run this without such JIT-warmup patches on the newer vLLM V1 engine? Any insights on why the memory footprint/graph profiling behavior differs so much from the reference would be greatly appreciated! UPD: The problem was with the container/main versions. Here's the working Dockerfile. |
|
@RefalMachine Thank you so much for sharing the |
The current code successfully deploys the model for me, however, after N minutes of operation under high load, NaN values randomly start appearing in the logits (it’s unclear where or why), resulting in the output degenerating into continuous '!!!' tokens. This can only be resolved by restarting. Unfortunately, I don't know how to fix this issue yet. If anyone has encountered this before, I would be grateful for any leads. |
|
I am able to run it on 40 GB A100 GPUs. Using the latest
As for the steps I had to combine a bunch of scattered pieces to make it work on my system:
|
Do you run this with 16 A100s? |
|
It seems @haosdent hasn’t been active on this project for some time. Really looking forward to the GLM-5.2 and DeepSeek V4 support once the rebase is done. |






Purpose
Closes #38006. Enables sparse MLA models (GLM-5, DeepSeek-V3.2) on SM80 (A100/A800) and SM121 (GB10/DGX Spark), where DeepGEMM / FlashMLA-Sparse / FlashInfer-MLA-Sparse are unavailable.
Changes
is_deep_gemm_supported()(SM90+ check) replaceshas_deep_gemm()insparse_attn_indexer.py/indexer.py. Stops DeepGEMM kernels from being invoked on SM80/SM121.fp8_mqa_logitsfor the indexer.mqa_logits_triton.pyreproduces DeepGEMM's prefill + paged MQA logits. Prefill takes bf16 q/k (pre-decoded from FP8 in the Python wrapper) and feeds a straighttl.dot; paged decode keeps a 256-entry bf16 LUT for in-kernel FP8 decode. K-side scale applied to the fp32 dot output, per-row K-tile early-exit on the chunked-prefill path.TRITON_MLA_SPARSEbackend.triton_mla_sparse_kernel.pyadds a split-KV decode with N-way online-softmax merge plus a single-pass fast path. Autotune is warmed at init using indexer-derived(n_head, head_dim). Masked-out sentinel is-1e30to avoidNaNfrom(-inf) − (-inf)on all-masked tiles.TritonMLASparseMetadataBuilderadvertisesAttentionCGSupport.UNIFORM_BATCH; flips A100 TP=8 back toFULL_AND_PIECEWISE.mxfp4_experts_quant/silu_and_mul_mxfp4_experts_quantstubs innvfp4_quant_entry.cu. Real impls are SM10.x-only in CMake buttorch_bindings.cppreferences them unconditionally, which breaks source builds on SM 8.x.Benchmarks
8×A100 SXM TP=8,
lukealonso/GLM-5.1-NVFP4, single prompt, decode 200 tokens.cold= first request on a fresh prompt;warm= repeat (prefix cache hit):Tests
tests/kernels/attention/test_mqa_logits_triton.py— 41 cases (DeepGEMM reference + clean/dirtyclean_logits+ 256-byte FP8 decode).tests/kernels/attention/test_triton_mla_sparse_kernel.py— 53 cases (split vs single-pass + auto-heuristic + short-prefill no-NaN).Limitations (follow-up): BF16 KV cache only on SM80/SM121;
VLLM_BATCH_INVARIANTshould forcenum_kv_splits=1— not wired.