[Fix][Kernel] gqa_decode: cap tile config for Metax 64 KiB shared memory#35
[Fix][Kernel] gqa_decode: cap tile config for Metax 64 KiB shared memory#35ventijing wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces helper functions to estimate dynamic shared memory usage and adjusts GQA decode kernel configurations specifically for MetaX MACA devices to respect their 64 KiB shared memory limit. The feedback highlights two critical issues: first, the default configuration on MetaX can still exceed the shared memory limit depending on the head dimension, which would cause runtime failures; second, the autotuning search space should be expanded to include smaller block sizes (such as 32) on MetaX to ensure that valid, launchable configurations can always be found.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def default_config(self) -> dict: | ||
| block_H = 64 | ||
| num_split = self._default_num_split() | ||
| if _is_metax(): | ||
| # Metax MACA devices expose 64 KiB shared memory per block. The | ||
| # CUDA-oriented default (block_N=128, num_stages=2) double-buffers | ||
| # K/V and exceeds that budget. | ||
| block_N = 64 | ||
| num_stages = 0 | ||
| else: | ||
| block_N = 128 | ||
| num_stages = 2 | ||
| return { | ||
| "block_H": 64, | ||
| "block_N": 128, | ||
| "num_split": self._default_num_split(), | ||
| "num_stages": 2, | ||
| "block_H": block_H, | ||
| "block_N": block_N, | ||
| "num_split": num_split, | ||
| "num_stages": num_stages, | ||
| "threads": 128, | ||
| } |
There was a problem hiding this comment.
On MetaX MACA devices, the shared memory limit is strictly 64 KiB (65,536 bytes). With the default configuration of block_H = 64, block_N = 64, and num_stages = 0, if the head dimension dim is 128 and valid_block_H is 64, the estimated shared memory is 16384 (Q) + 32768 (K/V) + 16384 (O) + 64 (split_bytes) = 65600 bytes, which exceeds the 64 KiB limit. If dim is 256, it exceeds the limit even further. This will cause a runtime shared memory limit error when launching the kernel with the default configuration.
We should dynamically adjust block_H and block_N in default_config to ensure the default configuration is always safe and launchable.
@property
def default_config(self) -> dict:
block_H = 64
num_split = self._default_num_split()
if _is_metax():
# Metax MACA devices expose 64 KiB shared memory per block. The
# CUDA-oriented default (block_N=128, num_stages=2) double-buffers
# K/V and exceeds that budget.
block_N = 64
num_stages = 0
# Dynamically reduce block sizes if the default config exceeds the shared memory limit
kv_group_num = self.heads // self.groups
valid_block_H = min(block_H, kv_group_num)
elem_bytes = 2 if self.dtype_str in ("float16", "bfloat16") else 4
smem = _gqa_decode_shared_memory_bytes(
block_H,
block_N,
self.dim,
num_stages,
valid_block_H=valid_block_H,
num_split=num_split,
elem_bytes=elem_bytes,
)
smem_limit = _shared_memory_limit_bytes()
if smem > smem_limit:
block_H = 32
smem = _gqa_decode_shared_memory_bytes(
block_H,
block_N,
self.dim,
num_stages,
valid_block_H=min(block_H, kv_group_num),
num_split=num_split,
elem_bytes=elem_bytes,
)
if smem > smem_limit:
block_N = 32
else:
block_N = 128
num_stages = 2
return {
"block_H": block_H,
"block_N": block_N,
"num_split": num_split,
"num_stages": num_stages,
"threads": 128,
}| block_N = [64, 128] | ||
| block_H = [64] | ||
| num_split = [ns for ns in [2, 4, 8, 16, 32] if ns <= self.seqlen_kv] or [1] | ||
| num_stages = [1, 2, 3] | ||
| num_stages = [0, 1, 2, 3] if _is_metax() else [1, 2, 3] | ||
| threads = [128] | ||
| _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) |
There was a problem hiding this comment.
On MetaX MACA devices, if the head dimension dim is 128 or 256, all configurations with block_H = 64 and block_N = 64 or 128 might exceed the 64 KiB shared memory limit. If all autotune configurations are filtered out, autotune_configs will fall back to [self.default_config], which also exceeds the limit and will fail at runtime.
To prevent this, we should expand the autotuning search space on MetaX to include smaller block sizes (e.g., block_H = [32, 64] and block_N = [32, 64, 128]). This allows the autotuner to find a valid, high-performing configuration that fits within the hardware limits.
| block_N = [64, 128] | |
| block_H = [64] | |
| num_split = [ns for ns in [2, 4, 8, 16, 32] if ns <= self.seqlen_kv] or [1] | |
| num_stages = [1, 2, 3] | |
| num_stages = [0, 1, 2, 3] if _is_metax() else [1, 2, 3] | |
| threads = [128] | |
| _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) | |
| @property | |
| def autotune_configs(self) -> list[dict]: | |
| block_N = [32, 64, 128] if _is_metax() else [64, 128] | |
| block_H = [32, 64] if _is_metax() else [64] | |
| num_split = [ns for ns in [2, 4, 8, 16, 32] if ns <= self.seqlen_kv] or [1] | |
| num_stages = [0, 1, 2, 3] if _is_metax() else [1, 2, 3] | |
| threads = [128] |
Use block_N=64 and num_stages=0 on Metax/MACA by default, and filter autotune candidates by estimated shared memory so decode kernels launch within the device limit.
Use block_N=64 and num_stages=0 on Metax/MACA by default, and filter autotune candidates by estimated shared memory so decode kernels launch within the device limit.