Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions helion/language/_tracing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .._compiler.ast_extension import expr_from_string
from .._compiler.ast_extension import statement_from_string
from .._compiler.compile_environment import CompileEnvironment
from .._compiler.device_function import find_block_size_symbols
from .._compiler.dtype_utils import cast_ast
from .._compiler.host_function import HostFunction
from .._compiler.variable_origin import BlockSizeOrigin
Expand Down Expand Up @@ -520,6 +521,44 @@ def _scratch_write_stmt(state: CodegenState, sname: str, val: ast.AST) -> ast.AS
return statement_from_string(f"{sname}[{idx}] = {{val}}", val=val)


def _resolve_dim_size(
s: object,
env: CompileEnvironment,
config: Config,
) -> int | None:
"""Resolve a tensor-dim size to a concrete int from ``config``, else ``None``.

Handles a single tile dim via ``resolve_block_id`` and ``reshape``-merged
dims (a sympy product/sum/power of block symbols) by substituting each block
size. The ``int(s)`` fallback would otherwise return the full-extent size
hint and over-size loop-carried scratch.
"""
bid = env.resolve_block_id(s)
if bid is not None:
bs = env.block_sizes[bid].from_config(config)
return bs if isinstance(bs, int) else None

if isinstance(s, int):
return s
expr = s._sympy_() if isinstance(s, torch.SymInt) else s
if not isinstance(expr, sympy.Expr):
return None
if expr.is_Integer:
return int(expr)

block_mapping, non_block_symbols = find_block_size_symbols(expr)
if non_block_symbols:
return None
subs: dict[sympy.Symbol, sympy.Integer] = {}
for symbol, block_id in block_mapping.items():
bs = env.block_sizes[block_id].from_config(config)
if not isinstance(bs, int):
return None
subs[symbol] = sympy.Integer(bs)
resolved = expr.xreplace(subs)
return int(resolved) if resolved.is_Integer else None


def _resolve_shape(
proxy: torch.Tensor,
env: CompileEnvironment,
Expand All @@ -528,11 +567,9 @@ def _resolve_shape(
"""Resolve symbolic tile sizes to concrete block sizes from config."""
resolved = []
for s in proxy.shape:
bid = env.resolve_block_id(s)
if bid is not None:
bs = env.block_sizes[bid].from_config(config)
assert isinstance(bs, int)
resolved.append(bs)
size = _resolve_dim_size(s, env, config)
if size is not None:
resolved.append(size)
else:
resolved.append(int(s))
return tuple(resolved)
Expand Down
63 changes: 63 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,69 @@ def test_attention_unroll_fp32(self) -> None:
self.assertIn("torch.empty_like(q_view, device='meta')", _code)
self.assertIn("out = _launcher(", _code)

def test_attention_reshape_merge_scratch_size(self) -> None:
"""Reshape-merged tiled dims size loop-carried scratch by block product.

A ``reshape([-1, d])`` that merges several tiled dims gives a leading
size that is a *product* of block-size symbols. The scratch must resolve
to that product (here m_block=64), not the symbol's size hint (the full
2*2*64=262144 extent) which would over-size the buffer and crash.
"""

@helion.kernel(backend="pallas", static_shapes=True)
def attn_merge(
q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor
) -> torch.Tensor:
bq, hq, m, n = (
q_in.size(0),
q_in.size(1),
q_in.size(2),
k_in.size(2),
)
d = hl.specialize(q_in.size(3))
out = torch.empty_like(q_in)
scale = (1.0 / math.sqrt(d)) * 1.44269504
for tb, th, tm in hl.tile([bq, hq, m]):
qt = q_in[tb, th, tm, :].reshape([-1, d])
m_i = hl.full([qt.size(0)], float("-inf"), dtype=torch.float32)
l_i = hl.full([qt.size(0)], 1.0, dtype=torch.float32)
acc = hl.zeros([qt.size(0), d], dtype=torch.float32)
for tn in hl.tile(n):
kt = k_in[tb, th, tn, :].reshape([-1, d])
qk = hl.dot(qt * scale, kt.transpose(0, 1), out_dtype=torch.float32)
m_ij = torch.maximum(m_i, torch.amax(qk, -1))
p = torch.exp2(qk - m_ij[:, None])
l_i = l_i * torch.exp2(m_i - m_ij) + torch.sum(p, -1)
acc = acc * torch.exp2(m_i - m_ij)[:, None]
vt = v_in[tb, th, tn, :].reshape([-1, d])
acc = torch.addmm(acc, p.to(vt.dtype), vt)
m_i = m_ij
out[tb, th, tm, :] = (
(acc / l_i[:, None]).to(out.dtype).reshape([1, 1, -1, d])
)
return out

query = torch.randn(2, 2, 64, 32, dtype=torch.float32, device=DEVICE)
key = torch.randn(2, 2, 64, 32, dtype=torch.float32, device=DEVICE)
val = torch.randn(2, 2, 64, 32, dtype=torch.float32, device=DEVICE)
code, result = code_and_output(
attn_merge,
(query, key, val),
block_sizes=[1, 1, 64, 32],
pallas_loop_type="emit_pipeline",
)
self.assertIn(
"_scratch_shapes=["
"((64,), 'jnp.float32', 'vmem'), "
"((64,), 'jnp.float32', 'vmem'), "
"((64, 32), 'jnp.float32', 'vmem')]",
code,
)
ref = torch.nn.functional.scaled_dot_product_attention(
query.float().cpu(), key.float().cpu(), val.float().cpu()
).to(device=DEVICE)
torch.testing.assert_close(result, ref, rtol=1e-2, atol=1e-2)

def test_hl_zeros_outer_arithmetic_emit_pipeline(self) -> None:
"""``hl.zeros`` results must support arithmetic at outer (non-inner-loop) scope.

Expand Down
Loading