Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions examples/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
Tensor Concatenation Example
============================
Tensor Concatenation Examples
=============================

This example demonstrates how to implement a tensor concatenation operation using Helion.
This example demonstrates two approaches to implementing tensor concatenation in Helion:
a simple version that tiles rows and uses slices for each source tensor, and a masked
version that tiles both dimensions using ``hl.load`` with ``extra_mask``.
"""

# %%
Expand All @@ -20,22 +22,64 @@
import helion.language as hl

# %%
# Concatenation Kernel
# --------------------
# Simple Concatenation Kernel
# ---------------------------
# Tiles only the row dimension and uses slices to copy each source tensor
# into the corresponding region of the output. This produces two separate
# load/store pairs with no manual masking.


# %%
@helion.kernel()
def concat2d_dim1_simple(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Concatenates two 2D tensors along dimension 1 using separate stores.

Tiles the row dimension and writes each source tensor into its
corresponding slice of the output, avoiding manual masking.

Args:
x: First input tensor of shape [M, N1]
y: Second input tensor of shape [M, N2] with same first dimension as x

Returns:
Output tensor of shape [M, N1+N2]
"""
assert x.size(0) == y.size(0)
out = torch.empty(
[x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device
)
n1 = x.size(1)
for tile_m in hl.tile(x.size(0)):
out[tile_m, :n1] = x[tile_m, :]
out[tile_m, n1:] = y[tile_m, :]
return out


# %%
# Masked Concatenation Kernel
# ----------------------------
# Tiles both dimensions of the output. Because a single tile along the
# column dimension can span both source tensors, ``hl.load`` with
# ``extra_mask`` is used to selectively load from each source, and
# ``torch.where`` merges the results.


# %%
@helion.kernel()
def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Concatenates two 2D tensors along dimension 1 (columns).
Concatenates two 2D tensors along dimension 1 using masked loads.

Tiles both dimensions and uses ``hl.load`` with ``extra_mask`` to
handle tiles that straddle the boundary between the two source tensors.

Args:
x: First input tensor of shape [M, N1]
y: Second input tensor of shape [M, N2] with same first dimension as x

Returns:
Output tensor of shape [M, N1+N2] containing the concatenation of x and y along dimension 1
Output tensor of shape [M, N1+N2]
"""
assert x.size(0) == y.size(0)
out = torch.empty(
Expand Down Expand Up @@ -70,7 +114,11 @@ def main() -> None:
"""
x = torch.randn([1500, 400], device=DEVICE)
y = torch.randn([1500, 600], device=DEVICE)
run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y))
kernels = {
"simple": concat2d_dim1_simple,
"masked": concat2d_dim1,
}
run_example(kernels, lambda x, y: torch.cat([x, y], dim=1), (x, y))


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,18 @@ def test_xsa_near_zero_v(self):
block_sizes=[1, 64, 32],
)

def test_concat_simple(self):
args = (
torch.randn(512, 500, device=DEVICE),
torch.randn(512, 512, device=DEVICE),
)
check_example(
"concatenate",
args,
torch.cat(args, dim=1),
fn_name="concat2d_dim1_simple",
)

def test_concat(self):
args = (
torch.randn(512, 500, device=DEVICE),
Expand Down
Loading