Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions benchmark/bench_flash_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,27 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
"flash_mla_triton",
]

shape_configs = [
{"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.bfloat16}
for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 8192*2, 8192*4] for head in [128]
]
def build_shape_configs(device="cuda"):
return [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor(
[seqlen + 2 * i for i in range(batch)],
dtype=torch.int32,
device=device,
),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.bfloat16,
}
for batch in [128]
for seqlen in [1024, 2048, 4096, 8192, 8192 * 2, 8192 * 4]
for head in [128]
]


def get_args():
Expand All @@ -504,6 +521,7 @@ def get_args():
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
shape_configs = build_shape_configs()
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
Expand All @@ -517,4 +535,4 @@ def get_args():
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n')
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"])
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n')
57 changes: 57 additions & 0 deletions tests/test_benchmark_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import ast
import unittest
from pathlib import Path


BENCHMARK = Path(__file__).parents[1] / "benchmark" / "bench_flash_mla.py"


class BenchmarkShapeConfigTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tree = ast.parse(BENCHMARK.read_text(encoding="utf-8"))

def test_shape_configs_are_not_created_at_import_time(self):
top_level_assigns = [
node
for node in self.tree.body
if isinstance(node, ast.Assign)
and any(
isinstance(target, ast.Name) and target.id == "shape_configs"
for target in node.targets
)
]

self.assertEqual(top_level_assigns, [])

def test_shape_builder_accepts_device_argument(self):
builder = next(
node
for node in self.tree.body
if isinstance(node, ast.FunctionDef) and node.name == "build_shape_configs"
)

self.assertEqual(builder.args.args[0].arg, "device")
self.assertEqual(ast.literal_eval(builder.args.defaults[0]), "cuda")

def test_main_creates_shape_configs_lazily(self):
main_block = next(
node
for node in self.tree.body
if isinstance(node, ast.If)
and isinstance(node.test, ast.Compare)
and isinstance(node.test.left, ast.Name)
and node.test.left.id == "__name__"
)
calls = [node for node in ast.walk(main_block) if isinstance(node, ast.Call)]

self.assertTrue(
any(
isinstance(call.func, ast.Name) and call.func.id == "build_shape_configs"
for call in calls
)
)


if __name__ == "__main__":
unittest.main()