Skip to content

Commit fdfd4ba

Browse files
Merge pull request #5 from aymuos15/fix-blackwell-target
Drop a-suffix from Blackwell bw example arch targets
2 parents bc3b0ca + f531bbf commit fdfd4ba

3 files changed

Lines changed: 6 additions & 6 deletions

File tree

examples/blackwell/layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Blackwell LayerNorm example using the maintained pyptx kernel path.
22
33
Run ``python examples/blackwell/layer_norm.py`` to execute both a ``jax.jit``
4-
path and a PyTorch eager path on ``sm_100a``.
4+
path and a PyTorch eager path on ``sm_100``.
55
"""
66
from __future__ import annotations
77

@@ -32,7 +32,7 @@ def _pick_rows_per_cta(B: int) -> int:
3232
def build_layer_norm(B: int, N: int, *, eps: float = 1e-5, rows_per_cta: int | None = None):
3333
if rows_per_cta is None:
3434
rows_per_cta = _pick_rows_per_cta(B)
35-
return _build_layer_norm(B, N, eps=eps, rows_per_cta=rows_per_cta, arch="sm_100a")
35+
return _build_layer_norm(B, N, eps=eps, rows_per_cta=rows_per_cta, arch="sm_100")
3636

3737

3838
def _run_jax_case(B: int, N: int) -> None:

examples/blackwell/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Blackwell RMSNorm example using the maintained pyptx kernel path.
22
33
Run ``python examples/blackwell/rms_norm.py`` to execute both a ``jax.jit``
4-
path and a PyTorch eager path on ``sm_100a``.
4+
path and a PyTorch eager path on ``sm_100``.
55
"""
66
from __future__ import annotations
77

@@ -24,7 +24,7 @@
2424

2525

2626
def build_rms_norm(B: int, N: int, *, eps: float = 1e-6):
27-
return _build_rms_norm(B, N, eps=eps, arch="sm_100a")
27+
return _build_rms_norm(B, N, eps=eps, arch="sm_100")
2828

2929

3030
def _run_jax_case(B: int, N: int) -> None:

examples/blackwell/swiglu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Blackwell SwiGLU example using the maintained pyptx kernel path.
22
33
Run ``python examples/blackwell/swiglu.py`` to execute both a ``jax.jit``
4-
path and a PyTorch eager path on ``sm_100a``.
4+
path and a PyTorch eager path on ``sm_100``.
55
"""
66
from __future__ import annotations
77

@@ -32,7 +32,7 @@ def _pick_rows_per_cta(M: int) -> int:
3232
def build_fused_silu_mul(M: int, F: int, *, rows_per_cta: int | None = None):
3333
if rows_per_cta is None:
3434
rows_per_cta = _pick_rows_per_cta(M)
35-
return _build_fused_silu_mul(M, F, rows_per_cta=rows_per_cta, arch="sm_100a")
35+
return _build_fused_silu_mul(M, F, rows_per_cta=rows_per_cta, arch="sm_100")
3636

3737

3838
def _run_jax_case(M: int, F: int) -> None:

0 commit comments

Comments
 (0)