Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions helion/_compiler/cute/cute_fx_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,21 @@ def aux_tensor_load_kind(
row 0 of the aux to every output row. The splice site builds
the same stride-``(0, 1)`` 2-D view as the trailing-axis
rowvec form (row 0 reused across M).
- ``("broadcast", 2)``: a per-row column-vector broadcast aux
load (``scale_a[tile_m, tile_n]`` where ``scale_a`` is a full
``(M, N)`` view with **stride 0 on the trailing (N) axis** — an
``unsqueeze(1).expand(M, N)`` of a per-row ``(M,)`` vector). The
underlying tensor's rank is 2 with the carrier's global shape,
the load result shape equals the carrier tile shape, and the
load's two index symbols are exactly ``carrier_tile_index_nodes``
in order. Its value depends only on ``m``, so it is uniform over
each thread's N fragment in the T2R epilogue: the splice reads it
as a single **scalar** per subtile (``tTR_gAux[(0,0,0,s)]``)
rather than a vector ``.load()``, avoiding a redundant N-wide
read. Tried before ``("exact", None)`` because a stride-0-N aux
still has the full ``(M, N)`` underlying shape; a genuine 2-D
residual has a non-zero trailing stride and falls through to the
exact matcher.

The classifier returns ``None`` for everything else: 3-D
underlying tensors with a static collapse
Expand Down Expand Up @@ -310,6 +325,18 @@ def aux_tensor_load_kind(
if len(carrier_tile_shape) != 2:
return None

colvec = _matches_colvec_broadcast(
aux_shape=aux_shape,
aux_tensor_shape=aux_tensor_shape,
aux_tensor_val=aux_tensor_val,
index_list=index_list,
carrier_tile_shape=carrier_tile_shape,
carrier_tile_index_nodes=carrier_tile_index_nodes,
carrier_global_shape=carrier_global_shape,
)
if colvec is not None:
return colvec

if _matches_exact(
aux_shape=aux_shape,
aux_tensor_shape=aux_tensor_shape,
Expand Down Expand Up @@ -490,6 +517,61 @@ def _matches_leading_broadcast(
return ("broadcast", 0)


def _matches_colvec_broadcast(
*,
aux_shape: tuple[object, ...],
aux_tensor_shape: tuple[object, ...],
aux_tensor_val: torch.Tensor,
index_list: Sequence[object],
carrier_tile_shape: tuple[object, ...],
carrier_tile_index_nodes: tuple[torch.fx.Node, ...] | None,
carrier_global_shape: tuple[object, ...] | None,
) -> tuple[str, int] | None:
"""Check whether a load is a per-row column-vector broadcast and
return ``("broadcast", 2)`` on success.

The accepted form is ``scale_a[tile_m, tile_n]`` where ``scale_a`` is a
full ``(M, N)`` view with **stride 0 on the trailing (N) axis** — i.e. an
``unsqueeze(1).expand(M, N)`` of a per-row ``(M,)`` vector. Its value
depends only on ``m``, so it is *uniform over each thread's N fragment*
in the tcgen05 T2R epilogue. The splice therefore reads it as a single
**scalar** per subtile (matching CUTLASS's ``sa = tTR_gSA[(0,0,0,s)]``)
rather than a vector ``.load()``, avoiding a redundant N-wide read.

This must be tried *before* ``_matches_exact`` (a stride-0-N aux still
has the full ``(M, N)`` underlying shape, so the exact matcher would
otherwise claim it and emit a vector read). A genuine 2-D residual has a
non-zero trailing stride and falls through to ``_matches_exact``.
"""
if (
len(aux_tensor_shape) != 2
or len(aux_shape) != 2
or len(index_list) != 2
or len(carrier_tile_shape) != 2
):
return None
if carrier_tile_index_nodes is None or len(carrier_tile_index_nodes) != 2:
return None
for idx, expected in zip(index_list, carrier_tile_index_nodes, strict=True):
if idx is not expected:
return None
for aux_dim, carrier_dim in zip(aux_shape, carrier_tile_shape, strict=True):
if aux_dim != carrier_dim:
return None
if carrier_global_shape is not None:
if len(carrier_global_shape) != 2:
return None
if tuple(aux_tensor_shape) != tuple(carrier_global_shape):
return None
stride = aux_tensor_val.stride()
# Uniform over N (trailing stride 0) and varying over M (leading stride
# non-zero) — the column-vector broadcast. A real (M, N) tensor has
# trailing stride 1 and is left for the exact-shape matcher.
if len(stride) != 2 or stride[1] != 0 or stride[0] == 0:
return None
return ("broadcast", 2)


def _matches_exact(
*,
aux_shape: tuple[object, ...],
Expand Down
29 changes: 25 additions & 4 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2647,9 +2647,12 @@ def _codegen_cute_store_tcgen05_tile(
# Broadcast aux steps need a fresh AST var for the 2-D view
# of the rank-1 underlying tensor (stride 0 on the orthogonal
# axis). Exact-shape aux steps leave ``aux_view2d`` as None.
# broadcast_axis 0/1 build a stride-0 2-D view of a rank-1 tensor;
# the colvec form (2) reuses the exact-shape pipeline over its own
# (M, N) stride-(1,0) view, so it needs no separate ``aux_view2d``.
aux_view2d = (
df.new_var(f"tcgen05_aux_view2d_{aux_idx}")
if aux_step.broadcast_axis is not None
if aux_step.broadcast_axis in (0, 1)
else None
)
aux_step_records.append(
Expand Down Expand Up @@ -2826,6 +2829,10 @@ def _codegen_cute_store_tcgen05_tile(
aux_pipeline_uses_tma_load = False
aux_ring_smem_names = tuple(None for _ in aux_step_records)

# Row-vector aux (``bias[n]`` / rowwise ``scale_b[n]``) reads stay
# per-subtile (the generic ``ttr_aux_subtile.load()`` path below, placed
# after the c_pipeline acquire / acc ``consumer_wait`` / T2R prefix per the
# cycle-69 placement).
rowvec_aux_stage_records: list[_RowvecAuxStageRecord | None] = []
for aux_idx, rec in enumerate(aux_step_records):
copy_bits = 128
Expand Down Expand Up @@ -3102,9 +3109,11 @@ def _aux_tile_setup_lines(
)
continue

if rec.broadcast_axis is None:
# Exact-shape rank-2 aux: slice the per-tile region
# of the underlying 2-D tensor directly.
if rec.broadcast_axis is None or rec.broadcast_axis == 2:
# Exact-shape rank-2 aux (or the colvec form, which is a full
# (M, N) stride-(1,0) view): slice the per-tile region of the
# underlying 2-D tensor directly. The colvec's per-subtile read
# is specialized to a scalar in ``_aux_subtile_load_source``.
source_for_local_tile = rec.aux_tensor_name
aux_tile_is_local = False
elif rowvec_stage is not None:
Expand Down Expand Up @@ -3407,6 +3416,18 @@ def _aux_subtile_load_source(
f"{prelude_indent}{rec.aux_loaded} = {rec.aux_rmem}.load()\n"
)
continue
if rec.broadcast_axis == 2:
# Column-vector (per-row) aux: uniform over each thread's N
# fragment, so read a single SCALAR per subtile (T2R index
# (0,0,0)) instead of a redundant N-wide vector ``.load()``.
# Matches CUTLASS's ``sa = tTR_gSA[(0,0,0,subtile)]``; the
# scalar broadcasts in the ``acc * aux`` chain multiply.
lines.append(
f"{prelude_indent}{rec.aux_loaded} = "
f"{rec.ttr_aux_grouped}"
f"[(0, 0, 0, cutlass.Int32(_tcgen05_subtile))]\n"
)
continue
lines.extend(
[
(
Expand Down
191 changes: 191 additions & 0 deletions test/test_cute_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13827,6 +13827,197 @@ def test_trace_mma_to_store_dtype_unknown_graph_returns_none(self) -> None:
)
self.assertIsNone(_trace_mma_to_store_dtype(mma_node, []))

def _build_aux_load_node(
self,
*,
aux_tensor: torch.Tensor,
load_shape: tuple[int, ...],
index_nodes: tuple[torch.fx.Node, ...],
extra_mask: object = None,
eviction_policy: object = None,
kwargs: dict[str, object] | None = None,
) -> tuple[torch.fx.Node, tuple[torch.fx.Node, ...]]:
"""Build a synthetic ``helion.language.memory_ops.load`` FX node for
the aux-tensor classifier (``aux_tensor_load_kind``).

``aux_tensor`` is the underlying tensor whose ``.shape`` / ``.stride()``
drive classification (use ``.expand(...)`` to get the stride-0
broadcast axes). ``load_shape`` is the per-tile load result shape.
``index_nodes`` is the index list passed to ``load`` — pass the same
carrier tile-id nodes to mimic ``aux[tile_m, tile_n]``. Returns the
load node and the carrier tile-id index nodes.
"""
from helion.language import memory_ops

graph = Graph()
tensor_node = graph.call_function(_tracing_ops._new_var, args=())
tensor_node.meta["val"] = aux_tensor
args: tuple[object, ...] = (
tensor_node,
list(index_nodes),
extra_mask,
eviction_policy,
)
load_node = graph.call_function(memory_ops.load, args=args, kwargs=kwargs or {})
load_node.meta["val"] = torch.empty(load_shape, dtype=aux_tensor.dtype)
return load_node, index_nodes

def _carrier_index_nodes(self) -> tuple[torch.fx.Node, ...]:
"""Two distinct FX nodes standing in for the carrier's
``(tile_m, tile_n)`` tile-id symbols."""
g = Graph()
m = g.call_function(_tracing_ops._new_var, args=())
n = g.call_function(_tracing_ops._new_var, args=())
return (m, n)

def test_aux_load_kind_colvec_stride_1_0_is_broadcast_2(self) -> None:
"""A full ``(M, N)`` aux with trailing stride 0 (the
``unsqueeze(1).expand(M, N)`` per-row column vector) classifies as
``("broadcast", 2)`` so the epilogue reads it as a scalar per
subtile instead of a redundant N-wide vector load."""
from helion._compiler.cute.cute_fx_walk import aux_tensor_load_kind

idx = self._carrier_index_nodes()
colvec = torch.empty(4, dtype=torch.float32).unsqueeze(1).expand(4, 4)
self.assertEqual(colvec.stride(), (1, 0))
load_node, _ = self._build_aux_load_node(
aux_tensor=colvec,
load_shape=(4, 4),
index_nodes=idx,
)
self.assertEqual(
aux_tensor_load_kind(
load_node,
carrier_tile_shape=(4, 4),
carrier_tile_index_nodes=idx,
carrier_global_shape=(4, 4),
),
("broadcast", 2),
)

def test_aux_load_kind_exact_tensor_is_not_colvec(self) -> None:
"""A genuine dense ``(M, N)`` aux (trailing stride 1) must fall
through the colvec matcher to ``("exact", None)`` — the colvec
path only claims trailing-stride-0 views."""
from helion._compiler.cute.cute_fx_walk import aux_tensor_load_kind

idx = self._carrier_index_nodes()
exact = torch.empty(4, 4, dtype=torch.float32)
self.assertEqual(exact.stride(), (4, 1))
load_node, _ = self._build_aux_load_node(
aux_tensor=exact,
load_shape=(4, 4),
index_nodes=idx,
)
self.assertEqual(
aux_tensor_load_kind(
load_node,
carrier_tile_shape=(4, 4),
carrier_tile_index_nodes=idx,
carrier_global_shape=(4, 4),
),
("exact", None),
)

def test_aux_load_kind_rowvec_1n_is_not_colvec(self) -> None:
"""The explicit ``(1, N)`` leading-broadcast row vector (unit
leading axis, contiguous trailing) is ``("broadcast", 0)``, never
the colvec form — colvec requires a full ``(M, N)`` view with
trailing stride 0 and non-zero leading stride."""
from helion._compiler.cute.cute_fx_walk import aux_tensor_load_kind

idx = self._carrier_index_nodes()
# Underlying ``(1, N)`` tensor; the per-tile load result is (M, N).
rowvec = torch.empty(1, 4, dtype=torch.float32)
self.assertEqual(rowvec.stride(), (4, 1))
load_node, _ = self._build_aux_load_node(
aux_tensor=rowvec,
load_shape=(4, 4),
index_nodes=idx,
)
self.assertEqual(
aux_tensor_load_kind(
load_node,
carrier_tile_shape=(4, 4),
carrier_tile_index_nodes=idx,
carrier_global_shape=(4, 4),
),
("broadcast", 0),
)

def test_aux_load_kind_colvec_rejected_when_index_order_mismatches(
self,
) -> None:
"""The colvec matcher requires the load index list to be the
carrier tile-id nodes in order; a swapped/foreign index must not
be claimed as ``("broadcast", 2)``."""
from helion._compiler.cute.cute_fx_walk import aux_tensor_load_kind

idx = self._carrier_index_nodes()
colvec = torch.empty(4, dtype=torch.float32).unsqueeze(1).expand(4, 4)
# Swap the index order so it no longer matches the carrier nodes.
load_node, _ = self._build_aux_load_node(
aux_tensor=colvec,
load_shape=(4, 4),
index_nodes=(idx[1], idx[0]),
)
self.assertNotEqual(
aux_tensor_load_kind(
load_node,
carrier_tile_shape=(4, 4),
carrier_tile_index_nodes=idx,
carrier_global_shape=(4, 4),
),
("broadcast", 2),
)

def test_aux_load_kind_colvec_rejected_when_global_shape_mismatches(
self,
) -> None:
"""When the underlying ``(M, N)`` view does not match the carrier's
global output shape, the colvec matcher bails (returns ``None``
rather than ``("broadcast", 2)``)."""
from helion._compiler.cute.cute_fx_walk import aux_tensor_load_kind

idx = self._carrier_index_nodes()
colvec = torch.empty(4, dtype=torch.float32).unsqueeze(1).expand(4, 4)
result = aux_tensor_load_kind(
self._build_aux_load_node(
aux_tensor=colvec,
load_shape=(4, 4),
index_nodes=idx,
)[0],
carrier_tile_shape=(4, 4),
carrier_tile_index_nodes=idx,
carrier_global_shape=(8, 8), # mismatched global shape
)
self.assertNotEqual(result, ("broadcast", 2))

def test_aux_load_kind_colvec_rejected_with_extra_mask(self) -> None:
"""A present ``extra_mask`` arg disqualifies the load entirely
(the splice emits a plain ``.load()`` with no mask), so even a
colvec-shaped aux returns ``None``."""
from helion._compiler.cute.cute_fx_walk import aux_tensor_load_kind

idx = self._carrier_index_nodes()
colvec = torch.empty(4, dtype=torch.float32).unsqueeze(1).expand(4, 4)
mask_graph = Graph()
mask_node = mask_graph.call_function(_tracing_ops._new_var, args=())
load_node, _ = self._build_aux_load_node(
aux_tensor=colvec,
load_shape=(4, 4),
index_nodes=idx,
extra_mask=mask_node,
)
self.assertIsNone(
aux_tensor_load_kind(
load_node,
carrier_tile_shape=(4, 4),
carrier_tile_index_nodes=idx,
carrier_global_shape=(4, 4),
)
)

def test_emit_sched_pipeline_setup_round_trips_pipeline_async(self) -> None:
"""``_emit_sched_pipeline_setup`` emits the
``cutlass.pipeline.PipelineAsync.create`` wrapper used to
Expand Down
Loading
Loading