Skip to content

[Fix][Kernel] gqa_decode: cap tile config for Metax 64 KiB shared memory#35

Open
ventijing wants to merge 2 commits into
MetaX-MACA:devfrom
ventijing:gqa_dev
Open

[Fix][Kernel] gqa_decode: cap tile config for Metax 64 KiB shared memory#35
ventijing wants to merge 2 commits into
MetaX-MACA:devfrom
ventijing:gqa_dev

Conversation

@ventijing

Copy link
Copy Markdown

Use block_N=64 and num_stages=0 on Metax/MACA by default, and filter autotune candidates by estimated shared memory so decode kernels launch within the device limit.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces helper functions to estimate dynamic shared memory usage and adjusts GQA decode kernel configurations specifically for MetaX MACA devices to respect their 64 KiB shared memory limit. The feedback highlights two critical issues: first, the default configuration on MetaX can still exceed the shared memory limit depending on the head dimension, which would cause runtime failures; second, the autotuning search space should be expanded to include smaller block sizes (such as 32) on MetaX to ensure that valid, launchable configurations can always be found.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 446 to 464
def default_config(self) -> dict:
block_H = 64
num_split = self._default_num_split()
if _is_metax():
# Metax MACA devices expose 64 KiB shared memory per block. The
# CUDA-oriented default (block_N=128, num_stages=2) double-buffers
# K/V and exceeds that budget.
block_N = 64
num_stages = 0
else:
block_N = 128
num_stages = 2
return {
"block_H": 64,
"block_N": 128,
"num_split": self._default_num_split(),
"num_stages": 2,
"block_H": block_H,
"block_N": block_N,
"num_split": num_split,
"num_stages": num_stages,
"threads": 128,
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

On MetaX MACA devices, the shared memory limit is strictly 64 KiB (65,536 bytes). With the default configuration of block_H = 64, block_N = 64, and num_stages = 0, if the head dimension dim is 128 and valid_block_H is 64, the estimated shared memory is 16384 (Q) + 32768 (K/V) + 16384 (O) + 64 (split_bytes) = 65600 bytes, which exceeds the 64 KiB limit. If dim is 256, it exceeds the limit even further. This will cause a runtime shared memory limit error when launching the kernel with the default configuration.

We should dynamically adjust block_H and block_N in default_config to ensure the default configuration is always safe and launchable.

    @property
    def default_config(self) -> dict:
        block_H = 64
        num_split = self._default_num_split()
        if _is_metax():
            # Metax MACA devices expose 64 KiB shared memory per block.  The
            # CUDA-oriented default (block_N=128, num_stages=2) double-buffers
            # K/V and exceeds that budget.
            block_N = 64
            num_stages = 0
            
            # Dynamically reduce block sizes if the default config exceeds the shared memory limit
            kv_group_num = self.heads // self.groups
            valid_block_H = min(block_H, kv_group_num)
            elem_bytes = 2 if self.dtype_str in ("float16", "bfloat16") else 4
            smem = _gqa_decode_shared_memory_bytes(
                block_H,
                block_N,
                self.dim,
                num_stages,
                valid_block_H=valid_block_H,
                num_split=num_split,
                elem_bytes=elem_bytes,
            )
            smem_limit = _shared_memory_limit_bytes()
            if smem > smem_limit:
                block_H = 32
                smem = _gqa_decode_shared_memory_bytes(
                    block_H,
                    block_N,
                    self.dim,
                    num_stages,
                    valid_block_H=min(block_H, kv_group_num),
                    num_split=num_split,
                    elem_bytes=elem_bytes,
                )
                if smem > smem_limit:
                    block_N = 32
        else:
            block_N = 128
            num_stages = 2
        return {
            "block_H": block_H,
            "block_N": block_N,
            "num_split": num_split,
            "num_stages": num_stages,
            "threads": 128,
        }

Comment on lines 482 to 487
block_N = [64, 128]
block_H = [64]
num_split = [ns for ns in [2, 4, 8, 16, 32] if ns <= self.seqlen_kv] or [1]
num_stages = [1, 2, 3]
num_stages = [0, 1, 2, 3] if _is_metax() else [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

On MetaX MACA devices, if the head dimension dim is 128 or 256, all configurations with block_H = 64 and block_N = 64 or 128 might exceed the 64 KiB shared memory limit. If all autotune configurations are filtered out, autotune_configs will fall back to [self.default_config], which also exceeds the limit and will fail at runtime.

To prevent this, we should expand the autotuning search space on MetaX to include smaller block sizes (e.g., block_H = [32, 64] and block_N = [32, 64, 128]). This allows the autotuner to find a valid, high-performing configuration that fits within the hardware limits.

Suggested change
block_N = [64, 128]
block_H = [64]
num_split = [ns for ns in [2, 4, 8, 16, 32] if ns <= self.seqlen_kv] or [1]
num_stages = [1, 2, 3]
num_stages = [0, 1, 2, 3] if _is_metax() else [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
@property
def autotune_configs(self) -> list[dict]:
block_N = [32, 64, 128] if _is_metax() else [64, 128]
block_H = [32, 64] if _is_metax() else [64]
num_split = [ns for ns in [2, 4, 8, 16, 32] if ns <= self.seqlen_kv] or [1]
num_stages = [0, 1, 2, 3] if _is_metax() else [1, 2, 3]
threads = [128]

Use block_N=64 and num_stages=0 on Metax/MACA by default, and filter autotune candidates by estimated shared memory so decode kernels launch within the device limit.
@ventijing ventijing changed the title [MetaXGPU][Fix][Kernel] gqa_decode: cap tile config for Metax 64 KiB shared memory [Fix][Kernel] gqa_decode: cap tile config for Metax 64 KiB shared memory Jul 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant