Skip to content

Commit dfd55e4

Browse files
committed
[MetaXGPU] Add compiler-path C500 hgemm route
Keep GemmOp auto/default dispatch on the TileLang GemmKernel and reject direct maca_hgemm/maca_auto backend overrides for new hgemm work. Add the MACA BSM compiler path used by the packed-B split-K result, including prepared-B packing, prepared-B caching, split-K reduction, and C500 defaults. Validation: git diff --cached --check; ./.venv/bin/python -m py_compile tileops/ops/gemm.py tileops/kernels/gemm/gemm.py tileops/kernels/gemm/maca_auto.py tests/ops/test_gemm_auto_dispatch.py
1 parent 05d2d47 commit dfd55e4

5 files changed

Lines changed: 749 additions & 79 deletions

File tree

docs/perf/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ All conclusions are scoped to this configuration. Re-validate when any component
2222
| Category | Checklist | Evidence |
2323
| ----------- | -------------------------------- | -------------------------------------------------- |
2424
| Elementwise | [elementwise.md](elementwise.md) | [elementwise-evidence.md](elementwise-evidence.md) |
25+
| HGEMM | [hgemm-codegen-delta-loop.md](hgemm-codegen-delta-loop.md) | Compiler-path MetaX C500 HGEMM codegen workflow |

tests/ops/test_gemm_auto_dispatch.py

Lines changed: 168 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import torch
77

8-
from tileops.kernels.gemm.maca_auto import MacaAutoGemmKernel
8+
from tileops.kernels.gemm.gemm import GemmKernel
99
from tileops.kernels.gemm import maca_hgemm as maca_hgemm_module
1010
from tileops.kernels.gemm.maca_hgemm import (
1111
MacaHGemmKernel,
@@ -130,11 +130,23 @@ def test_is_metax_c500(monkeypatch: pytest.MonkeyPatch) -> None:
130130

131131

132132
@pytest.mark.smoke
133-
def test_gemm_default_selector_uses_c500_auto_backend(monkeypatch: pytest.MonkeyPatch) -> None:
133+
def test_gemm_default_selector_uses_tilelang_compiler_backend_on_c500(
134+
monkeypatch: pytest.MonkeyPatch) -> None:
134135
_patch_metax_c500(monkeypatch)
135136
monkeypatch.delenv("TILEOPS_GEMM_BACKEND", raising=False)
136137

137-
assert _select_gemm_kernel().__name__ == "MacaAutoGemmKernel"
138+
assert _select_gemm_kernel() is GemmKernel
139+
140+
141+
@pytest.mark.smoke
142+
@pytest.mark.parametrize("backend", ["maca_hgemm", "maca_auto"])
143+
def test_gemm_selector_rejects_direct_hpp_backends(
144+
monkeypatch: pytest.MonkeyPatch, backend: str) -> None:
145+
_patch_metax_c500(monkeypatch)
146+
monkeypatch.setenv("TILEOPS_GEMM_BACKEND", backend)
147+
148+
with pytest.raises(RuntimeError, match="TileLang DSL/compiler"):
149+
_select_gemm_kernel()
138150

139151

140152
@pytest.mark.smoke
@@ -227,39 +239,157 @@ def test_maca_hgemm_experimental_rowa_layout_b_gate_is_disabled(
227239

228240

229241
@pytest.mark.smoke
230-
def test_gemm_op_auto_routes_to_maca_hgemm_on_c500_fp16(
242+
def test_gemm_op_auto_routes_to_tilelang_compiler_backend_on_c500_fp16(
231243
monkeypatch: pytest.MonkeyPatch) -> None:
232244
_patch_metax_c500(monkeypatch)
233245
monkeypatch.delenv("TILEOPS_GEMM_BACKEND", raising=False)
234-
_install_fake_backend_module(monkeypatch, "tileops.kernels.gemm.maca_hgemm",
235-
"MacaHGemmKernel")
236246

237247
op = GemmOp(128, 128, 128, dtype=torch.float16, tune=False)
238248

239-
assert isinstance(op.kernel, MacaAutoGemmKernel)
240-
assert op.kernel.inner.__class__.__name__ == "MacaHGemmKernel"
241-
assert op.kernel.config["selected_backend"] == "MacaHGemmKernel"
249+
assert isinstance(op.kernel, GemmKernel)
242250

243251

244252
@pytest.mark.smoke
245-
def test_gemm_op_prepacked_b_path_is_exposed_through_auto_dispatch(
253+
def test_gemm_op_auto_dispatch_exposes_compiler_prepared_b_path(
246254
monkeypatch: pytest.MonkeyPatch) -> None:
247255
_patch_metax_c500(monkeypatch)
248256
monkeypatch.delenv("TILEOPS_GEMM_BACKEND", raising=False)
249-
_install_fake_backend_module_with_prepacked_api(monkeypatch, "tileops.kernels.gemm.maca_hgemm",
250-
"MacaHGemmKernel")
251257

252-
op = GemmOp(128, 128, 128, dtype=torch.float16, tune=False)
253-
a = torch.ones((128, 128), dtype=torch.float16)
254-
b = torch.ones((128, 128), dtype=torch.float16)
258+
op = GemmOp(2, 3, 4, dtype=torch.float16, tune=False)
259+
b = torch.arange(12, dtype=torch.float16).reshape(4, 3)
260+
261+
prepared_b = op.prepare_b(b)
262+
263+
assert isinstance(op.kernel, GemmKernel)
264+
assert prepared_b.shape == (3, 4)
265+
assert torch.equal(prepared_b, b.transpose(0, 1).contiguous())
266+
267+
268+
@pytest.mark.smoke
269+
def test_gemm_kernel_prefers_maca_bsm_path_on_aligned_c500_fp16(
270+
monkeypatch: pytest.MonkeyPatch) -> None:
271+
_patch_metax_c500(monkeypatch)
272+
kernel = GemmKernel(128, 128, 128, dtype=torch.float16, tune=False)
273+
274+
assert kernel._use_maca_bsm_path is True
275+
assert kernel._use_col_major_output is False
276+
assert kernel.config == {
277+
"block_m": 128,
278+
"block_n": 128,
279+
"block_k": 128,
280+
"num_stages": 0,
281+
"threads": 256,
282+
"enable_rasteration": True,
283+
}
284+
285+
286+
@pytest.mark.smoke
287+
def test_gemm_kernel_prepare_b_reuses_native_cache(
288+
monkeypatch: pytest.MonkeyPatch) -> None:
289+
_patch_metax_c500(monkeypatch)
290+
kernel = GemmKernel(2, 3, 4, dtype=torch.float16, tune=False)
291+
b = torch.arange(12, dtype=torch.float16).reshape(4, 3)
292+
293+
prepared_first = kernel.prepare_b(b)
294+
prepared_second = kernel.prepare_b(b)
295+
296+
assert prepared_first is prepared_second
297+
298+
299+
@pytest.mark.smoke
300+
def test_gemm_kernel_prepare_b_can_pack_bsm_tile_layout(
301+
monkeypatch: pytest.MonkeyPatch) -> None:
302+
_patch_metax_c500(monkeypatch)
303+
monkeypatch.setenv("TILEOPS_GEMM_PACKED_B_TILE", "1")
304+
monkeypatch.delenv("TILEOPS_GEMM_SPLIT_K", raising=False)
305+
kernel = GemmKernel(128, 256, 128, dtype=torch.float16, tune=False)
306+
b = torch.arange(128 * 256, dtype=torch.float16).reshape(128, 256)
307+
308+
prepared_b = kernel.prepare_b(b)
309+
expected = b.transpose(0, 1).contiguous().view(
310+
2,
311+
128,
312+
1,
313+
128,
314+
).permute(0, 2, 1, 3).contiguous()
315+
316+
assert kernel._use_maca_bsm_path is True
317+
assert kernel._use_packed_b_tile_path is True
318+
assert prepared_b.shape == (2, 1, 128, 128)
319+
assert torch.equal(prepared_b, expected)
320+
321+
322+
@pytest.mark.smoke
323+
def test_gemm_kernel_prepare_b_can_pack_splitk_bsm_tile_layout(
324+
monkeypatch: pytest.MonkeyPatch) -> None:
325+
_patch_metax_c500(monkeypatch)
326+
monkeypatch.setenv("TILEOPS_GEMM_PACKED_B_TILE", "1")
327+
monkeypatch.setenv("TILEOPS_GEMM_SPLIT_K", "2")
328+
kernel = GemmKernel(128, 128, 256, dtype=torch.float16, tune=False)
329+
b = torch.arange(256 * 128, dtype=torch.float16).reshape(256, 128)
330+
331+
prepared_b = kernel.prepare_b(b)
332+
expected = b.transpose(0, 1).contiguous().view(
333+
1,
334+
128,
335+
4,
336+
64,
337+
).permute(0, 2, 1, 3).contiguous()
338+
339+
assert kernel._use_split_k_path is True
340+
assert kernel._use_packed_b_tile_path is True
341+
assert prepared_b.shape == (1, 4, 128, 64)
342+
assert torch.equal(prepared_b, expected)
343+
344+
345+
@pytest.mark.smoke
346+
def test_gemm_kernel_can_select_packed_b_async_pipeline(
347+
monkeypatch: pytest.MonkeyPatch) -> None:
348+
_patch_metax_c500(monkeypatch)
349+
monkeypatch.setenv("TILEOPS_GEMM_PACKED_B_TILE", "1")
350+
monkeypatch.setenv("TILEOPS_GEMM_PACKED_B_ASYNC_PIPELINE", "1")
351+
monkeypatch.setenv("TILEOPS_GEMM_SPLIT_K", "2")
255352

256-
prepared = op.prepare_b(b)
257-
out = op.forward_with_prepared_b(a, prepared)
353+
kernel = GemmKernel(128, 128, 256, dtype=torch.float16, tune=False)
258354

259-
assert torch.equal(prepared, b + 1)
260-
assert torch.equal(out, b)
261-
assert op.kernel.inner.prepared_b is b
262-
assert op.kernel.inner.forward_args == (a, prepared)
355+
assert kernel._use_split_k_path is True
356+
assert kernel._use_packed_b_tile_path is True
357+
assert kernel._use_packed_b_async_pipeline_path is True
358+
359+
360+
@pytest.mark.smoke
361+
def test_gemm_kernel_rejects_incompatible_splitk_block_k_configuration(
362+
monkeypatch: pytest.MonkeyPatch) -> None:
363+
_patch_metax_c500(monkeypatch)
364+
monkeypatch.setenv("TILEOPS_GEMM_SPLIT_K", "2")
365+
366+
with pytest.raises(RuntimeError, match="block_k divisible by split_k"):
367+
GemmKernel(128, 128, 256, dtype=torch.float16, tune=False, config={"block_k": 33})
368+
369+
370+
@pytest.mark.smoke
371+
def test_gemm_kernel_packed_b_tile_leaves_transposed_b_layout_alone(
372+
monkeypatch: pytest.MonkeyPatch) -> None:
373+
_patch_metax_c500(monkeypatch)
374+
monkeypatch.setenv("TILEOPS_GEMM_PACKED_B_TILE", "1")
375+
kernel = GemmKernel(128, 128, 128, dtype=torch.float16, tune=False, trans_b=True)
376+
b = torch.arange(128 * 128, dtype=torch.float16).reshape(128, 128)
377+
378+
prepared_b = kernel.prepare_b(b)
379+
380+
assert kernel._use_maca_bsm_path is True
381+
assert kernel._use_packed_b_tile_path is False
382+
assert prepared_b is b
383+
384+
385+
@pytest.mark.smoke
386+
def test_gemm_kernel_prepare_a_is_identity(
387+
monkeypatch: pytest.MonkeyPatch) -> None:
388+
_patch_metax_c500(monkeypatch)
389+
kernel = GemmKernel(2, 3, 4, dtype=torch.float16, tune=False)
390+
a = torch.arange(8, dtype=torch.float16).reshape(2, 4)
391+
392+
assert torch.equal(kernel.prepare_a(a), a)
263393

264394

265395
@pytest.mark.smoke
@@ -291,21 +421,20 @@ def test_maca_hgemm_explicit_launch_order_env_disables_auto_selection(
291421

292422

293423
@pytest.mark.smoke
294-
def test_gemm_op_reference_layout_ab_continuous_c_routes_through_external_entrypoint(
424+
def test_maca_hgemm_reference_layout_ab_continuous_c_routes_through_external_entrypoint(
295425
monkeypatch: pytest.MonkeyPatch) -> None:
296426
_patch_metax_c500(monkeypatch)
297427
monkeypatch.setenv("TILEOPS_MACA_HGEMM_USE_REFERENCE_LAYOUT_AB_CONTINUOUS_C", "1")
298-
monkeypatch.delenv("TILEOPS_GEMM_BACKEND", raising=False)
299428
fake_reference = _install_fake_reference_layout_ab_module(monkeypatch)
300429
_reference_muxi_layout_kernels.cache_clear()
301430

302-
op = GemmOp(128, 16, 5120, dtype=torch.float16, tune=False)
431+
kernel = MacaHGemmKernel(128, 16, 5120, dtype=torch.float16, tune=False)
303432
a = torch.ones((128, 5120), dtype=torch.float16)
304433
b = torch.ones((5120, 16), dtype=torch.float16)
305434

306-
prepared_a = op.prepare_a(a)
307-
prepared_b = op.prepare_b(b)
308-
out = op.forward_with_prepared_a_and_b(prepared_a, prepared_b)
435+
prepared_a = kernel.prepare_a(a)
436+
prepared_b = kernel.prepare_b(b)
437+
out = kernel.forward_with_prepared_a_and_b(prepared_a, prepared_b)
309438
expected_prepared_a = a.view(128 // 16, 16, 5120 // 8, 8).permute(0, 2, 1, 3).contiguous()
310439
expected_prepared_b = b.transpose(0, 1).contiguous().view(
311440
16 // 16,
@@ -315,8 +444,7 @@ def test_gemm_op_reference_layout_ab_continuous_c_routes_through_external_entryp
315444
8,
316445
).permute(2, 0, 3, 1, 4).contiguous()
317446

318-
assert isinstance(op.kernel.inner, MacaHGemmKernel)
319-
assert op.kernel.inner.use_reference_layout_ab_continuous_c is True
447+
assert kernel.use_reference_layout_ab_continuous_c is True
320448
assert prepared_a.shape == (8, 640, 16, 8)
321449
assert prepared_b.shape == (160, 1, 4, 16, 8)
322450
assert torch.equal(prepared_a, expected_prepared_a)
@@ -326,27 +454,25 @@ def test_gemm_op_reference_layout_ab_continuous_c_routes_through_external_entryp
326454

327455

328456
@pytest.mark.smoke
329-
def test_gemm_op_reference_layout_a_routes_through_external_entrypoint(
457+
def test_maca_hgemm_reference_layout_a_routes_through_external_entrypoint(
330458
monkeypatch: pytest.MonkeyPatch) -> None:
331459
_patch_metax_c500(monkeypatch)
332460
monkeypatch.setenv("TILEOPS_MACA_HGEMM_USE_REFERENCE_LAYOUT_A_BODY", "1")
333-
monkeypatch.delenv("TILEOPS_GEMM_BACKEND", raising=False)
334461
fake_reference = _install_fake_reference_layout_a_module(monkeypatch)
335462
_reference_muxi_layout_kernels.cache_clear()
336463

337-
op = GemmOp(128, 64, 128, dtype=torch.float16, tune=False)
464+
kernel = MacaHGemmKernel(128, 64, 128, dtype=torch.float16, tune=False)
338465
a = torch.arange(128 * 128, dtype=torch.float16).reshape(128, 128)
339466
b = torch.arange(128 * 64, dtype=torch.float16).reshape(128, 64)
340467

341-
prepared_a = op.prepare_a(a)
342-
prepared_b = op.prepare_b(b)
343-
out = op.forward_with_prepared_a_and_b(prepared_a, prepared_b)
468+
prepared_a = kernel.prepare_a(a)
469+
prepared_b = kernel.prepare_b(b)
470+
out = kernel.forward_with_prepared_a_and_b(prepared_a, prepared_b)
344471
expected_prepared_a = a.view(128 // 16, 16, 128 // 8, 8).permute(0, 2, 1, 3).contiguous()
345472
expected_prepared_b = b.transpose(0, 1).contiguous()
346473

347-
assert isinstance(op.kernel.inner, MacaHGemmKernel)
348-
assert op.kernel.inner.use_reference_layout_a is True
349-
assert op.kernel.inner.config["backend"] == "maca_hgemm_reference_layout_a"
474+
assert kernel.use_reference_layout_a is True
475+
assert kernel.config["backend"] == "maca_hgemm_reference_layout_a"
350476
assert prepared_a.shape == (8, 16, 16, 8)
351477
assert prepared_b.shape == (64, 128)
352478
assert torch.equal(prepared_a, expected_prepared_a)
@@ -367,11 +493,12 @@ def test_maca_hgemm_rowa_layout_b_body_stays_disabled_after_failed_smoke(
367493

368494

369495
@pytest.mark.smoke
370-
def test_gemm_op_reference_layout_ab_continuous_c_rejects_unsupported_long_k_shape(
496+
def test_gemm_op_ignores_hpp_reference_layout_env_on_auto_dispatch(
371497
monkeypatch: pytest.MonkeyPatch) -> None:
372498
_patch_metax_c500(monkeypatch)
373499
monkeypatch.setenv("TILEOPS_MACA_HGEMM_USE_REFERENCE_LAYOUT_AB_CONTINUOUS_C", "1")
374500
monkeypatch.delenv("TILEOPS_GEMM_BACKEND", raising=False)
375501

376-
with pytest.raises(RuntimeError, match="only supports shapes listed"):
377-
GemmOp(1664, 1024, 16384, dtype=torch.float16, tune=False)
502+
op = GemmOp(1664, 1024, 16384, dtype=torch.float16, tune=False)
503+
504+
assert isinstance(op.kernel, GemmKernel)

0 commit comments

Comments
 (0)