[Pallas] Test jagged carry with dynamic row counts#2722
Conversation
stack-info: PR: #2722, branch: thcmbs/stack/4
77afab6 to
a01f806
Compare
| if not ( | ||
| isinstance(block_row, int) | ||
| and isinstance(block_col, int) | ||
| and isinstance(n_cols, int) |
There was a problem hiding this comment.
I know this is just a draft, but I've been testing this PR as well. This restriction on n_cols being static prevented me from using the carry on grouped_gemm. However it's not hard to lift; the only issue is that we must support scratch shapes to print symbolic values. We can do that with
diff --git a/helion/_compiler/backend.py b/helion/_compiler/backend.py
--- a/helion/_compiler/backend.py
+++ b/helion/_compiler/backend.py
@@ -2437,7 +2437,10 @@ class PallasBackend(Backend):
for s in device_fn._scratch_args
]
if scratch_shapes:
- launcher_args.append(f"_scratch_shapes={scratch_shapes!r}")
+ from .host_function import HostFunction
+ launcher_args.append(
+ f"_scratch_shapes={HostFunction.current().literal_expr(scratch_shapes)}"
+ )
# Identify which launcher arg positions correspond to pipeline-body
# tensors (need HBM refs); all others get proper BlockSpecs.
diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py
--- a/helion/_compiler/host_function.py
+++ b/helion/_compiler/host_function.py
@@ -281,6 +281,8 @@ class HostFunction:
if isinstance(expr, list):
return "[" + ", ".join(self.literal_expr(x) for x in expr) + "]"
if isinstance(expr, tuple):
+ if not expr:
+ return "()"
return "(" + ", ".join(self.literal_expr(x) for x in expr) + ", )"
return repr(expr)
(the first hunk is the important one; the second hunk unbreaks some tests pass when printing empty tuples on the scratch shapes)
There was a problem hiding this comment.
Do you want to pick up this work? I'd just submit this example as is to prevent regressions on what we currently support, but more than happy if we extend the coverage.
There was a problem hiding this comment.
Yes, happy to do this since I need it for grouped_gemm.
9fdb34c to
3c77116
Compare
stack-info: PR: #2722, branch: thcmbs/stack/4
stack-info: PR: #2722, branch: thcmbs/stack/4
stack-info: PR: #2722, branch: thcmbs/stack/4
stack-info: PR: #2719, branch: thcmbs/stack/3
stack-info: PR: #2722, branch: thcmbs/stack/4
|
As it is now (just a test), does it more sense to land this as part of its parent PR? |
Yep, fair point - done |
Stacked PRs:
[Pallas] Test jagged carry with dynamic row counts