Skip to content

optimization qwen3-vl-4b TTFT for gfx1150 with 2 448x448 image and 256 text token input #1012

Open
qingxuamd wants to merge 2 commits into
gfx11from
qingxu/qwen3-vl-optimize2
Open

optimization qwen3-vl-4b TTFT for gfx1150 with 2 448x448 image and 256 text token input #1012
qingxuamd wants to merge 2 commits into
gfx11from
qingxu/qwen3-vl-optimize2

Conversation

@qingxuamd

Copy link
Copy Markdown

Specifically optimize the kernel tile according to qwen3-vl-4b and input shape, without this PR, TTFT=1231 ms, with this PR, TTFT = 980 ms.

This optimization is for gfx1150.
The model of Qwen3-VL-4B-Instruct-AWQ-4bit-lm_head_int8, for
triton_w4a16_skinny_fmt_kernel, it cost more then 60% latency
in prefill. The required input = 2 448x448 img + 256 token,
then, tok num ~=660 tok.

M: ~660 (2x448x448 + 256 token  prefill)
N,K(hot shape):
660 x 19456 x 2560
660 x 2560 x 9728
660 x 6144 x 2560
660 x 2560 x 4096
warps, currently 8 is best num

so:
BLOCK_M=64
BLOCK_N=256
BLOCK_K=32
num_warps=8

GEMM:M=671,N/K as above
kernel tile:64,256,32,8

Signed-off-by: Xu Qing <qing.xu2@amd.com>
Add a Qwen3-VL-4B prefill shape guard in Triton unified attention on gfx11
and apply BM64/T32/W4/S1/EU4 defaults to reduce TTFT. Also fix prefill
tile override wiring so VLLM_UA_PREFILL_TILE_SIZE is honored instead of
being overwritten by the default path.

Signed-off-by: Xu Qing <qing.xu2@amd.com>
Comment on lines +278 to +287
# Profile-guided default for Qwen3-VL-like multimodal prefill.
qwen3_prefill_shapes = {
(19456, 2560), # gate_up_proj-like
(2560, 9728), # down_proj-like
(6144, 2560), # qkv_proj-like
(2560, 4096), # o_proj-like
}
if 576 <= M <= 832 and (N, K) in qwen3_prefill_shapes:
return 64, 256, min(32, group_size), 8

@mgehre-amd mgehre-amd Jun 22, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change already landed in https://github.com/ROCm/vllm/pull/1009/changes for bfloat16.
Could you please put this in the same shape, i.e.

        if on_gfx1103() and M > 256:
            # Tested on Qwen3-VL-4B-AWQ
            block_m, block_n, block_k, num_warps = 64, 256, 64, 8

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, would update it later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants