Skip to content

[docs] update rms kwargs#3

Merged
patrick-toulme merged 1 commit into
patrick-toulme:mainfrom
knightron0:main
Apr 29, 2026
Merged

[docs] update rms kwargs#3
patrick-toulme merged 1 commit into
patrick-toulme:mainfrom
knightron0:main

Conversation

@knightron0

Copy link
Copy Markdown
Contributor

I'm not too sure if this is intended, but the only way I got the RMS introduction kernel from the homepage working was to edit out the N keyword argument and extract it from the X's shape.

The default snippet threw a TypeError on H100:

from pyptx import kernel, reg, smem, ptx, Tile
from pyptx.types import f32, u32

@kernel(
    in_specs=(Tile("B", "N", f32), Tile("N", f32)),   # X[B, N], W[N]
    out_specs=(Tile("B", "N", f32),),                 # Y[B, N]
    grid=lambda B, N: (B, 1, 1),
    block=(128, 1, 1),
    arch="sm_90a",
)
def rms_norm(X, W, Y, *, N: int, eps: float = 1e-6):
    partials = smem.alloc(f32, (4, 1))                # warp-partial sums
    px, pw, py = ptx.global_ptrs(X, W, Y)             # three param ptrs at once
    tid = reg.scalar(u32); ptx.inst.mov.u32(tid, ptx.special.tid.x())
    row = reg.scalar(u32); ptx.inst.mov.u32(row, ptx.special.ctaid.x())
    px += row * (N * 4); py += row * (N * 4)

    # Pass 1: v4 loads, accumulate sum-of-squares per thread.
    sum_sq = reg.scalar(f32, init=0.0)
    x_vals = reg.array(f32, N // 128)
    for j in range(N // 512):
        off = (tid << 4) + j * (128 * 16)             # 4 elems * 4 bytes per thread
        ptx.inst.ld.global_.v4.f32(
            [x_vals[j*4+k] for k in range(4)],
            ptx.addr(px + off),
        )
        for k in range(4):
            ptx.inst.fma.rn.f32(sum_sq, x_vals[j*4+k], x_vals[j*4+k], sum_sq)

    ptx.warp.reduce_sum(sum_sq)                       # canonical shfl.bfly reduce
    # ... block reduce via SMEM, rsqrt, scale by W, v4-store Y ...
    ptx.ret()

import torch

x = torch.randn(256, 4096, device="cuda")
w = torch.randn(4096, device="cuda")
y = rms_norm(x, w)
Full Error
File /usr/local/lib/python3.13/site-packages/pyptx/kernel.py:983, in Kernel.__call__(self, *args, **kwargs)
    977     raw_param_values = ()
    979 if cache_key not in self._cubin_handles:
    980     # Trace → PTX → cubin → register
    981     # Pass the shape env so TensorSpec.shape is resolved inside
    982     # the kernel body (A.shape[1] works at trace time).
--> 983     module = self._trace(_shape_env=shape_env, **template_kwargs)
    984     ptx_source = emit(module)
    985     grid_tuple = self._resolve_grid(shape_env)

File /usr/local/lib/python3.13/site-packages/pyptx/kernel.py:622, in Kernel._trace(self, _shape_env, **kwargs)
    611 """Trace the kernel function and return an IR Module.
    612 
    613 kwargs may include BOTH template parameters (keyword-only params
   (...)    619 it seeds the shape variables and any caller kwargs override it.
    620 """
    621 # Split kwargs into template params vs shape vars
--> 622 template_kwargs, caller_shape_env, _ = self._split_kwargs(kwargs)
    623 resolved = dict(self._template_defaults)
    624 resolved.update(template_kwargs)

File /usr/local/lib/python3.13/site-packages/pyptx/kernel.py:479, in Kernel._split_kwargs(self, kwargs)
    477 if k in shape_names:
    478     if not isinstance(v, int):
--> 479         raise TypeError(
    480             f"Shape variable {k!r} must be an int, got {type(v).__name__}"
    481         )
    482     shape_env[k] = v
    483 elif k in self._template_names:

TypeError: Shape variable 'N' must be an int, got NoneType

I think it's something to do with the conflict between the symbolic "N" and keyword N, but maybe it's more intuitive to just extract it from X for the intro page!

@patrick-toulme

Copy link
Copy Markdown
Owner

appreciated thanks!

@patrick-toulme patrick-toulme merged commit 8137021 into patrick-toulme:main Apr 29, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants