Skip to content

optimize for mlu grouprmsnorm#285

Open
yuer-cn wants to merge 2 commits into
masterfrom
opt_grouprmsnorm
Open

optimize for mlu grouprmsnorm#285
yuer-cn wants to merge 2 commits into
masterfrom
opt_grouprmsnorm

Conversation

@yuer-cn

@yuer-cn yuer-cn commented May 8, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the RMSNorm kernel by introducing autotuning for BLOCK_M, vectorizing operations with 2D blocks, and adding a tl.dot optimization path for MLU hardware. Feedback suggests constraining the autotuner to prevent hardware resource exhaustion when both block dimensions are large. Additionally, it is recommended to transpose the vector rather than the data matrix in the dot product for better efficiency and to correct a misleading error message in the new input size assertion.

@triton.jit

def cfggen():
block_m = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The autotuning configuration includes BLOCK_M values up to 2048. When combined with BLOCK_N up to 8192, the resulting 2D block size (2048 * 8192 = 16M elements) is likely to exceed the hardware's LRAM capacity or Triton's internal limits for register allocation, which may lead to compilation failures or significant performance degradation due to spilling. Consider capping BLOCK_M or adding a constraint to the autotuner to keep the total block size within reasonable limits (e.g., BLOCK_M * BLOCK_N <= 65536).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlu需要单次处理的数据量够大才能有效利用片上资源,tuning会选择一个可执行且性能最优的配置


if USE_DOT:
sq = x_f32 * x_f32
acc = tl.reshape(tl.dot(ones, tl.trans(sq), allow_tf32=False), (BLOCK_M,))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In Triton kernels, transposing a large data matrix in registers can be expensive. Since ones is a vector, it is more efficient to transpose ones instead of sq to achieve the same reduction sum. This avoids the overhead of transposing the (BLOCK_M, BLOCK_N) matrix while still leveraging the tl.dot optimization for MLU hardware.

acc = tl.reshape(tl.dot(sq, tl.trans(ones), allow_tf32=False), (BLOCK_M,))
References
  1. On MLU hardware, use tl.dot() for matrix multiplication in Triton kernels, as triton-mlu can utilize tensor cores for acceleration, even for small matrix dimensions.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

出于性能考虑,sq需作为权值矩阵参与计算

Comment thread mojo_opset/backends/ttx/kernels/mlu/group_rmsnorm.py Outdated
@yuer-cn yuer-cn force-pushed the opt_grouprmsnorm branch from e3a355f to f0f06a4 Compare May 15, 2026 01:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant