Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .loop_dependency_checker import LoopDependencyChecker
from .matmul_utils import tensor_matmul_replacement
from .matmul_utils import torch_matmul_replacement
from .node_masking import defer_pallas_load_masks
from .node_masking import remove_unnecessary_masking
from .roll_reduction import ReductionRoller
from .source_location import current_location
Expand Down Expand Up @@ -2959,10 +2960,13 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
promote_cute_root_graph_host_tensors(device_ir.graphs, promotions)
for graph in device_ir.graphs:
prepare_graph_lowerings(graph.graph)
defer_load_masks = CompileEnvironment.current().backend.name == "pallas"
for graph in device_ir.graphs:
validate_host_tensor_usage(graph.graph)
add_tile_with_offset_metadata(graph)
remove_unnecessary_tile_index(graph.graph)
if defer_load_masks:
defer_pallas_load_masks(graph.graph)
remove_unnecessary_masking(graph.graph)

# TODO(hinriksnaer): extract into a separate step? everything below
Expand Down
190 changes: 190 additions & 0 deletions helion/_compiler/node_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@
ValueRangesAny = ValueRanges[Any]


# Relayout ops a load's out-of-bounds mask may be deferred *through*: the mask is
# re-materialized later, in the consumer's layout, by a downstream ``_mask_to``
# (see ``defer_pallas_load_masks``).
#
# Restricted to ops that permute tile axes WITHOUT regrouping the masked
# dimension's elements, so the masked dimension's set of valid/invalid lanes is
# preserved exactly (only its axis position changes). ``permute`` is the only
# such op needed today -- ``transpose``/``.T`` lower to it.
#
# Deliberately NOT included:
# * ``view``/``reshape``: even when the masked block id still appears exactly
# once in the output shape, a reshape can regroup elements so the new
# per-axis mask (``arange < extent``) selects different flat positions than
# the eager load mask did -- e.g. a ``[B, 2]`` tile with 3 valid rows
# reshaped to ``[2, B]`` (old invalid flat lanes 6,7 -> new mask zeroes 3,7,
# dropping valid data and leaking invalid data). "the dim survives once" is
# necessary but NOT sufficient; admitting these needs a stride/lane-set
# equivalence proof, not just a dim-count check.
# * ``expand``/``stack``/``gather``: pass the masked *value* through but can
# replicate or relocate padded lanes into valid ones.
# * ``squeeze``/``unsqueeze``/``alias``: safe in principle (no regrouping) but
# left out until they have their own deferral tests.
#
# IMPORTANT: every op here must be RANK-PRESERVING. ``defer_pallas_load_masks``
# relies on that: its profitability gate (masked axis is a major dim at the load
# but a last-two dim at the consumer) doubles as the old "a relayout actually
# moved the axis" check *only* because a same-shape direct ``_mask_to`` cannot put
# an axis in both positions at once. A rank-changing op (squeeze/unsqueeze/view/
# reshape) would break that, and an explicit relayout-crossed check would need to
# be reinstated.
_RELAYOUT_TARGETS = frozenset({torch.ops.aten.permute.default})


def mask_node_inputs(
node: torch.fx.Node,
other: float | bool = 0,
Expand Down Expand Up @@ -184,6 +217,163 @@ def recompute_masked_values(graph: torch.fx.Graph) -> None:
node.meta["masked_value"] = cached_masked_value(node)


def defer_pallas_load_masks(graph: torch.fx.Graph) -> None:
"""Defer a Pallas load's eager out-of-bounds mask to a downstream ``_mask_to``.

Pallas load codegen materializes a tile's out-of-bounds mask multiplicatively
in the load's own layout (``ref[idx] * mask``). When the loaded value is only
*relayouted* (an axis permutation; see ``_RELAYOUT_TARGETS``) and then consumed
by a dot or reduction, that consumer already inserts a ``_mask_to(x, 0)`` which
can re-materialize the same mask in the consumer's layout. The mask is dynamic
(``arange < extent``), so it cannot be elided even when logically all-true;
applying it in the pre-relayout layout therefore keeps a live op on the path
into the relayout.

For each load whose masked tile dim is provably re-masked downstream, we:

* record the deferred block ids on the load so Pallas load codegen skips the
eager mask for those dims, and
* mark the load's masked value unknown so ``remove_unnecessary_masking`` keeps
the downstream ``_mask_to`` (it is no longer redundant once the load is not
pre-masked).

Correctness rests on a single dataflow fact: *every* use of the load reaches a
``_mask_to(_, 0)`` crossing only ``_RELAYOUT_TARGETS`` ops, with the masked
dim still present as a standalone tile dim at each step. Because those ops
only permute axes (they do not regroup the masked dim's elements), the later
per-axis mask covers exactly the lanes the eager mask would have. The
standalone-dim check is necessary but not sufficient on its own -- the
correctness guarantee comes from restricting the crossed ops to pure axis
permutations (so e.g. ``reshape`` is excluded; see ``_RELAYOUT_TARGETS``). A
use that does not re-mask (store, elementwise, reduction without a mask) keeps
the eager load mask.

Profitability is a *positional* gate on top of that correctness proof. A mask
on an axis inside the last-two (sublane/lane) dims is a vectorized per-register
op, while a mask on a major (outer) axis is applied per outer row and is much
more work. So defer only when the masked axis is a major dim at the load and
the relayout carries it into the last-two dims at the consumer ``_mask_to``;
deferring in the reverse direction would move the mask onto the more expensive
axis, so it is not done. This gate also subsumes the "a relayout actually
moved the axis" requirement (see the loop body).

Pallas-only: Triton masks loads as real data (``tl.load(..., other=0)``), so
relayout never moves unmasked lanes and there is nothing to defer.
"""
from ..language.memory_ops import load as load_op
from .aten_lowering import passthrough_masked_value
from .compile_environment import CompileEnvironment

env = CompileEnvironment.current()

def dim_index(node: torch.fx.Node, block_id: int) -> int | None:
"""Index of ``block_id`` in ``node``'s value, or None unless it appears as
exactly one standalone dimension (this doubles as the survives-uniquely
check)."""
val = node.meta.get("val")
if not isinstance(val, torch.Tensor):
return None
hits = [
i
for i, size in enumerate(val.size())
if env.resolve_block_id(size) == block_id
]
return hits[0] if len(hits) == 1 else None

def is_major_dim(node: torch.fx.Node, block_id: int) -> bool:
# An outer dim, outside the last-two (sublane, lane) vreg tile, where a
# mask is applied per outer row rather than as a per-register op.
idx = dim_index(node, block_id)
return idx is not None and idx < node.meta["val"].ndim - 2

def is_last_two_dim(node: torch.fx.Node, block_id: int) -> bool:
# Inside the last-two (sublane/lane) vreg tile, where a mask is a
# vectorized per-register op.
idx = dim_index(node, block_id)
return idx is not None and idx >= node.meta["val"].ndim - 2

def is_relayout(node: torch.fx.Node) -> bool:
if node.op != "call_function" or node.target not in _RELAYOUT_TARGETS:
return False
lowering = node.meta.get("lowering")
return getattr(lowering, "masked_value_fn", None) is passthrough_masked_value

def is_remask(node: torch.fx.Node, src: torch.fx.Node, block_id: int) -> bool:
# A zero-fill ``_mask_to`` on ``src`` that re-masks ``block_id`` with that
# axis in the last-two dims (the profitable place to apply the mask).
# ``bool(...)``: ``node.args[1] == 0`` is typed as ``Argument`` (the fill is
# always a scalar here, but the static type is a union), so coerce to bool.
return bool(
node.op == "call_function"
and node.target is _mask_to
and node.args[0] is src
and node.args[1] == 0
and is_last_two_dim(node, block_id)
)

def all_uses_remask(
node: torch.fx.Node, block_id: int, memo: dict[torch.fx.Node, bool]
) -> bool:
cached = memo.get(node)
if cached is not None:
return cached
memo[node] = False # conservative guard against revisiting mid-walk
users = list(node.users)
result = bool(users)
for user in users:
if is_remask(user, node, block_id):
continue
if (
is_relayout(user)
and dim_index(user, block_id) is not None
and all_uses_remask(user, block_id, memo)
):
continue
result = False
break
memo[node] = result
return result

changed = False
for node in graph.find_nodes(op="call_function", target=load_op):
val = node.meta.get("val")
if not isinstance(val, torch.Tensor):
continue
candidates = {
block_id
for size in val.size()
if (block_id := env.resolve_block_id(size)) is not None
}
deferred: set[int] = set()
for block_id in candidates:
# Profitability gate: defer only when the masked axis is a major/outer
# dim at the load and a relayout carries it into the last-two
# (vreg-tile) dims at the consumer ``_mask_to``. A mask on a last-two
# axis is a per-register op; a mask on a major axis is applied per
# outer row, so this is the only direction that moves the mask onto a
# cheaper axis (the reverse would move it onto a more expensive one).
#
# This also subsumes the "must cross >=1 relayout" check: with a
# rank-preserving relayout set, a direct ``_mask_to`` shares the load's
# shape, so ``block_id`` cannot be both major at the load and last-two
# at the consumer (see the note on ``_RELAYOUT_TARGETS``).
if not is_major_dim(node, block_id):
continue
if not node.users:
continue
if all_uses_remask(node, block_id, {}):
deferred.add(block_id)
if deferred:
node.meta["pallas_deferred_mask_block_ids"] = frozenset(deferred)
node.meta["masked_value"] = None
changed = True

if changed:
# Drop stale masked-value caches that assumed the load was pre-masked, so
# the surviving ``_mask_to`` nodes are not wrongly judged redundant.
recompute_masked_values(graph)


def getitem_masked_value(
getitem_node: torch.fx.Node,
) -> float | bool | None:
Expand Down
9 changes: 7 additions & 2 deletions helion/_compiler/pallas/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _load_mask_expr(
indexing_patterns = _get_indexing_patterns(state, tensor)
env = CompileEnvironment.current()
output_sizes = [*output_val.size()]
# Dims whose mask has been deferred to a downstream ``_mask_to`` by
# ``defer_pallas_load_masks`` -- masked later in the consumer layout instead.
deferred = state.fx_node.meta.get("pallas_deferred_mask_block_ids") or frozenset()
mask_exprs: list[str] = []
dtype_str: str | None = None
out_dim = 0
Expand All @@ -100,8 +103,10 @@ def _load_mask_expr(
# always valid, and applying a block-sized mask would broadcast
# the dim from 1 to block_size, causing shape mismatches.
dim_size = tensor.shape[tensor_dim]
if (not isinstance(dim_size, int) or dim_size > 1) and _tile_needs_mask(
state, block_id, tensor, tensor_dim
if (
block_id not in deferred
and (not isinstance(dim_size, int) or dim_size > 1)
and _tile_needs_mask(state, block_id, tensor, tensor_dim)
):
mask_var = state.codegen.mask_var(block_id)
if mask_var is not None:
Expand Down
Loading
Loading