Skip to content

Commit 0169c4d

Browse files
Claude skill for Triton kernel development (liger-kernel-dev) (#1170)
## Summary This PR adds a **Claude Code skill** (`liger-kernel-dev`) that automates the full lifecycle of Triton kernel development for Liger Kernel. Instead of manually writing 8+ files per kernel (ops, transformers wrapper, functional API, exports, tests, benchmarks), a contributor can now describe a PyTorch operation and the skill handles analysis, code generation, and validation through a 3-stage agentic pipeline. ### How it works 1. **Analyze** — An Analyzer agent reads the PyTorch operation (from a local file, URL, code snippet, natural language description, or model component reference), writes a standalone PyTorch reference implementation, and produces a structured kernel profile classifying the operation into one of 3 complexity tiers 2. **Generate** — A Generator agent takes the profile and creates/modifies up to 8 files: Triton forward+backward kernels with `torch.autograd.Function`, `nn.Module` wrapper, functional API, `__init__.py` exports, unit tests (parametrized over shapes and dtypes), and benchmark scripts 3. **Validate** — A Validator agent runs `make checkstyle`, unit tests (hard gate with 3 retries — stops on persistent failure), benchmarks (speed + memory for fwd/bwd/full), and generates plots. Optionally runs `ncu` profiling ### Key design decisions - **Progressive disclosure**: SKILL.md is a concise entrypoint (~63 lines); detailed instructions live in separate agent files and templates that are loaded only when needed - **Three complexity tiers**: Templates cover element-wise ops (SwiGLU-like), reduction ops (RMSNorm-like), and fused/complex ops (CrossEntropy-like) with tier-specific patterns - **Correctness is a hard gate**: If tests fail after 3 retries, the pipeline stops and reports to the user rather than pushing forward with broken code - **Single dtype benchmarks**: Benchmarks use `model.dtype` (typically bfloat16). Multi-dtype coverage is handled by unit tests - **Supports both create and modify modes**: Detects intent automatically — creating new kernels goes through the full pipeline, modifying existing kernels skips analysis ### Skill file structure ``` .claude/skills/liger-kernel-dev/ ├── SKILL.md # Main entry point (63 lines) ├── analyzer.md # Stage 1: understand operation, produce profile ├── generator.md # Stage 2: generate all 8 files ├── validator.md # Stage 3: checkstyle → tests → benchmarks → plots ├── kernel-profile-format.md # Kernel profile schema + naming conventions ├── templates/ │ ├── ops-kernel.md # Triton kernel patterns by tier │ ├── module-wrapper.md # nn.Module wrapper pattern │ ├── functional-api.md # functional.py modification pattern │ ├── unit-test.md # Test file pattern with testing rules │ └── benchmark.md # Benchmark script pattern └── examples/ ├── swiglu-profile.md # Tier 1 (element-wise) reference ├── rms-norm-profile.md # Tier 2 (reduction) reference └── cross-entropy-profile.md # Tier 3 (fused/complex) reference ``` ## Testing Done - Tested end-to-end on **ReLU Squared** (Tier 1, element-wise activation used in Nemotron models) → #1171 - Skill generated all 8 files: Triton kernel, module wrapper, functional API, exports, 18 unit tests, and benchmarks - All tests pass on H100 (float32 + bfloat16) - Benchmarks show **1.9-3.3x speedup** and **37.5% memory savings** vs PyTorch - Full PR including benchmark plots was generated end-to-end by the skill - Iteratively refined the skill based on issues found during testing (single-dtype benchmarks, template accuracy) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 07dd9be commit 0169c4d

File tree

13 files changed

+1094
-0
lines changed

13 files changed

+1094
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
---
2+
name: liger-kernel-dev
3+
description: "Develops production-ready Triton kernels for Liger Kernel. Creates new kernels from PyTorch operations (local files, URLs, code snippets, or natural language) with ops, module wrappers, functional APIs, unit tests, benchmarks, and plots. Also modifies existing Liger kernels. Use when adding a new Triton kernel, converting a PyTorch operation to Triton, or updating an existing Liger kernel."
4+
---
5+
6+
# Liger Kernel Dev
7+
8+
Develops Triton kernels for Liger Kernel through a 3-stage pipeline with human review between stages. Supports creating new kernels and modifying existing ones. NVIDIA GPUs only.
9+
10+
## Mode Detection
11+
12+
- **Create mode**: User asks to create/add/generate/write/build a new kernel → full pipeline
13+
- **Modify mode**: User asks to update/fix/change/extend an existing kernel → skip Analyze, modify files, then Validate
14+
15+
## Pipeline (Create Mode)
16+
17+
### Stage 1: Analyze
18+
19+
Spawn an **Analyzer** agent (read [analyzer.md](analyzer.md)).
20+
21+
Accepts any input: local file, URL, code snippet, natural language description, or model component reference. Produces a standalone PyTorch reference implementation and a kernel profile.
22+
23+
**Human checkpoint:** Present PyTorch reference + kernel profile. Confirm before proceeding.
24+
25+
### Stage 2: Generate
26+
27+
Spawn a **Generator** agent (read [generator.md](generator.md)).
28+
29+
Generates/modifies up to 8 files:
30+
31+
1. `src/liger_kernel/ops/{kernel}.py` — NEW Triton kernels + autograd Function
32+
2. `src/liger_kernel/transformers/{kernel}.py` — NEW nn.Module wrapper
33+
3. `src/liger_kernel/transformers/functional.py` — MODIFY add functional API
34+
4. `src/liger_kernel/ops/__init__.py` — MODIFY export Function class
35+
5. `src/liger_kernel/transformers/__init__.py` — MODIFY export Module + `__all__`
36+
6. `test/transformers/test_{kernel}.py` — NEW unit tests
37+
7. `benchmark/scripts/benchmark_{kernel}.py` — NEW benchmark script
38+
8. `benchmark/data/all_benchmark_data.csv` — MODIFY (after benchmarks run)
39+
40+
**Human checkpoint:** Present changes for review.
41+
42+
### Stage 3: Validate
43+
44+
Spawn a **Validator** agent (read [validator.md](validator.md)).
45+
46+
Runs checkstyle, unit tests (hard gate — stops on persistent failure), benchmarks, and generates plots. Optionally runs ncu profiling.
47+
48+
**Human checkpoint:** Report final results with benchmark numbers and plots.
49+
50+
## Pipeline (Modify Mode)
51+
52+
1. Read existing kernel files to understand current implementation
53+
2. Understand the requested modification
54+
3. Make targeted changes (Generator handles this)
55+
4. Run full Validate stage (same as create mode)
56+
57+
## Reference Files
58+
59+
- [kernel-profile-format.md](kernel-profile-format.md) — Kernel profile schema and field descriptions
60+
- [examples/swiglu-profile.md](examples/swiglu-profile.md) — Tier 1 (element-wise) reference
61+
- [examples/rms-norm-profile.md](examples/rms-norm-profile.md) — Tier 2 (reduction) reference
62+
- [examples/cross-entropy-profile.md](examples/cross-entropy-profile.md) — Tier 3 (fused/complex) reference
63+
- Templates in [templates/](templates/) — Code generation patterns for each file type
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Analyzer Agent
2+
3+
Understands a PyTorch operation from any input form and produces a standalone PyTorch reference implementation + kernel profile.
4+
5+
## Input Handling
6+
7+
The user may provide the operation in any form:
8+
9+
1. **Local file path** → Read the file directly
10+
2. **URL** (GitHub, HuggingFace, etc.) → Fetch via WebFetch tool
11+
3. **Code snippet** → Pasted in the conversation
12+
4. **Natural language** → Mathematical description (e.g., "element-wise SiLU(x) * y")
13+
5. **Model component** → e.g., "the MLP in Phi-4" — locate in transformers source and extract
14+
15+
## Steps
16+
17+
### 1. Understand the Operation
18+
19+
- Read/fetch the source code from whatever input was provided
20+
- Identify the mathematical operation (forward pass)
21+
- Derive the backward pass (gradient computation)
22+
- Identify all inputs, outputs, and their expected shapes/dtypes
23+
- Note any precision-sensitive operations that need float32 upcasting (sigmoid, rsqrt, exp, log, tanh)
24+
25+
### 2. Write PyTorch Reference
26+
27+
Create a standalone implementation that:
28+
- Depends only on `torch` (no external libraries)
29+
- Implements both forward and backward behavior (either as an `nn.Module` or a plain function)
30+
- Will serve as the correctness baseline for testing
31+
- Is clean, readable, and well-named
32+
33+
### 3. Classify Into Tier
34+
35+
Read [kernel-profile-format.md](kernel-profile-format.md) for the full schema.
36+
37+
**Tier 1 — Element-wise**: No reductions across dimensions. One row per program. Examples: SwiGLU, GeGLU, DyT.
38+
- Read reference: `src/liger_kernel/ops/swiglu.py`
39+
40+
**Tier 2 — Reduction**: Cross-column reductions (tl.sum, tl.max). May need to save intermediate state for backward. May need SM-based parallelism for weight gradient reduction. Examples: RMSNorm, LayerNorm, Softmax, Sparsemax.
41+
- Read reference: `src/liger_kernel/ops/rms_norm.py`
42+
43+
**Tier 3 — Fused/Complex**: Multi-pass algorithms, gradient-in-forward tricks, multiple outputs. Examples: CrossEntropy, FusedLinearCrossEntropy.
44+
- Read reference: `src/liger_kernel/ops/cross_entropy.py`
45+
46+
Also read the closest example profile:
47+
- Tier 1 → [examples/swiglu-profile.md](examples/swiglu-profile.md)
48+
- Tier 2 → [examples/rms-norm-profile.md](examples/rms-norm-profile.md)
49+
- Tier 3 → [examples/cross-entropy-profile.md](examples/cross-entropy-profile.md)
50+
51+
### 4. Produce Kernel Profile
52+
53+
Fill in all fields from [kernel-profile-format.md](kernel-profile-format.md).
54+
55+
### 5. Present to User
56+
57+
Show:
58+
1. The PyTorch reference implementation (full code)
59+
2. The kernel profile (all fields)
60+
3. Which existing kernel is closest (for the Generator to use as reference)
61+
62+
Wait for user confirmation before proceeding to Stage 2.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Kernel Profile: CrossEntropy (Tier 3 — Fused/Complex)
2+
3+
## Identity
4+
- operation_name: cross_entropy
5+
- function_class_name: LigerCrossEntropyFunction
6+
- module_class_name: LigerCrossEntropyLoss
7+
- functional_name: liger_cross_entropy
8+
9+
## Classification
10+
- tier: 3
11+
- tier_description: fused/complex
12+
- closest_existing_kernel: fused_linear_cross_entropy (extends this with linear layer fusion)
13+
14+
## Forward Pass
15+
- forward_inputs:
16+
- _input: shape (B*T, V), logits
17+
- target: shape (B*T,), label indices
18+
- weight: shape (V,) or None, class weights
19+
- ignore_index: int, label to ignore
20+
- label_smoothing: float
21+
- reduction: str, "mean" | "sum" | "none"
22+
- softcap: float or None
23+
- forward_outputs:
24+
- loss: scalar or shape (B*T,) if reduction="none"
25+
- z_loss: optional auxiliary loss
26+
- token_accuracy: optional accuracy metric
27+
- predicted_tokens: optional argmax tokens
28+
- forward_computation: Two-pass online softmax + cross entropy loss with optional smoothing and softcapping
29+
- precision_sensitive_ops: [exp, log]
30+
31+
## Backward Pass
32+
- backward_saved_tensors: [_input] — gradient is computed during forward and stored in-place
33+
- backward_recompute: none (gradient-in-forward trick)
34+
- gradient_formulas:
35+
- d_input: already computed in forward pass and stored in _input tensor. Backward just scales by grad_output.
36+
37+
## Tiling Strategy
38+
- grid_dimensions: 1D
39+
- grid_description: one program per row (one row = one token's logits over vocab)
40+
- block_size_source: custom — iterates over vocab in chunks of BLOCK_SIZE
41+
- needs_sm_parallelism: false
42+
43+
## Module Parameters
44+
- module_init_params:
45+
- weight: optional class weights
46+
- ignore_index: int = -100
47+
- lse_square_scale: float = 0.0
48+
- label_smoothing: float = 0.0
49+
- reduction: str = "mean"
50+
- softcap: float or None = None
51+
- return_z_loss: bool = False
52+
- learnable_params: none
53+
54+
## Benchmarking
55+
- benchmark_variable: vocab_size (V)
56+
- benchmark_x_label: "V"
57+
- benchmark_x_values_suggestion: [4096, 8192, 16384, 32768, 65536, 131072]
58+
- benchmark_providers: ["liger", "huggingface"]
59+
- benchmark_fixed_config: {B: 8, T: 512, dtype: torch.bfloat16}
60+
61+
## Key Patterns
62+
63+
- **Online softmax**: Two-pass algorithm. Pass 1: compute running max and logsumexp. Pass 2: compute softmax and gradients. Avoids materializing the full softmax vector.
64+
- **Gradient-in-forward trick**: The forward kernel computes the gradient and stores it directly in `_input` (overwriting logits). The backward pass just retrieves this and multiplies by `grad_output`. This saves having to recompute softmax in backward.
65+
- **Constexpr flags for code elimination**: `HAS_WEIGHT`, `HAS_SOFTCAPPING`, `HAS_GRADIENTS`, `RETURN_Z_LOSS`, `RETURN_TOKEN_ACCURACY` — each is `tl.constexpr`, so the compiler removes unused code paths entirely.
66+
- **Chunked vocab iteration**: The kernel loops over vocabulary in `BLOCK_SIZE` chunks: `for i in range(0, n_cols, BLOCK_SIZE)`. This handles arbitrarily large vocabularies without requiring BLOCK_SIZE >= n_cols.
67+
- **Multiple loss components**: Combines original CE loss, label smoothing loss, and z-loss (for training stability). Each component contributes to both loss and gradient.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Kernel Profile: RMSNorm (Tier 2 — Reduction)
2+
3+
## Identity
4+
- operation_name: rms_norm
5+
- function_class_name: LigerRMSNormFunction
6+
- module_class_name: LigerRMSNorm
7+
- functional_name: liger_rms_norm
8+
9+
## Classification
10+
- tier: 2
11+
- tier_description: reduction
12+
- closest_existing_kernel: layer_norm (similar pattern with different normalization)
13+
14+
## Forward Pass
15+
- forward_inputs:
16+
- X: shape (B, T, H), input tensor
17+
- W: shape (H,), weight tensor
18+
- eps: float, epsilon for numerical stability
19+
- offset: float, weight offset (0.0 for Llama, 1.0 for Gemma)
20+
- casting_mode: str, "llama" | "gemma" | "none"
21+
- forward_outputs:
22+
- Y: shape (B, T, H), normalized output
23+
- forward_computation: Y = (X / RMS(X)) * (W + offset), RMS = sqrt(mean(X^2) + eps)
24+
- precision_sensitive_ops: [rsqrt]
25+
26+
## Backward Pass
27+
- backward_saved_tensors: [X, W, RSTD] — RSTD cached from forward to avoid recomputation
28+
- backward_recompute: none (RSTD is expensive to recompute)
29+
- gradient_formulas:
30+
- dX: rstd * (dY*(W+offset) - (1/N) * rstd^2 * dot(dY*(W+offset), X) * X)
31+
- dW: sum over (B,T) of dY * (X * rstd)
32+
33+
## Tiling Strategy
34+
- grid_dimensions: 1D
35+
- grid_description: forward uses one program per row `(n_rows,)`, backward uses SM-based partitioning `(sm_count,)` with `rows_per_program`
36+
- block_size_source: calculate_settings(n_cols)
37+
- needs_sm_parallelism: true (for dW reduction — each SM accumulates partial dW, then summed)
38+
39+
## Module Parameters
40+
- module_init_params:
41+
- hidden_size: int
42+
- eps: float = 1e-6
43+
- offset: float = 0.0
44+
- casting_mode: str = "llama"
45+
- init_fn: str = "ones"
46+
- in_place: bool = True
47+
- elementwise_affine: bool = True
48+
- learnable_params:
49+
- weight: shape (hidden_size,), init ones or zeros
50+
51+
## Benchmarking
52+
- benchmark_variable: hidden_size
53+
- benchmark_x_label: "hidden_size"
54+
- benchmark_x_values_suggestion: [1024, 2048, 4096, 8192, 16384]
55+
- benchmark_providers: ["liger", "huggingface"]
56+
- benchmark_fixed_config: {M: 4096, eps: 1e-6, dtype: torch.float32}
57+
58+
## Key Patterns
59+
60+
- **RSTD caching**: Forward computes and stores `rstd = rsqrt(mean(X^2) + eps)` — 1 value per row, tiny memory cost, saves 4 ops in backward
61+
- **SM-based backward**: Weight gradient needs reduction across all rows. Each SM processes `rows_per_program` rows and accumulates into `_dW[sm_id, :]`. Final `dW = _dW.sum(dim=0)`
62+
- **Casting modes as constexpr**: `casting_mode` is `tl.constexpr` so the compiler eliminates dead branches
63+
- **In-place backward option**: `in_place=True` writes dX into dY tensor to save memory. Set `False` when dY is needed elsewhere (e.g., Gemma2 residual)
64+
- **Two kernel variants**: Row-wise for `BLOCK_SIZE > 256 or n_rows < 4096*8`, block-wise otherwise (processes `BLOCK_ROW=16` rows per program for better GPU utilization)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Kernel Profile: SwiGLU (Tier 1 — Element-wise)
2+
3+
## Identity
4+
- operation_name: swiglu
5+
- function_class_name: LigerSiLUMulFunction
6+
- module_class_name: LigerSwiGLUMLP
7+
- functional_name: liger_swiglu
8+
9+
## Classification
10+
- tier: 1
11+
- tier_description: element-wise
12+
- closest_existing_kernel: geglu (same structure, different activation)
13+
14+
## Forward Pass
15+
- forward_inputs:
16+
- a: shape (B, T, H), gate projection output
17+
- b: shape (B, T, H), up projection output
18+
- forward_outputs:
19+
- c: shape (B, T, H), silu(a) * b
20+
- forward_computation: c = silu(a) * b, where silu(x) = x * sigmoid(x)
21+
- precision_sensitive_ops: [sigmoid]
22+
23+
## Backward Pass
24+
- backward_saved_tensors: [a, b] (reshaped to 2D in forward wrapper)
25+
- backward_recompute: recompute silu(a) and sigmoid(a) in backward
26+
- gradient_formulas:
27+
- da: dc * (silu(a) * (1 - sigmoid(a)) + sigmoid(a)) * b
28+
- db: dc * silu(a)
29+
30+
## Tiling Strategy
31+
- grid_dimensions: 1D
32+
- grid_description: one program per row, `(n_rows,)`
33+
- block_size_source: calculate_settings(n_cols)
34+
- needs_sm_parallelism: false
35+
36+
## Module Parameters
37+
- module_init_params:
38+
- config: HuggingFace model config object
39+
- learnable_params:
40+
- gate_proj: Linear(hidden_size, intermediate_size, bias=False)
41+
- up_proj: Linear(hidden_size, intermediate_size, bias=False)
42+
- down_proj: Linear(intermediate_size, hidden_size, bias=False)
43+
44+
## Benchmarking
45+
- benchmark_variable: hidden_size
46+
- benchmark_x_label: "hidden_size"
47+
- benchmark_x_values_suggestion: [1024, 2048, 4096, 8192, 16384]
48+
- benchmark_providers: ["liger", "torch", "torch_compile"]
49+
- benchmark_fixed_config: {BT: 4096, dtype: torch.bfloat16}
50+
51+
## Key Patterns
52+
53+
- **Recomputation over saving**: Forward saves `a, b` but backward recomputes `sigmoid(a)` and `silu(a)` — saves memory, sigmoid is cheap
54+
- **In-place backward**: Writes gradients directly to `a_ptr` and `b_ptr` (the saved tensors) — saves allocation
55+
- **Float32 for sigmoid**: `a_row` cast to `tl.float32` before `tl.sigmoid`, result cast back via `.cast(b_row.dtype)`
56+
- **No intermediate allocations**: Forward kernel writes directly to output `c`; backward kernel overwrites saved `a, b`

0 commit comments

Comments
 (0)