[ttx/npu] Optimize lightning_indexer_kernel by deferring k_scale to post-dot#317
[ttx/npu] Optimize lightning_indexer_kernel by deferring k_scale to post-dot#317lyujheng wants to merge 2 commits into
Conversation
…st-dot scaling Move k_scale from pre-scaling K to post-dot application to better utilize AIC/AIV core memory bandwidth.
There was a problem hiding this comment.
Code Review
This pull request optimizes the lightning_indexer_kernel by moving the k_scale multiplication to occur after the dot product instead of before. This change reduces the computational overhead of element-wise multiplications by performing the scaling on the result of the dot product. I have no further feedback to provide.
|
Hi @wwens7, could you help review this PR when you have a chance? This PR optimizes The change is straightforward — moving the scaling from pre-dot element-wise multiply to post-dot broadcast, reducing AIC-AIV data transfer overhead by avoiding global memory round-trips. Thanks! |
Description
Optimize
lightning_indexer_kernelby deferringk_scaleapplication from pre-scaling K tensorto post-dot product scaling, fully leveraging AIC and AIV core memory bandwidth.
Changes
k = k * k_scale[:, None])k_scaleafter QK dot product as a lightweight broadcast multiplyPerformance
Benchmark on Ascend 910C NPU and triton-ascend 3.2.0 (device latency):
Average speedup: ~3.6x (bf16/fp16 gains significantly higher than fp32).
Accuracy
All accuracy tests passed across all test shapes and dtypes (bf16/fp16/fp32).