Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions benchmarks/bench_final.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,44 @@ def ref(g, u):
print(f"{M:5d} {F:6d} {t_p:8.2f}us {t_t:8.2f}us {bw(bytes_, t_p):11.1f} {bw(bytes_, t_t):11.1f} {t_t/t_p:8.2f}x")


def flash_norm():
from examples.hopper.flash_norm import build_flash_norm
from examples.hopper.rms_norm import build_rms_norm
hdr("FLASH NORM (fp32) — norm sans weight vs RMS norm")
print(f"{'B':>5} {'N':>6} {'flash_norm':>11} {'rms_norm':>11} {'pyptx GB/s':>11} {'speedup':>8}")
for B, N in [(32, 1024), (256, 4096), (1024, 8192), (2048, 8192)]:
_, flash = build_flash_norm(B, N, N)
rms = build_rms_norm(B, N)
x = torch.randn(B, N, device="cuda")
w = torch.randn(N, device="cuda")
flash(x); torch.cuda.synchronize()
t_f = _time_events(lambda x: flash(x), x) * 1e3
rms(x, w); torch.cuda.synchronize()
t_r = _time_events(lambda x, w: rms(x, w), x, w) * 1e3
bytes_ = 2 * B * N * 4
print(f"{B:5d} {N:6d} {t_f:10.2f}us {t_r:10.2f}us {bw(bytes_, t_f):11.1f} {t_r/t_f:8.2f}x")


def softmax():
from examples.hopper.softmax import build_softmax
hdr("SOFTMAX (fp32, row-wise)")
hdr("SOFTMAX (fp32, row-wise, Hopper sm_90a)")
print(f"{'B':>5} {'N':>6} {'pyptx':>9} {'torch':>9} {'pyptx GB/s':>11} {'torch GB/s':>11} {'speedup':>8}")
for B, N in [(32, 1024), (256, 4096), (1024, 8192), (2048, 8192)]:
k = build_softmax(B, N)
x = torch.randn(B, N, device="cuda")
k(x); torch.cuda.synchronize()
t_p = _time_events(lambda x: k(x), x) * 1e3
def ref(x):
return torch.softmax(x, dim=-1)
ref(x); torch.cuda.synchronize()
t_t = _time_events(ref, x) * 1e3
bytes_ = 2 * B * N * 4
print(f"{B:5d} {N:6d} {t_p:8.2f}us {t_t:8.2f}us {bw(bytes_, t_p):11.1f} {bw(bytes_, t_t):11.1f} {t_t/t_p:8.2f}x")


def softmax_blackwell():
from examples.blackwell.softmax import build_softmax
hdr("SOFTMAX (fp32, row-wise, Blackwell sm_100a)")
print(f"{'B':>5} {'N':>6} {'pyptx':>9} {'torch':>9} {'pyptx GB/s':>11} {'torch GB/s':>11} {'speedup':>8}")
for B, N in [(32, 1024), (256, 4096), (1024, 8192), (2048, 8192)]:
k = build_softmax(B, N)
Expand Down Expand Up @@ -126,6 +161,21 @@ def grouped_gemm():
print(f"{G:3d} {M:5d} {N:4d} {K:5d} {t_p:8.2f}us {tflops:8.1f}")


def flash_norm_ampere():
from examples.ampere.flash_norm import build_flash_norm
from examples.ampere.rms_norm import build_rms_norm
hdr("FLASH NORM Ampere (fp32, sm_80) — norm sans weight vs RMS norm")
print(f"{'B':>5} {'N':>6} {'flash_norm':>11} {'rms_norm':>11} {'speedup':>8}")
for B, N in [(32, 1024), (256, 4096), (1024, 8192), (2048, 8192)]:
_, flash = build_flash_norm(B, N, N)
rms = build_rms_norm(B, N)
x = torch.randn(B, N, device="cuda")
w = torch.randn(N, device="cuda")
t_f = _time_events(lambda x: flash(x), x) * 1e3
t_r = _time_events(lambda x, w: rms(x, w), x, w) * 1e3
print(f"{B:5d} {N:6d} {t_f:10.2f}us {t_r:10.2f}us {t_r/t_f:8.2f}x")


def flash_attn():
from examples.hopper.experimental.flash_attention_hopper import build_flash_attention_hopper
hdr("FLASH ATTENTION (Hopper BN=64 multi-k, bf16)")
Expand Down Expand Up @@ -166,6 +216,7 @@ def hopper_gemm():
def main():
rms_norm()
layer_norm()
flash_norm()
swiglu()
softmax()
grouped_gemm()
Expand All @@ -176,9 +227,10 @@ def main():
if __name__ == "__main__":
import sys
dispatch = {
"rms": rms_norm, "layer": layer_norm, "silu": swiglu,
"softmax": softmax,
"grouped": grouped_gemm, "attn": flash_attn, "gemm": hopper_gemm,
"rms": rms_norm, "layer": layer_norm, "flash": flash_norm,
"silu": swiglu, "softmax": softmax, "softmax_bw": softmax_blackwell,
"grouped": grouped_gemm, "attn": flash_attn,
"gemm": hopper_gemm,
}
targets = sys.argv[1:]
if not targets:
Expand Down
22 changes: 11 additions & 11 deletions docs/api/ptx.md
Original file line number Diff line number Diff line change
Expand Up @@ -1821,15 +1821,15 @@ dict(**kwargs) -> new dictionary initialized with the name=value pairs

- Kind: `attribute`

- Value: `<built-in method get of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method get of dict object at 0x7f0f5e9a7cc0>`

Return the value for key if key is in the dictionary, else default.

#### `setdefault`

- Kind: `attribute`

- Value: `<built-in method setdefault of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method setdefault of dict object at 0x7f0f5e9a7cc0>`

Insert key with a value of default if key is not in the dictionary.

Expand All @@ -1839,7 +1839,7 @@ Return the value for key if key is in the dictionary, else default.

- Kind: `attribute`

- Value: `<built-in method pop of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method pop of dict object at 0x7f0f5e9a7cc0>`

D.pop(k[,d]) -> v, remove specified key and return the corresponding value.

Expand All @@ -1850,7 +1850,7 @@ raise a KeyError.

- Kind: `attribute`

- Value: `<built-in method popitem of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method popitem of dict object at 0x7f0f5e9a7cc0>`

Remove and return a (key, value) pair as a 2-tuple.

Expand All @@ -1861,31 +1861,31 @@ Raises KeyError if the dict is empty.

- Kind: `attribute`

- Value: `<built-in method keys of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method keys of dict object at 0x7f0f5e9a7cc0>`

D.keys() -> a set-like object providing a view on D's keys

#### `items`

- Kind: `attribute`

- Value: `<built-in method items of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method items of dict object at 0x7f0f5e9a7cc0>`

D.items() -> a set-like object providing a view on D's items

#### `values`

- Kind: `attribute`

- Value: `<built-in method values of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method values of dict object at 0x7f0f5e9a7cc0>`

D.values() -> an object providing a view on D's values

#### `update`

- Kind: `attribute`

- Value: `<built-in method update of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method update of dict object at 0x7f0f5e9a7cc0>`

D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
If E is present and has a .keys() method, then does: for k in E.keys(): D[k] = E[k]
Expand All @@ -1896,22 +1896,22 @@ In either case, this is followed by: for k in F: D[k] = F[k]

- Kind: `attribute`

- Value: `<built-in method fromkeys of type object at 0x7f6bf7131f60>`
- Value: `<built-in method fromkeys of type object at 0x7f0f5fb31f60>`

Create a new dictionary with keys from iterable and values set to value.

#### `clear`

- Kind: `attribute`

- Value: `<built-in method clear of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method clear of dict object at 0x7f0f5e9a7cc0>`

D.clear() -> None. Remove all items from D.

#### `copy`

- Kind: `attribute`

- Value: `<built-in method copy of dict object at 0x7f6bf5e4bcc0>`
- Value: `<built-in method copy of dict object at 0x7f0f5e9a7cc0>`

D.copy() -> a shallow copy of D
4 changes: 2 additions & 2 deletions docs/api/pyptx.md
Original file line number Diff line number Diff line change
Expand Up @@ -516,14 +516,14 @@ Raises:

- Kind: `attribute`

- Value: `<built-in method cache_info of functools._lru_cache_wrapper object at 0x7f6bf6175bc0>`
- Value: `<built-in method cache_info of functools._lru_cache_wrapper object at 0x7f0f5ec75bc0>`

Report cache statistics

#### `cache_clear`

- Kind: `attribute`

- Value: `<built-in method cache_clear of functools._lru_cache_wrapper object at 0x7f6bf6175bc0>`
- Value: `<built-in method cache_clear of functools._lru_cache_wrapper object at 0x7f0f5ec75bc0>`

Clear the cache and cache statistics
115 changes: 115 additions & 0 deletions docs/examples/ampere/flash_norm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Ampere / Flash Norm

[:material-github: View on GitHub](https://github.com/patrick-toulme/pyptx/blob/dev/examples/ampere/flash_norm.py){ .md-button }
[:material-file-code: `examples/ampere/flash_norm.py`](https://github.com/patrick-toulme/pyptx/blob/dev/examples/ampere/flash_norm.py){ .md-button }

## Overview

Fused FlashNorm for Ampere (sm_80), written in pyptx, callable from JAX
and PyTorch.

Reaches **1.3 TB/s** at B=2048 N=8192 f32 on A100, **1.02x** faster than
the equivalent pyptx rms_norm by eliminating the per-element weight load
and multiply.

``Y[b, i] = X[b, i] / sqrt(mean(X[b, :]^2) + eps)``

Thin arch wrapper around ``examples/hopper/flash_norm.py`` — same kernel,
compiled for ``sm_80``. All instructions (``ld.global.v4.f32``,
``fma.rn.f32``, ``rsqrt.approx.f32``, ``shfl.sync.bfly.b32``,
``bar.sync``) are available since sm_80.

Run ``python examples/ampere/flash_norm.py`` to execute both a ``jax.jit``
path and a PyTorch eager path.

## Source

??? example "Full source"

```python
"""Fused FlashNorm for Ampere (sm_80), written in pyptx, callable from JAX
and PyTorch.

Reaches **1.3 TB/s** at B=2048 N=8192 f32 on A100, **1.02x** faster than
the equivalent pyptx rms_norm by eliminating the per-element weight load
and multiply.

``Y[b, i] = X[b, i] / sqrt(mean(X[b, :]^2) + eps)``

Thin arch wrapper around ``examples/hopper/flash_norm.py`` — same kernel,
compiled for ``sm_80``. All instructions (``ld.global.v4.f32``,
``fma.rn.f32``, ``rsqrt.approx.f32``, ``shfl.sync.bfly.b32``,
``bar.sync``) are available since sm_80.

Run ``python examples/ampere/flash_norm.py`` to execute both a ``jax.jit``
path and a PyTorch eager path.
"""
from __future__ import annotations

import os

os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

import jax
import jax.numpy as jnp
import numpy as np

try:
from pyptx.examples.hopper.flash_norm import build_flash_norm as _build_flash_norm
from pyptx.examples.hopper.flash_norm import flash_norm_ref
except ImportError:
from examples.hopper.flash_norm import build_flash_norm as _build_flash_norm
from examples.hopper.flash_norm import flash_norm_ref


def build_flash_norm(B: int, N: int, D: int, *, eps: float = 1e-5):
return _build_flash_norm(B, N, D, eps=eps, arch="sm_80")


def _run_jax_case(B: int, N: int) -> None:
_, flash = build_flash_norm(B, N, N)
np.random.seed(B * 7919 + N)
x_np = np.random.randn(B, N).astype(np.float32) * 0.3
x = jnp.asarray(x_np)

@jax.jit
def fn(x):
return flash(x)

out = np.asarray(fn(x))
ref = np.asarray(flash_norm_ref(x))
diff = float(np.abs(out - ref).max())
ok = bool(np.allclose(out, ref, atol=1e-4, rtol=1e-3))
status = "OK " if ok else "FAIL"
print(f"[JAX {status}] B={B:4d} N={N:5d} max_abs={diff:.3e}")


def _run_torch_case(B: int, N: int) -> None:
import torch

_, flash = build_flash_norm(B, N, N)
np.random.seed(B * 7919 + N)
x_np = np.random.randn(B, N).astype(np.float32) * 0.3
x = torch.tensor(x_np, device="cuda")

out = flash(x)
torch.cuda.synchronize()
ms = (x * x).mean(dim=-1, keepdim=True)
ref = x * torch.rsqrt(ms + 1e-5)
diff = float((out - ref).abs().max())
ok = bool(torch.allclose(out, ref, atol=1e-4, rtol=1e-3))
status = "OK " if ok else "FAIL"
print(f"[Torch{status}] B={B:4d} N={N:5d} max_abs={diff:.3e}")


def main() -> None:
_ = (jnp.ones((4,), dtype=jnp.float32) + 1).block_until_ready()

for B, N in [(4, 64), (16, 512), (32, 1024), (128, 2048), (256, 4096)]:
_run_jax_case(B, N)
_run_torch_case(B, N)


if __name__ == "__main__":
main()
```
Loading