Skip to content

Add tuned MoE Triton configs for MetaX C500 (MXC500)#313

Open
LindseyMei wants to merge 1 commit into
MetaX-MACA:releases/v0.13.0from
LindseyMei:feat/moe-c500-configs
Open

Add tuned MoE Triton configs for MetaX C500 (MXC500)#313
LindseyMei wants to merge 1 commit into
MetaX-MACA:releases/v0.13.0from
LindseyMei:feat/moe-c500-configs

Conversation

@LindseyMei

Copy link
Copy Markdown

Summary

Add tuned two-stage Triton MoE configs for MetaX C500 (device_name=MXC500) covering three popular MoE shapes that currently fall back to the generic default config and emit the Using default MoE config. Performance might be sub-optimal! warning.

Shape (H,E,N) Top-K Representative model Architecture
2048, 60, 1408 4 Qwen/Qwen1.5-MoE-A2.7B Qwen2MoeForCausalLM
2048, 64, 1408 6 deepseek-ai/DeepSeek-V2-Lite DeepseekV2ForCausalLM
2048, 128, 768 8 Qwen/Qwen3-30B-A3B Qwen3MoeForCausalLM

Problem

vllm_metax selects the fused-MoE Triton tile from JSON configs keyed by (H, E, N, device_name[, dtype]). When no matching tuned config exists, it falls back to get_default_config, which is noticeably slower for prefill / large-batch shapes.

Method

The upstream benchmarks/kernels/benchmark_moe.py tunes the upstream fused_experts, but on MACA the runtime path goes through the OOT vllm_metax.model_executor.layers.fused_moe.fused_moe.fused_experts (verified to be a different object). Therefore I wrote a small micro-benchmark (moe_tune.py) that:

  • Forces a candidate tile via vllm.model_executor.layers.fused_moe.override_config.
  • Times a single vllm_metax fused_experts() call with random weights (no model download needed).
  • Uses CUDA events + torch.cuda.synchronize(), warmup + median over repeated iters.
  • Wraps the winning flat tile into the MetaX two-stage schema (stage1/stage2 + ACCF32/SPLIT_K/pipeline/scenario).

Correctness was verified with torch.allclose(rtol=2e-2, atol=2e-2) between default-tile and tuned-tile outputs, and config pickup was confirmed with get_moe_configs().

Results (kernel-level, MetaX C500)

Model shape M Default (ms) Tuned (ms) Speedup
Qwen1.5-MoE E=60,N=1408 64 1.154 0.915 1.26x
256 1.291 1.025 1.26x
1024 2.340 1.567 1.49x
2048 4.559 2.589 1.76x
DeepSeek-V2-Lite E=64,N=1408 128 1.260 1.046 1.21x
512 1.642 1.311 1.25x
1024 3.163 1.720 1.84x
2048 6.793 3.081 2.20x
Qwen3-30B-A3B E=128,N=768 256 1.505 1.189 1.27x
1024 2.650 1.850 1.43x
2048 4.424 2.755 1.61x

Small-M shapes are kept at or near the default tile to avoid decode-latency regressions.

Testing environment

  • MetaX C500 (currently visible as an sGPU slice: 16 GB VRAM / 25 % compute due to the test container configuration)
  • MACA 3.3.0.15
  • vllm 0.13.1.dev0 / vllm_metax 0.13.0 / mcoplib 0.3.1
  • torch 2.8.0+metax3.3.0.2

The tile shapes (BLOCK_SIZE / warps / stages) are per-SM properties and should transfer to a full C500; the grid-level parameters (GROUP_SIZE_M=1, SPLIT_K=1) are conservative. I would welcome a maintainer running the same moe_tune.py on a full C500 to cross-check, but the current data already shows clear, reproducible wins on real MACA hardware.

Notes

Signed-off-by: LindseyMei 648816901@qq.com

Add two-stage Triton configs for shapes missing on MXC500:
- H=2048,E=60,N=1408  (Qwen1.5-MoE-A2.7B, Qwen2MoeForCausalLM)
- H=2048,E=64,N=1408  (DeepSeek-V2-Lite, DeepseekV2ForCausalLM)
- H=2048,E=128,N=768  (Qwen3-30B-A3B, Qwen3MoeForCausalLM)

Configs were tuned on the actual vllm_metax fused-MoE Triton kernel
(rather than upstream fused_experts), using CUDA events with
synchronization. Speedups vs default tile range from 1.17x to 2.20x
for M>=64; correctness verified with torch.allclose.

Signed-off-by: LindseyMei <648816901@qq.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces three new Triton kernel configuration JSON files for Fused MoE on the MXC500 device. The feedback points out several parameter inconsistencies in the H=2048,E=64,N=1408,device_name=MXC500.json configuration. Specifically, it is recommended to reduce GROUP_SIZE_M from 16 to 1 for batch size 1, and to decrease num_warps from 8 to 4 for batch sizes 1, 8, and 64 to ensure consistency with other small batch sizes and avoid low occupancy.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The GROUP_SIZE_M parameter is set to 16 for batch size 1, which is inconsistent with all other batch sizes (which use 1) and has no benefit since there is only one tile in the M dimension. It should be set to 1.

Suggested change
"GROUP_SIZE_M": 16,
"GROUP_SIZE_M": 1,

"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The GROUP_SIZE_M parameter is set to 16 for batch size 1 in stage 2, which is inconsistent with all other batch sizes (which use 1) and has no benefit since there is only one tile in the M dimension. It should be set to 1.

Suggested change
"GROUP_SIZE_M": 16,
"GROUP_SIZE_M": 1,

"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_warps parameter is set to 8 for batch size 1, which is inconsistent with other small batch sizes like 16 and 32 (which use 4) and can lead to low occupancy and overhead. It should be set to 4.

Suggested change
"num_warps": 8,
"num_warps": 4,

"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_warps parameter is set to 8 for batch size 1 in stage 2, which is inconsistent with other small batch sizes like 16 and 32 (which use 4) and can lead to low occupancy and overhead. It should be set to 4.

Suggested change
"num_warps": 8,
"num_warps": 4,

"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_warps parameter is set to 8 for batch size 8, which is inconsistent with other small batch sizes like 16 and 32 (which use 4) and can lead to low occupancy and overhead. It should be set to 4.

Suggested change
"num_warps": 8,
"num_warps": 4,

"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_warps parameter is set to 8 for batch size 8 in stage 2, which is inconsistent with other small batch sizes like 16 and 32 (which use 4) and can lead to low occupancy and overhead. It should be set to 4.

Suggested change
"num_warps": 8,
"num_warps": 4,

"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_warps parameter is set to 8 for batch size 64, which is inconsistent with other small batch sizes like 16 and 32 (which use 4) and can lead to low occupancy and overhead. It should be set to 4.

Suggested change
"num_warps": 8,
"num_warps": 4,

"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The num_warps parameter is set to 8 for batch size 64 in stage 2, which is inconsistent with other small batch sizes like 16 and 32 (which use 4) and can lead to low occupancy and overhead. It should be set to 4.

Suggested change
"num_warps": 8,
"num_warps": 4,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant