Skip to content

fix layernorm_fwd_kernel#347

Open
Todobe wants to merge 1 commit into
XPU-Forces:masterfrom
Todobe:fix_layernorm
Open

fix layernorm_fwd_kernel#347
Todobe wants to merge 1 commit into
XPU-Forces:masterfrom
Todobe:fix_layernorm

Conversation

@Todobe

@Todobe Todobe commented Jun 4, 2026

Copy link
Copy Markdown

This operator has an issue with tail block processing. It needs to calculate sum((x - mean(x))^2). When the dimension of x is not an integer multiple of the block size, the last block of x after partitioning becomes [normal value, normal value, normal value, 0]. Then, (x - mean(x)) ^2 becomes [normal value, normal value, normal value, mean(x)^2]. When performing the sum, this final mean(x)^2 is also included, causing an accuracy error.

Moreover, since the mean of x in the test cases is basically 0, this issue was not detected.

This issue can be resolved by adding a tl.where statement after the subtraction.

To reproduce this error, you need to modify the test file as follows(to make the mean of x not equal to 0)):

--- a/mojo_opset/tests/accuracy/operators/test_normalization.py
+++ b/mojo_opset/tests/accuracy/operators/test_normalization.py
@@ -75,7 +75,7 @@ def test_rmsnorm(shape, dtype, eps):
 @pytest.mark.parametrize("eps", [1e-5])
 @bypass_not_implemented
 def test_layernorm(shape, dtype, eps):
-    x = torch.randn(size=shape, dtype=dtype)
+    x = torch.randn(size=shape, dtype=dtype) + 1
     weight = torch.randn(size=(shape[-1],), dtype=dtype)
     bias = torch.randn(size=(shape[-1],), dtype=dtype)
     layernorm = MojoLayerNorm(eps=eps, norm_size=weight.size(0), dtype=weight.dtype, device=x.device)

Use pytest tests/accuracy/operators/test_normalization.py::test_layernorm to test the accuracy.

Origin op test result:

        else:
>           torch.testing.assert_close(norm.to(torch.float32), ref.to(torch.float32), atol=atol, rtol=rtol)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 15630674 / 145754836 (10.7%)
E           Greatest absolute difference: 0.6875 at index (4819, 16480) (up to 0.05 allowed)
E           Greatest relative difference: 160200.109375 at index (4799, 44) (up to 0.01 allowed)

utils/acc.py:61: AssertionError
====================================================================================================================== warnings summary ======================================================================================================================
experimental/operators/attention.py:698
  /root/XPU-Forces/mojo_opset/mojo_opset/experimental/operators/attention.py:698: DeprecationWarning: invalid escape sequence '\l'
    """

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== short test summary info ===================================================================================================================
FAILED tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape2] - AssertionError: Tensor-likes are not close!
FAILED tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape4] - AssertionError: Tensor-likes are not close!
FAILED tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape2] - AssertionError: Tensor-likes are not close!
FAILED tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape4] - AssertionError: Tensor-likes are not close!
========================================================================================================== 4 failed, 6 passed, 1 warning in 21.96s ===========================================================================================================

Fixed op test result:

tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape0] 
------------------------------------------------------------------------------------ live log call ------------------------------------------------------------------------------------
[INFO] 06/04/2026 06:59:51 >> get env TRITON_DEBUG = 1
[INFO] 06/04/2026 06:59:51 >> get env TRITON_ALWAYS_COMPILE = 1
PASSED                                                                                                                                                                          [ 10%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape1] PASSED                                                                                      [ 20%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape2] PASSED                                                                                      [ 30%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape3] PASSED                                                                                      [ 40%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype0-shape4] PASSED                                                                                      [ 50%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape0] PASSED                                                                                      [ 60%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape1] PASSED                                                                                      [ 70%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape2] PASSED                                                                                      [ 80%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape3] PASSED                                                                                      [ 90%]
tests/accuracy/operators/test_normalization.py::test_layernorm[1e-05-dtype1-shape4] PASSED                                                                                      [100%]

================================================================================= 10 passed in 25.85s =================================================================================

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the Triton layernorm forward kernel in mojo_opset/backends/ttx/kernels/npu/layernorm.py to filter out invalid columns by applying a block_mask to x_centered before calculating the variance accumulation. There are no review comments, and I have no additional feedback to provide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant