Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/jagged_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,54 @@ def jagged_layer_norm_kernel(
return out.reshape(total_L, M)


@helion.kernel(
backend="pallas",
config=helion.Config(
block_sizes=[8, 16, 8, 16, 8, 16], pallas_loop_type="emit_pipeline"
),
)
def jagged_layer_norm_kernel_pallas(
x_values: torch.Tensor,
x_offsets: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
"""Pallas/TPU variant of :func:`jagged_layer_norm_kernel`.

Norms jointly over the jagged rows and features via ``hl.tile(s, e)``. The
joint reduce is split into two single-axis reductions so it lowers on TPU.
"""
M = x_values.size(1)
num_rows = x_offsets.size(0) - 1
out = torch.empty_like(x_values)

for g in hl.grid(num_rows):
s = x_offsets[g]
e = x_offsets[g + 1]
count = ((e - s) * M).to(torch.float32)

sum_acc = hl.zeros([1], dtype=torch.float32)
for tile_m in hl.tile(M):
for st in hl.tile(s, e):
part = x_values[st, tile_m].to(torch.float32).sum(dim=0)
sum_acc = sum_acc + part.sum(dim=0)
mean = sum_acc / count

var_acc = hl.zeros([1], dtype=torch.float32)
for tile_m in hl.tile(M):
for st in hl.tile(s, e):
centered = x_values[st, tile_m].to(torch.float32) - mean
part = (centered * centered).sum(dim=0)
var_acc = var_acc + part.sum(dim=0)
rstd = torch.rsqrt(var_acc / count + eps)

for tile_m in hl.tile(M):
for st in hl.tile(s, e):
normalized = (x_values[st, tile_m].to(torch.float32) - mean) * rstd
out[st, tile_m] = normalized.to(x_values.dtype)

return out


# %%
# Reference Implementation
# ------------------------
Expand Down
36 changes: 36 additions & 0 deletions examples/jagged_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,42 @@ def jagged_mean_kernel(
return out


@helion.kernel(
backend="pallas",
config=helion.Config(block_sizes=[8, 16], pallas_loop_type="emit_pipeline"),
)
def jagged_mean_kernel_pallas(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
x_feature_counts: torch.Tensor,
max_M: int,
) -> torch.Tensor:
"""Pallas/TPU variant of :func:`jagged_mean_kernel`.

Reduces the jagged row range ``hl.tile(s, e)`` on the 2-D tensor; the
feature count masks the output. The lane extent comes from ``x_data.size(1)``
so the load stays in the pipelined DMA.
"""
num_rows = x_offsets.size(0) - 1
M = x_data.size(1)
out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)

for g in hl.grid(num_rows):
s = x_offsets[g]
e = x_offsets[g + 1]
nnz = (e - s).to(torch.float32)
fcount = x_feature_counts[g]
for tile_m in hl.tile(M):
acc = hl.zeros([tile_m], dtype=torch.float32)
for st in hl.tile(s, e):
acc = acc + x_data[st, tile_m].to(torch.float32).sum(dim=0)
mean = torch.where(nnz > 0, acc / nnz, 0.0)
valid = tile_m.index < fcount
out[g, tile_m] = torch.where(valid, mean, 0.0).to(x_data.dtype)

return out


# %%
# Reference Implementation
# ------------------------
Expand Down
38 changes: 38 additions & 0 deletions examples/jagged_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,44 @@ def jagged_softmax_kernel(
return out.reshape(N, M)


@helion.kernel(
backend="pallas",
config=helion.Config(block_sizes=[8, 16, 16, 16], pallas_loop_type="emit_pipeline"),
)
def jagged_softmax_kernel_pallas(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
) -> torch.Tensor:
"""Pallas/TPU variant of :func:`jagged_softmax_kernel`.

Softmax over the jagged row axis per column, in three passes over
``hl.tile(s, e)`` (max, sum-exp, normalized write) instead of the flattened
gather.
"""
M = x_data.size(1)
num_rows = x_offsets.size(0) - 1
out = torch.empty_like(x_data)

for g in hl.grid(num_rows):
s = x_offsets[g]
e = x_offsets[g + 1]
for tile_m in hl.tile(M):
row_max = hl.full([tile_m], float("-inf"), dtype=torch.float32)
for st in hl.tile(s, e):
row_max = torch.maximum(
row_max, x_data[st, tile_m].to(torch.float32).amax(dim=0)
)
denom = hl.zeros([tile_m], dtype=torch.float32)
for st in hl.tile(s, e):
shifted = x_data[st, tile_m].to(torch.float32) - row_max[None, :]
denom = denom + torch.exp(shifted).sum(dim=0)
for st in hl.tile(s, e):
shifted = x_data[st, tile_m].to(torch.float32) - row_max[None, :]
out[st, tile_m] = (torch.exp(shifted) / denom[None, :]).to(x_data.dtype)

return out


# %%
# Benchmark Wrapper
# -----------------
Expand Down
29 changes: 29 additions & 0 deletions examples/jagged_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,35 @@ def jagged_sum_kernel(
return out


@helion.kernel(
backend="pallas",
config=helion.Config(block_sizes=[8, 16], pallas_loop_type="emit_pipeline"),
)
def jagged_sum_kernel_pallas(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
) -> torch.Tensor:
"""Pallas/TPU variant of :func:`jagged_sum_kernel`.

Reduces the jagged row range ``hl.tile(s, e)`` on the 2-D tensor instead of
the flattened gather, which TPU has no cheap lowering for.
"""
num_rows = x_offsets.size(0) - 1
M = x_data.size(1)
out = torch.zeros([num_rows, M], dtype=x_data.dtype, device=x_data.device)

for g in hl.grid(num_rows):
s = x_offsets[g]
e = x_offsets[g + 1]
for tile_m in hl.tile(M):
acc = hl.zeros([tile_m], dtype=torch.float32)
for st in hl.tile(s, e):
acc = acc + x_data[st, tile_m].to(torch.float32).sum(dim=0)
out[g, tile_m] = acc.to(x_data.dtype)

return out


# %%
# Reference Implementation
# ------------------------
Expand Down
104 changes: 64 additions & 40 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,16 +1101,16 @@ def test_long_sum_manual_non_divisible(self):
block_sizes=[32768, 1],
)

@xfailIfPallas("JAX tracer error with dynamic shapes")
@xfailIfPallasInterpret("currently failing with the interpreter")
@skipIfRefEager("hl.jagged_tile does not support ref mode yet")
def test_jagged_mean(self):
num_rows, max_cols = 32, 64
M = 8 # number of features
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=DEVICE),
torch.cumsum(lengths, dim=0),
torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE),
torch.cumsum(lengths, dim=0).to(LONG_INT_TYPE),
]
)
nnz = int(x_offsets[-1])
Expand All @@ -1125,13 +1125,18 @@ def test_jagged_mean(self):
x_data, x_offsets, feature_counts, M
)

check_example(
"jagged_mean",
args,
expected,
fn_name="jagged_mean_kernel",
block_sizes=[16, 8, 16],
)
if _get_backend() == "pallas":
check_example(
"jagged_mean", args, expected, fn_name="jagged_mean_kernel_pallas"
)
else:
check_example(
"jagged_mean",
args,
expected,
fn_name="jagged_mean_kernel",
block_sizes=[16, 8, 16],
)

@xfailIfPallas("requires triton module")
@skipIfRefEager(
Expand Down Expand Up @@ -1387,16 +1392,16 @@ def test_layernorm_without_bias(self):
num_stages=3,
)

@xfailIfPallas("JAX tracer error with dynamic shapes")
@xfailIfPallasInterpret("currently failing with the interpreter")
@skipIfRefEager("hl.jagged_tile does not support ref mode yet")
def test_jagged_softmax(self):
num_rows, max_cols = 128, 64
M = 8 # number of features
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=DEVICE),
torch.cumsum(lengths, dim=0),
torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE),
torch.cumsum(lengths, dim=0).to(LONG_INT_TYPE),
]
)
nnz = int(x_offsets[-1])
Expand All @@ -1407,13 +1412,18 @@ def test_jagged_softmax(self):
mod = import_path(EXAMPLES_DIR / "jagged_softmax.py")
expected = mod.reference_jagged_softmax_pytorch(x_data, x_offsets)

check_example(
"jagged_softmax",
args,
expected,
fn_name="jagged_softmax_kernel",
block_sizes=[16, 8, 16, 16],
)
if _get_backend() == "pallas":
check_example(
"jagged_softmax", args, expected, fn_name="jagged_softmax_kernel_pallas"
)
else:
check_example(
"jagged_softmax",
args,
expected,
fn_name="jagged_softmax_kernel",
block_sizes=[16, 8, 16, 16],
)

@xfailIfPallas("tensor-derived if-predicates not supported")
@skipIfXPU("Jagged tensor operations not fully supported on XPU")
Expand Down Expand Up @@ -1888,16 +1898,16 @@ def test_nvfp4_gemv(self):
rtol=2e-1,
)

@xfailIfPallas("JAX tracer error")
@xfailIfPallasInterpret("currently failing with the interpreter")
@skipIfRefEager("hl.jagged_tile does not support ref mode yet")
def test_jagged_sum(self):
num_rows, max_cols = 128, 64
M = 8 # number of features
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=DEVICE),
torch.cumsum(lengths, dim=0),
torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE),
torch.cumsum(lengths, dim=0).to(LONG_INT_TYPE),
]
)
nnz = int(x_offsets[-1])
Expand All @@ -1908,13 +1918,19 @@ def test_jagged_sum(self):
mod = import_path(EXAMPLES_DIR / "jagged_sum.py")
expected = mod.reference_jagged_sum_kernel_pytorch(x_data, x_offsets)

check_example(
"jagged_sum",
args,
expected,
fn_name="jagged_sum_kernel",
block_sizes=[16, 8, 16],
)
if _get_backend() == "pallas":
# Structured hl.tile(s, e) rewrite; config pinned on the kernel.
check_example(
"jagged_sum", args, expected, fn_name="jagged_sum_kernel_pallas"
)
else:
check_example(
"jagged_sum",
args,
expected,
fn_name="jagged_sum_kernel",
block_sizes=[16, 8, 16],
)

@skipIfXPU("Timeout on XPU")
def test_fused_linear_jsd(self):
Expand Down Expand Up @@ -2005,16 +2021,16 @@ def test_fused_linear_jsd_fwd(self):
)
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)

@xfailIfPallas("JAX tracer error")
@xfailIfPallasInterpret("currently failing with the interpreter")
@skipIfRefEager("hl.jagged_tile does not support ref mode yet")
def test_jagged_layer_norm(self):
num_rows, max_cols = 128, 64
M = 8 # number of features
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=DEVICE),
torch.cumsum(lengths, dim=0),
torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE),
torch.cumsum(lengths, dim=0).to(LONG_INT_TYPE),
]
)
nnz = int(x_offsets[-1])
Expand All @@ -2026,13 +2042,21 @@ def test_jagged_layer_norm(self):
mod = import_path(EXAMPLES_DIR / "jagged_layer_norm.py")
expected = mod.reference_jagged_layer_norm_pytorch(x_data, x_offsets, eps)

check_example(
"jagged_layer_norm",
args,
expected,
fn_name="jagged_layer_norm_kernel",
block_sizes=[4, 8, 8, 8, 8, 8, 8],
)
if _get_backend() == "pallas":
check_example(
"jagged_layer_norm",
args,
expected,
fn_name="jagged_layer_norm_kernel_pallas",
)
else:
check_example(
"jagged_layer_norm",
args,
expected,
fn_name="jagged_layer_norm_kernel",
block_sizes=[4, 8, 8, 8, 8, 8, 8],
)

def test_exp_fwd(self):
x = torch.randn([1024], device=DEVICE, dtype=torch.bfloat16)
Expand Down