Skip to content

get_paged_mqa_logits_metadata() does not work with large batch size due to excessive static shared memory impl #322

@gau-nernst

Description

@gau-nernst
import torch

from vllm.third_party.deep_gemm import get_paged_mqa_logits_metadata

batch_size = 8192
block_size = 64
num_sms = torch.cuda.get_device_properties(0).multi_processor_count

context_lens = torch.full((batch_size, 1), 4096, dtype=torch.int32, device="cuda")
indices = torch.arange(batch_size, dtype=torch.int32, device="cuda") // 2

get_paged_mqa_logits_metadata(
    context_lens,
    block_size,
    num_sms,
    indices=indices,
)
torch.cuda.synchronize()
NVCC compilation failed: ptxas error   : Entry function '_ZN9deep_gemm5sched30smxx_paged_mqa_logits_metadataILj8192ELj256ELj152ELb1EEEvjjbPKjS3_Pj' uses too much shared data (0x18004 bytes, 0xc000 max)

Traceback (most recent call last):
  File "/home/thien/debug.py", line 12, in <module>
    get_paged_mqa_logits_metadata(
RuntimeError: Assertion error (/workspace/.deps/deepgemm-src/csrc/apis/../jit_kernels/impls/../../jit/compiler.hpp:228): false and "NVCC compilation failed"

This is caused by static shared memory being linear with batch size

__shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize];
__shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize];

Would it possible to make it support large batch size instead? Either with dynamic shared memory or reimplement in a way that doesn't need that much shared memory.

The context is that I'm trying to use fp8_fp4_paged_mqa_logits() for prefill case, which requires calling this metadata preparation kernel. Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions