[Benchmark]: Add model_config sweep mode and model registry#1180
[Benchmark]: Add model_config sweep mode and model registry#1180noemotiovon wants to merge 2 commits intolinkedin:mainfrom
Conversation
- Add Qwen 2.5 models (7B / 14B / 72B) and DeepSeek models (V2 Lite / V3) to MODEL_REGISTRY - Add model_config sweep support to all 33 benchmark scripts, enabling benchmarks to sweep across different model architectures at a fixed sequence length - Refactor benchmark scripts by extracting helper functions: - _setup_* - _resolve_model_config_* to improve code reuse and keep implementations cleaner across sweep modes - Add grouped bar chart visualization in benchmarks_visualizer for model_config sweep results
Benchmark Framework DesignThis document describes the overall design of the Liger-Kernel benchmark suite, including its two benchmark dimensions, the shared infrastructure, and the phased implementation plan. 1. Benchmark DimensionsEvery operator should ideally be benchmarked along two orthogonal dimensions:
D1: Non-model dimension sweep (implemented)Sweep non-model dimensions (e.g. sequence length, BT) with a fixed model config selected via D2: Model dimension sweep (implemented)Sweep model architecture dimensions (e.g. hidden_size, or discrete model configs from 2. D2 Design ChoicesFollowing the maintainer discussion, we evaluated three approaches:
Decision: C as the primary approach, with A as optional enrichment for ops where single-parameter scaling is important. Rationale:
3. Universal Token Length for D2For D2 benchmarks, we need a fixed token-length that is safe (no OOM) across all model configs and all operators. Strategy
Proposed CLI# D1 (existing): token-length sweep with fixed model
python benchmark_geglu.py --model llama_3_8b
# D2 (new): model-config sweep with fixed token length
python benchmark_geglu.py --sweep-mode model_config --bt 2048The 4. Infrastructure Changes4.1 New config type@dataclass(frozen=True)
class ModelConfigSweepConfig:
"""Config for D2 benchmarks that sweep across model configs."""
model_configs: List[ModelConfig] # models to benchmark
bt: int # fixed batch * seq_len
batch_size: int # safe batch size
seq_len: int # safe seq_len4.2 New helperdef compute_model_config_sweep_config(
model_configs: List[ModelConfig],
probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]],
bt: int = 2048,
memory_utilization: float = 0.4,
) -> ModelConfigSweepConfig:
"""Find safe (batch_size, seq_len) that works across all model configs.
For each model config, runs probe_fn_factory(model_config, bt) to measure
peak memory, then picks the most conservative batch_size / seq_len.
"""
...4.3 Script-level changesEach benchmark script gains a model-config sweep code path gated by if args.sweep_mode == "model_config":
configs = [MODEL_REGISTRY[name] for name in MODEL_REGISTRY]
sweep = compute_model_config_sweep_config(configs, probe_fn_factory=..., bt=args.bt)
# x_values = model config indices
# extra_benchmark_configs = contains all model configs
...
else:
# existing token-length sweep logic
...4.4 VisualizationD2 results produce grouped bar charts (speedup or throughput) rather than line charts:
5. Phased Implementation PlanPhase 1: Foundation (current PR)Status: complete
Phase 2: Model-config sweep (D2)Status: complete
Phase 3: Rollout and visualizationStatus: in progress
Phase 3 Kernel Rollout TrackingAlready refactored (D1 + D2):
Norm-like kernels (input: BT × hidden_size):
Loss kernels (input involves vocab_size or similar):
RLHF/alignment loss kernels:
Positional encoding kernels:
Activation / misc kernels:
Attention kernels:
Other:
6. Directory Structure |
|
|
||
| This module re-computes forward in the backward, so forward occurs twice per iteration. | ||
| """ | ||
|
|
There was a problem hiding this comment.
maybe we could keep these comments
| dtype: torch.dtype, | ||
| device: str, | ||
| ): | ||
| def __init__(self, mhc_cls, *, hidden_size, hc, num_heads, intermediate_mult, tmax, dtype, device): |
| tmax: int, | ||
| dtype: torch.dtype, | ||
| device: str, | ||
| self, mhc_cls, *, vocab_size, hidden_size, hc, num_layers, num_heads, intermediate_mult, tmax, dtype, device |
| tmax: int, | ||
| dtype: torch.dtype, | ||
| ): | ||
| def _build_model(provider, *, hidden_size, hc, num_layers, num_heads, intermediate_mult, vocab_size, tmax, dtype): |
| Uses the DeepSpeed TiledMLP algorithm for memory-efficient MLP computation. | ||
| """ | ||
|
|
||
| def __init__(self, config, num_shards=None): |
| } | ||
| ], | ||
| "overwrite": args.overwrite, | ||
| } |
There was a problem hiding this comment.
We have built a general class BenchMiniMHCLM to test in this benchmark
| bias=bias, | ||
| dtype=dtype, | ||
| device=device, | ||
| ) |
| groups=groups, | ||
| bias=bias, | ||
| dtype=dtype, | ||
| device=device, |
| "bias": True, | ||
| "dtype": torch.bfloat16, | ||
| }, | ||
| ], |
There was a problem hiding this comment.
we have dropped too many extra configs here
| extra_benchmark_configs=[ | ||
| {"M": 2048, "dtype": torch.float32}, | ||
| {"M": 2048, "dtype": torch.bfloat16}, | ||
| ], |
| if args.sweep_mode == "model_config": | ||
| all_model_configs = list(MODEL_REGISTRY.values()) | ||
| T = 512 | ||
| BT = 2048 |
There was a problem hiding this comment.
BT is too small compared to the current one
| {"B": 32, "T": 512, "D": 768, "dtype": torch.float32}, | ||
| # Llama | ||
| {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32}, | ||
| ], |
There was a problem hiding this comment.
here we already have a bert-like model and a llama-like model
| else: | ||
| model = get_benchmark_model_config(args.model) | ||
| T = 512 | ||
| probe_bt = 2048 |
| torch.randn_like(q, device=device, dtype=dtype), | ||
| torch.randn_like(k, device=device), | ||
| ) | ||
| dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device) |
| torch.randn_like(q, device=device, dtype=dtype), | ||
| torch.randn_like(k, device=device, dtype=dtype), | ||
| ) | ||
| dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device, dtype=dtype) |
| rep=400, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) |
| "x_name": "T", | ||
| "x_label": "sequence length", | ||
| "x_values": [2**i for i in range(10, int(math.log2(max(1024, config.seq_len))) + 1)], | ||
| "kernel_providers": ["liger", "huggingface"], |
| ) | ||
| q = torch.randn((1, seq_len, num_q_heads, head_dim), device=device, requires_grad=True, dtype=dtype) | ||
| k = torch.randn((1, seq_len, num_kv_heads, head_dim), device=device, requires_grad=True, dtype=dtype) | ||
| dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(k, device=device) |
| rep=400, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd_fn, grad_to_none=[q, k], rep=400, quantiles=QUANTILES) |
| ignore_index: int = -100, | ||
| beta: float = 0.1, | ||
| ): | ||
| def __init__(self, H, V, dtype, use_bias=False, use_ref_bias=False, ignore_index=-100, beta=0.1): |
| beta=beta, | ||
| use_ref_model=True, | ||
| ).get_batch_loss_metrics | ||
| self.KTO_loss = HFKTOLoss(ignore_index=ignore_index, beta=beta, use_ref_model=True).get_batch_loss_metrics |
| ignore_index: int = -100, | ||
| beta: float = 0.1, | ||
| ): | ||
| def __init__(self, H, V, dtype, use_bias=False, use_ref_bias=False, ignore_index=-100, beta=0.1): |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) |
| rep=100, | ||
| quantiles=QUANTILES, | ||
| ) | ||
| ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) |
Hardware Type: Atlas 800I A2
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence