55import pytest
66import torch
77
8- from tileops .kernels .gemm .maca_auto import MacaAutoGemmKernel
8+ from tileops .kernels .gemm .gemm import GemmKernel
99from tileops .kernels .gemm import maca_hgemm as maca_hgemm_module
1010from tileops .kernels .gemm .maca_hgemm import (
1111 MacaHGemmKernel ,
@@ -130,11 +130,23 @@ def test_is_metax_c500(monkeypatch: pytest.MonkeyPatch) -> None:
130130
131131
132132@pytest .mark .smoke
133- def test_gemm_default_selector_uses_c500_auto_backend (monkeypatch : pytest .MonkeyPatch ) -> None :
133+ def test_gemm_default_selector_uses_tilelang_compiler_backend_on_c500 (
134+ monkeypatch : pytest .MonkeyPatch ) -> None :
134135 _patch_metax_c500 (monkeypatch )
135136 monkeypatch .delenv ("TILEOPS_GEMM_BACKEND" , raising = False )
136137
137- assert _select_gemm_kernel ().__name__ == "MacaAutoGemmKernel"
138+ assert _select_gemm_kernel () is GemmKernel
139+
140+
141+ @pytest .mark .smoke
142+ @pytest .mark .parametrize ("backend" , ["maca_hgemm" , "maca_auto" ])
143+ def test_gemm_selector_rejects_direct_hpp_backends (
144+ monkeypatch : pytest .MonkeyPatch , backend : str ) -> None :
145+ _patch_metax_c500 (monkeypatch )
146+ monkeypatch .setenv ("TILEOPS_GEMM_BACKEND" , backend )
147+
148+ with pytest .raises (RuntimeError , match = "TileLang DSL/compiler" ):
149+ _select_gemm_kernel ()
138150
139151
140152@pytest .mark .smoke
@@ -227,39 +239,157 @@ def test_maca_hgemm_experimental_rowa_layout_b_gate_is_disabled(
227239
228240
229241@pytest .mark .smoke
230- def test_gemm_op_auto_routes_to_maca_hgemm_on_c500_fp16 (
242+ def test_gemm_op_auto_routes_to_tilelang_compiler_backend_on_c500_fp16 (
231243 monkeypatch : pytest .MonkeyPatch ) -> None :
232244 _patch_metax_c500 (monkeypatch )
233245 monkeypatch .delenv ("TILEOPS_GEMM_BACKEND" , raising = False )
234- _install_fake_backend_module (monkeypatch , "tileops.kernels.gemm.maca_hgemm" ,
235- "MacaHGemmKernel" )
236246
237247 op = GemmOp (128 , 128 , 128 , dtype = torch .float16 , tune = False )
238248
239- assert isinstance (op .kernel , MacaAutoGemmKernel )
240- assert op .kernel .inner .__class__ .__name__ == "MacaHGemmKernel"
241- assert op .kernel .config ["selected_backend" ] == "MacaHGemmKernel"
249+ assert isinstance (op .kernel , GemmKernel )
242250
243251
244252@pytest .mark .smoke
245- def test_gemm_op_prepacked_b_path_is_exposed_through_auto_dispatch (
253+ def test_gemm_op_auto_dispatch_exposes_compiler_prepared_b_path (
246254 monkeypatch : pytest .MonkeyPatch ) -> None :
247255 _patch_metax_c500 (monkeypatch )
248256 monkeypatch .delenv ("TILEOPS_GEMM_BACKEND" , raising = False )
249- _install_fake_backend_module_with_prepacked_api (monkeypatch , "tileops.kernels.gemm.maca_hgemm" ,
250- "MacaHGemmKernel" )
251257
252- op = GemmOp (128 , 128 , 128 , dtype = torch .float16 , tune = False )
253- a = torch .ones ((128 , 128 ), dtype = torch .float16 )
254- b = torch .ones ((128 , 128 ), dtype = torch .float16 )
258+ op = GemmOp (2 , 3 , 4 , dtype = torch .float16 , tune = False )
259+ b = torch .arange (12 , dtype = torch .float16 ).reshape (4 , 3 )
260+
261+ prepared_b = op .prepare_b (b )
262+
263+ assert isinstance (op .kernel , GemmKernel )
264+ assert prepared_b .shape == (3 , 4 )
265+ assert torch .equal (prepared_b , b .transpose (0 , 1 ).contiguous ())
266+
267+
268+ @pytest .mark .smoke
269+ def test_gemm_kernel_prefers_maca_bsm_path_on_aligned_c500_fp16 (
270+ monkeypatch : pytest .MonkeyPatch ) -> None :
271+ _patch_metax_c500 (monkeypatch )
272+ kernel = GemmKernel (128 , 128 , 128 , dtype = torch .float16 , tune = False )
273+
274+ assert kernel ._use_maca_bsm_path is True
275+ assert kernel ._use_col_major_output is False
276+ assert kernel .config == {
277+ "block_m" : 128 ,
278+ "block_n" : 128 ,
279+ "block_k" : 128 ,
280+ "num_stages" : 0 ,
281+ "threads" : 256 ,
282+ "enable_rasteration" : True ,
283+ }
284+
285+
286+ @pytest .mark .smoke
287+ def test_gemm_kernel_prepare_b_reuses_native_cache (
288+ monkeypatch : pytest .MonkeyPatch ) -> None :
289+ _patch_metax_c500 (monkeypatch )
290+ kernel = GemmKernel (2 , 3 , 4 , dtype = torch .float16 , tune = False )
291+ b = torch .arange (12 , dtype = torch .float16 ).reshape (4 , 3 )
292+
293+ prepared_first = kernel .prepare_b (b )
294+ prepared_second = kernel .prepare_b (b )
295+
296+ assert prepared_first is prepared_second
297+
298+
299+ @pytest .mark .smoke
300+ def test_gemm_kernel_prepare_b_can_pack_bsm_tile_layout (
301+ monkeypatch : pytest .MonkeyPatch ) -> None :
302+ _patch_metax_c500 (monkeypatch )
303+ monkeypatch .setenv ("TILEOPS_GEMM_PACKED_B_TILE" , "1" )
304+ monkeypatch .delenv ("TILEOPS_GEMM_SPLIT_K" , raising = False )
305+ kernel = GemmKernel (128 , 256 , 128 , dtype = torch .float16 , tune = False )
306+ b = torch .arange (128 * 256 , dtype = torch .float16 ).reshape (128 , 256 )
307+
308+ prepared_b = kernel .prepare_b (b )
309+ expected = b .transpose (0 , 1 ).contiguous ().view (
310+ 2 ,
311+ 128 ,
312+ 1 ,
313+ 128 ,
314+ ).permute (0 , 2 , 1 , 3 ).contiguous ()
315+
316+ assert kernel ._use_maca_bsm_path is True
317+ assert kernel ._use_packed_b_tile_path is True
318+ assert prepared_b .shape == (2 , 1 , 128 , 128 )
319+ assert torch .equal (prepared_b , expected )
320+
321+
322+ @pytest .mark .smoke
323+ def test_gemm_kernel_prepare_b_can_pack_splitk_bsm_tile_layout (
324+ monkeypatch : pytest .MonkeyPatch ) -> None :
325+ _patch_metax_c500 (monkeypatch )
326+ monkeypatch .setenv ("TILEOPS_GEMM_PACKED_B_TILE" , "1" )
327+ monkeypatch .setenv ("TILEOPS_GEMM_SPLIT_K" , "2" )
328+ kernel = GemmKernel (128 , 128 , 256 , dtype = torch .float16 , tune = False )
329+ b = torch .arange (256 * 128 , dtype = torch .float16 ).reshape (256 , 128 )
330+
331+ prepared_b = kernel .prepare_b (b )
332+ expected = b .transpose (0 , 1 ).contiguous ().view (
333+ 1 ,
334+ 128 ,
335+ 4 ,
336+ 64 ,
337+ ).permute (0 , 2 , 1 , 3 ).contiguous ()
338+
339+ assert kernel ._use_split_k_path is True
340+ assert kernel ._use_packed_b_tile_path is True
341+ assert prepared_b .shape == (1 , 4 , 128 , 64 )
342+ assert torch .equal (prepared_b , expected )
343+
344+
345+ @pytest .mark .smoke
346+ def test_gemm_kernel_can_select_packed_b_async_pipeline (
347+ monkeypatch : pytest .MonkeyPatch ) -> None :
348+ _patch_metax_c500 (monkeypatch )
349+ monkeypatch .setenv ("TILEOPS_GEMM_PACKED_B_TILE" , "1" )
350+ monkeypatch .setenv ("TILEOPS_GEMM_PACKED_B_ASYNC_PIPELINE" , "1" )
351+ monkeypatch .setenv ("TILEOPS_GEMM_SPLIT_K" , "2" )
255352
256- prepared = op .prepare_b (b )
257- out = op .forward_with_prepared_b (a , prepared )
353+ kernel = GemmKernel (128 , 128 , 256 , dtype = torch .float16 , tune = False )
258354
259- assert torch .equal (prepared , b + 1 )
260- assert torch .equal (out , b )
261- assert op .kernel .inner .prepared_b is b
262- assert op .kernel .inner .forward_args == (a , prepared )
355+ assert kernel ._use_split_k_path is True
356+ assert kernel ._use_packed_b_tile_path is True
357+ assert kernel ._use_packed_b_async_pipeline_path is True
358+
359+
360+ @pytest .mark .smoke
361+ def test_gemm_kernel_rejects_incompatible_splitk_block_k_configuration (
362+ monkeypatch : pytest .MonkeyPatch ) -> None :
363+ _patch_metax_c500 (monkeypatch )
364+ monkeypatch .setenv ("TILEOPS_GEMM_SPLIT_K" , "2" )
365+
366+ with pytest .raises (RuntimeError , match = "block_k divisible by split_k" ):
367+ GemmKernel (128 , 128 , 256 , dtype = torch .float16 , tune = False , config = {"block_k" : 33 })
368+
369+
370+ @pytest .mark .smoke
371+ def test_gemm_kernel_packed_b_tile_leaves_transposed_b_layout_alone (
372+ monkeypatch : pytest .MonkeyPatch ) -> None :
373+ _patch_metax_c500 (monkeypatch )
374+ monkeypatch .setenv ("TILEOPS_GEMM_PACKED_B_TILE" , "1" )
375+ kernel = GemmKernel (128 , 128 , 128 , dtype = torch .float16 , tune = False , trans_b = True )
376+ b = torch .arange (128 * 128 , dtype = torch .float16 ).reshape (128 , 128 )
377+
378+ prepared_b = kernel .prepare_b (b )
379+
380+ assert kernel ._use_maca_bsm_path is True
381+ assert kernel ._use_packed_b_tile_path is False
382+ assert prepared_b is b
383+
384+
385+ @pytest .mark .smoke
386+ def test_gemm_kernel_prepare_a_is_identity (
387+ monkeypatch : pytest .MonkeyPatch ) -> None :
388+ _patch_metax_c500 (monkeypatch )
389+ kernel = GemmKernel (2 , 3 , 4 , dtype = torch .float16 , tune = False )
390+ a = torch .arange (8 , dtype = torch .float16 ).reshape (2 , 4 )
391+
392+ assert torch .equal (kernel .prepare_a (a ), a )
263393
264394
265395@pytest .mark .smoke
@@ -291,21 +421,20 @@ def test_maca_hgemm_explicit_launch_order_env_disables_auto_selection(
291421
292422
293423@pytest .mark .smoke
294- def test_gemm_op_reference_layout_ab_continuous_c_routes_through_external_entrypoint (
424+ def test_maca_hgemm_reference_layout_ab_continuous_c_routes_through_external_entrypoint (
295425 monkeypatch : pytest .MonkeyPatch ) -> None :
296426 _patch_metax_c500 (monkeypatch )
297427 monkeypatch .setenv ("TILEOPS_MACA_HGEMM_USE_REFERENCE_LAYOUT_AB_CONTINUOUS_C" , "1" )
298- monkeypatch .delenv ("TILEOPS_GEMM_BACKEND" , raising = False )
299428 fake_reference = _install_fake_reference_layout_ab_module (monkeypatch )
300429 _reference_muxi_layout_kernels .cache_clear ()
301430
302- op = GemmOp (128 , 16 , 5120 , dtype = torch .float16 , tune = False )
431+ kernel = MacaHGemmKernel (128 , 16 , 5120 , dtype = torch .float16 , tune = False )
303432 a = torch .ones ((128 , 5120 ), dtype = torch .float16 )
304433 b = torch .ones ((5120 , 16 ), dtype = torch .float16 )
305434
306- prepared_a = op .prepare_a (a )
307- prepared_b = op .prepare_b (b )
308- out = op .forward_with_prepared_a_and_b (prepared_a , prepared_b )
435+ prepared_a = kernel .prepare_a (a )
436+ prepared_b = kernel .prepare_b (b )
437+ out = kernel .forward_with_prepared_a_and_b (prepared_a , prepared_b )
309438 expected_prepared_a = a .view (128 // 16 , 16 , 5120 // 8 , 8 ).permute (0 , 2 , 1 , 3 ).contiguous ()
310439 expected_prepared_b = b .transpose (0 , 1 ).contiguous ().view (
311440 16 // 16 ,
@@ -315,8 +444,7 @@ def test_gemm_op_reference_layout_ab_continuous_c_routes_through_external_entryp
315444 8 ,
316445 ).permute (2 , 0 , 3 , 1 , 4 ).contiguous ()
317446
318- assert isinstance (op .kernel .inner , MacaHGemmKernel )
319- assert op .kernel .inner .use_reference_layout_ab_continuous_c is True
447+ assert kernel .use_reference_layout_ab_continuous_c is True
320448 assert prepared_a .shape == (8 , 640 , 16 , 8 )
321449 assert prepared_b .shape == (160 , 1 , 4 , 16 , 8 )
322450 assert torch .equal (prepared_a , expected_prepared_a )
@@ -326,27 +454,25 @@ def test_gemm_op_reference_layout_ab_continuous_c_routes_through_external_entryp
326454
327455
328456@pytest .mark .smoke
329- def test_gemm_op_reference_layout_a_routes_through_external_entrypoint (
457+ def test_maca_hgemm_reference_layout_a_routes_through_external_entrypoint (
330458 monkeypatch : pytest .MonkeyPatch ) -> None :
331459 _patch_metax_c500 (monkeypatch )
332460 monkeypatch .setenv ("TILEOPS_MACA_HGEMM_USE_REFERENCE_LAYOUT_A_BODY" , "1" )
333- monkeypatch .delenv ("TILEOPS_GEMM_BACKEND" , raising = False )
334461 fake_reference = _install_fake_reference_layout_a_module (monkeypatch )
335462 _reference_muxi_layout_kernels .cache_clear ()
336463
337- op = GemmOp (128 , 64 , 128 , dtype = torch .float16 , tune = False )
464+ kernel = MacaHGemmKernel (128 , 64 , 128 , dtype = torch .float16 , tune = False )
338465 a = torch .arange (128 * 128 , dtype = torch .float16 ).reshape (128 , 128 )
339466 b = torch .arange (128 * 64 , dtype = torch .float16 ).reshape (128 , 64 )
340467
341- prepared_a = op .prepare_a (a )
342- prepared_b = op .prepare_b (b )
343- out = op .forward_with_prepared_a_and_b (prepared_a , prepared_b )
468+ prepared_a = kernel .prepare_a (a )
469+ prepared_b = kernel .prepare_b (b )
470+ out = kernel .forward_with_prepared_a_and_b (prepared_a , prepared_b )
344471 expected_prepared_a = a .view (128 // 16 , 16 , 128 // 8 , 8 ).permute (0 , 2 , 1 , 3 ).contiguous ()
345472 expected_prepared_b = b .transpose (0 , 1 ).contiguous ()
346473
347- assert isinstance (op .kernel .inner , MacaHGemmKernel )
348- assert op .kernel .inner .use_reference_layout_a is True
349- assert op .kernel .inner .config ["backend" ] == "maca_hgemm_reference_layout_a"
474+ assert kernel .use_reference_layout_a is True
475+ assert kernel .config ["backend" ] == "maca_hgemm_reference_layout_a"
350476 assert prepared_a .shape == (8 , 16 , 16 , 8 )
351477 assert prepared_b .shape == (64 , 128 )
352478 assert torch .equal (prepared_a , expected_prepared_a )
@@ -367,11 +493,12 @@ def test_maca_hgemm_rowa_layout_b_body_stays_disabled_after_failed_smoke(
367493
368494
369495@pytest .mark .smoke
370- def test_gemm_op_reference_layout_ab_continuous_c_rejects_unsupported_long_k_shape (
496+ def test_gemm_op_ignores_hpp_reference_layout_env_on_auto_dispatch (
371497 monkeypatch : pytest .MonkeyPatch ) -> None :
372498 _patch_metax_c500 (monkeypatch )
373499 monkeypatch .setenv ("TILEOPS_MACA_HGEMM_USE_REFERENCE_LAYOUT_AB_CONTINUOUS_C" , "1" )
374500 monkeypatch .delenv ("TILEOPS_GEMM_BACKEND" , raising = False )
375501
376- with pytest .raises (RuntimeError , match = "only supports shapes listed" ):
377- GemmOp (1664 , 1024 , 16384 , dtype = torch .float16 , tune = False )
502+ op = GemmOp (1664 , 1024 , 16384 , dtype = torch .float16 , tune = False )
503+
504+ assert isinstance (op .kernel , GemmKernel )
0 commit comments