Skip to content

Parameter cache for CUDA kernel launch#3598

Open
guoqingbao wants to merge 1 commit into
huggingface:mainfrom
guoqingbao:parameter_cache
Open

Parameter cache for CUDA kernel launch#3598
guoqingbao wants to merge 1 commit into
huggingface:mainfrom
guoqingbao:parameter_cache

Conversation

@guoqingbao

Copy link
Copy Markdown
Contributor

Summary

Candle uses htod to copy kernel launch parameters to the device, which makes CUDA graph capture inefficient because it captures tens of host-to-device operations in each forward pass. This PR addresses the kernel launch bottleneck by caching kernel parameters, delivering ~10% performance speedup across ALL models, as validated in vLLM.rs (xinfer): guoqingbao/xinfer#371.

Usage:

Before CUDA graph capture, enable the parameter cache and run a warmup step. This caches the kernel launch parameters:

let _guard = candle_core::cuda_backend::cuda_param_cache_scope(true);
for b in 0..bs {
    let _ = self.model.forward(&input_ids_bs, ...)?;
}

Then perform the actual CUDA graph capture:

for b in 0..bs {
   self.model.start_capture(b)?;
    let out = self.model.forward(&input_ids_bs, ...)?;
    // Save out.
    self.model.end_capture()?;
}

Results from xinfer:

Model Quantization Size GPU Conventional CUDA graph CUDA graph + parameter cache Speedup
Qwen3-30B-A3B NVFP4 30B MoE RTX 5090 175.30 tokens/s 181.59 tokens/s 3.59%
Gemma4-26B-A4B NVFP4 26B MoE RTX 5090 131.00 tokens/s 137.23 tokens/s 4.76%
Ministral-3-3B (Multimodal) ISQ (BF16→Q4K) 3B A100 171.92 tokens/s 193.67 tokens/s 12.65%
DeepSeek-R1-0528-Qwen3-8B Q4_K_M 8B A100 124.87 tokens/s 139.25 tokens/s 11.52%
Llama-3.1-8B ISQ (BF16→Q4K) 8B A100 120.74 tokens/s 133.10 tokens/s 10.24%
Qwen3-VL-8B-Instruct (Multimodal) Q8_0 8B A100 105.31 tokens/s 112.51 tokens/s 6.84%
Qwen3.6-35B-A3B (Multimodal) FP8 35B MoE Hopper 102.00 tokens/s 110.00 tokens/s 7.84%
Qwen3-30B-A3B NVFP4 30B MoE V100 67.10 tokens/s 72.86 tokens/s 8.58%
GLM-4-9B-0414 Q4_K_M 9B A100 70.38 tokens/s 77.48 tokens/s 10.09%
MiniMax-M2.5 NVFP4 229B MoE Hopper ×2 62.00 tokens/s 64.50 tokens/s 4.03%
Qwen3.5-27B (Multimodal) Q4_K_M 27B Dense Hopper 45.20 tokens/s 49.33 tokens/s 9.14%
Qwen3.5-27B/Qwen3.6-27B FP8 27B Dense Hopper 42.00 tokens/s 45.00 tokens/s 7.14%
QwQ-32B Q4_K_M 32B A100 41.36 tokens/s 46.02 tokens/s 11.27%
Gemma4-31B ISQ (BF16→Q4K) 31B Dense Hopper 41.00 tokens/s 47.00 tokens/s 14.63%

Note: This PR does not benefit candle-examples by default. It is intended for downstream projects that rely on Candle and support CUDA graphs.

@guoqingbao

Copy link
Copy Markdown
Contributor Author

@EricLBuehler Do you have time to try this? I think you've supported cuda graph in mistral.rs recently?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant