import torch
from vllm.third_party.deep_gemm import get_paged_mqa_logits_metadata
batch_size = 8192
block_size = 64
num_sms = torch.cuda.get_device_properties(0).multi_processor_count
context_lens = torch.full((batch_size, 1), 4096, dtype=torch.int32, device="cuda")
indices = torch.arange(batch_size, dtype=torch.int32, device="cuda") // 2
get_paged_mqa_logits_metadata(
context_lens,
block_size,
num_sms,
indices=indices,
)
torch.cuda.synchronize()
NVCC compilation failed: ptxas error : Entry function '_ZN9deep_gemm5sched30smxx_paged_mqa_logits_metadataILj8192ELj256ELj152ELb1EEEvjjbPKjS3_Pj' uses too much shared data (0x18004 bytes, 0xc000 max)
Traceback (most recent call last):
File "/home/thien/debug.py", line 12, in <module>
get_paged_mqa_logits_metadata(
RuntimeError: Assertion error (/workspace/.deps/deepgemm-src/csrc/apis/../jit_kernels/impls/../../jit/compiler.hpp:228): false and "NVCC compilation failed"
This is caused by static shared memory being linear with batch size
|
__shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; |
|
__shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; |
Would it possible to make it support large batch size instead? Either with dynamic shared memory or reimplement in a way that doesn't need that much shared memory.
The context is that I'm trying to use fp8_fp4_paged_mqa_logits() for prefill case, which requires calling this metadata preparation kernel. Thank you.
This is caused by static shared memory being linear with batch size
DeepGEMM/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh
Lines 19 to 20 in 891d57b
Would it possible to make it support large batch size instead? Either with dynamic shared memory or reimplement in a way that doesn't need that much shared memory.
The context is that I'm trying to use
fp8_fp4_paged_mqa_logits()for prefill case, which requires calling this metadata preparation kernel. Thank you.