Skip to content

Commit f1b7e47

Browse files
Claude skill for automatically monkey patching HF transformers models (liger-autopatch) (#1167)
## Summary This PR adds a **Claude Code skill** (`liger-autopatch`) that automates adding Liger Kernel support for new HuggingFace Transformers models. Instead of manually writing 10+ files per model, a contributor can now say *"Add Liger Kernel support for nemotron"* and the skill handles analysis, code generation, testing, and validation through a 3-stage agentic pipeline. ### How it works 1. **Analyze** — A Model Analyzer agent reads the HF `modeling_*.py` source and produces a structured model profile, resolving 12 architectural decisions (norm type, MLP activation, MoE structure, RoPE variant, casting mode, etc.) 2. **Generate** — A Code Generator agent takes the profile and creates/modifies up to 13 files: `lce_forward`, monkey-patch function, `__init__.py` exports, output classes, instance patching tests, convergence tests (bf16 + fp32, FLCE + with_logits + multimodal), revert utilities, and README entry 3. **Validate** — A Validator agent runs the instance patching test, all applicable convergence tests, and `make checkstyle`, with up to 3 retry attempts per step ### Key design decisions - **Progressive disclosure**: SKILL.md is a concise entrypoint (~55 lines); detailed instructions live in separate agent files and templates - **Version compatibility**: Uses `try/except ImportError` for availability guards — CI on transformers 4.52.0 naturally skips newer models, no explicit version strings needed - **Placement enforcement**: Templates explicitly instruct alphabetical ordering for all insertions into existing files (imports, availability checkers, test functions, MINI_MODEL_SETUPS entries) - **Full convergence coverage**: Generates test entries in all 6 convergence test files (bf16/fp32 x FLCE/with_logits/multimodal) as appropriate for the model type ## Testing Done - Tested end-to-end on **nemotron** (dense, RMSNorm + SwiGLU + RoPE) → #1165 - Tested on **ministral** (dense, Mistral-family) → #1166 - Both PRs generated fully working code including monkey-patch, lce_forward, instance patching tests, and convergence tests across bf16/fp32 FLCE and with_logits variants - Iteratively refined the skill based on issues found during testing (alphabetical placement, convergence test coverage, checkstyle integration, version compatibility) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 39a1f45 commit f1b7e47

File tree

12 files changed

+934
-0
lines changed

12 files changed

+934
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
---
2+
name: liger-autopatch
3+
description: "Adds Liger Kernel support for a new HuggingFace Transformers model. Generates lce_forward, monkey-patch function, tests, and README entry. Use when adding a new model to Liger Kernel, when a user asks to patch an unsupported model, or when extending MODEL_TYPE_TO_APPLY_LIGER_FN."
4+
---
5+
6+
# Liger Auto-Patch
7+
8+
Adds Liger Kernel optimization support for a new HuggingFace model through a 3-stage pipeline with human review between stages.
9+
10+
## Pipeline
11+
12+
### Stage 1: Analyze
13+
14+
Spawn a **Model Analyzer** agent (read [model-analyzer.md](model-analyzer.md)).
15+
16+
The agent reads the HF `modeling_*.py` source and produces a **model profile** answering 12 architectural questions from [decision-matrix.md](decision-matrix.md).
17+
18+
**Human checkpoint:** Present the profile. Confirm before proceeding.
19+
20+
### Stage 2: Generate
21+
22+
Spawn a **Code Generator** agent (read [code-generator.md](code-generator.md)).
23+
24+
Generates/modifies up to 13 files:
25+
26+
1. `src/liger_kernel/transformers/model/{model}.py` — NEW lce_forward
27+
2. `src/liger_kernel/transformers/monkey_patch.py` — MODIFY
28+
3. `src/liger_kernel/transformers/__init__.py` — MODIFY
29+
4. `src/liger_kernel/transformers/model/output_classes.py` — MODIFY if needed
30+
5. `test/transformers/test_monkey_patch.py` — MODIFY
31+
6. `test/convergence/bf16/test_mini_models.py` — MODIFY (FLCE path)
32+
7. `test/convergence/bf16/test_mini_models_with_logits.py` — MODIFY (non-FLCE path)
33+
8. `test/convergence/fp32/test_mini_models.py` — MODIFY (FLCE path)
34+
9. `test/convergence/fp32/test_mini_models_with_logits.py` — MODIFY (non-FLCE path)
35+
10. `test/convergence/bf16/test_mini_models_multimodal.py` — MODIFY if VL model
36+
11. `test/convergence/fp32/test_mini_models_multimodal.py` — MODIFY if VL model
37+
12. `test/utils.py` — MODIFY
38+
13. `README.md` — MODIFY
39+
40+
**Human checkpoint:** Present changes for review.
41+
42+
### Stage 3: Validate
43+
44+
Spawn a **Validator** agent (read [validator.md](validator.md)).
45+
46+
Runs instance patching test, convergence test, and lint check. Retries up to 3 times on failure.
47+
48+
**Human checkpoint:** Report final test results.
49+
50+
## Reference Files
51+
52+
- [decision-matrix.md](decision-matrix.md) — 12 architectural decisions to resolve per model
53+
- [examples/llama-profile.md](examples/llama-profile.md) — Reference profile for standard dense model
54+
- [examples/gemma-profile.md](examples/gemma-profile.md) — Reference profile showing GeGLU + offset variant
55+
- Templates in [templates/](templates/) — Code generation patterns for each file type
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Code Generator Agent
2+
3+
Takes a confirmed model profile and generates all files to add Liger Kernel support.
4+
5+
## Pre-Requisites
6+
7+
Before generating, read the reference implementation closest to this model:
8+
- Dense → `src/liger_kernel/transformers/model/llama.py`
9+
- MoE → `src/liger_kernel/transformers/model/mixtral.py`
10+
- Vision-Language → `src/liger_kernel/transformers/model/qwen2_vl.py`
11+
- Gemma-family → `src/liger_kernel/transformers/model/gemma.py`
12+
13+
Also read the corresponding patching function in `monkey_patch.py` and the templates in [templates/](templates/).
14+
15+
## Files to Generate
16+
17+
### 1. `src/liger_kernel/transformers/model/{model_type}.py` (NEW)
18+
19+
The `lce_forward` function. See [templates/lce-forward-dense.md](templates/lce-forward-dense.md) or [templates/lce-forward-moe.md](templates/lce-forward-moe.md).
20+
21+
Key rules:
22+
- Match the exact forward signature from HF's `ForCausalLM.forward`
23+
- Use `lce_maybe_trainable_lm_head` from `llama.py` (shared PEFT/FSDP utility)
24+
- If model needs custom loss args (e.g., softcapping), write a local helper instead
25+
26+
### 2. `src/liger_kernel/transformers/monkey_patch.py` (MODIFY)
27+
28+
Three changes — see [templates/monkey-patch-fn.md](templates/monkey-patch-fn.md):
29+
30+
**A.** Add lce_forward import (~line 18-28):
31+
```python
32+
from liger_kernel.transformers.model.{model_type} import lce_forward as {model_type}_lce_forward
33+
```
34+
35+
**B.** Add `apply_liger_kernel_to_{model_type}` function with both class-level and instance-level patching paths.
36+
37+
**C.** Add entry to `MODEL_TYPE_TO_APPLY_LIGER_FN` dict (~line 3067).
38+
39+
### 3. `src/liger_kernel/transformers/__init__.py` (MODIFY)
40+
41+
Add the function in three locations (maintain alphabetical order):
42+
- `TYPE_CHECKING` block
43+
- `__getattr__` monkey_patch_symbols set
44+
- `__all__` list extension
45+
46+
### 4. `src/liger_kernel/transformers/model/output_classes.py` (MODIFY if needed)
47+
48+
Only for models needing custom output (MoE with `aux_loss`, VL with `rope_deltas`). Follow the existing guarded-import pattern in the file.
49+
50+
### 5. `test/transformers/test_monkey_patch.py` (MODIFY)
51+
52+
See [templates/test-instance-patch.md](templates/test-instance-patch.md). Add availability checker + skipif-decorated test function using `inspect.getsource()` assertions.
53+
54+
### 6. Convergence tests (MODIFY multiple files)
55+
56+
See [templates/test-convergence.md](templates/test-convergence.md). Every model needs entries in multiple convergence test files:
57+
58+
**All text models (dense + MoE)** — add to these 4 files:
59+
- `test/convergence/bf16/test_mini_models.py` — FLCE path, bf16
60+
- `test/convergence/bf16/test_mini_models_with_logits.py` — non-FLCE path (tests RMSNorm/SwiGLU/RoPE only), bf16
61+
- `test/convergence/fp32/test_mini_models.py` — FLCE path, fp32
62+
- `test/convergence/fp32/test_mini_models_with_logits.py` — non-FLCE path, fp32
63+
64+
**Vision-language models** — also add to these 2:
65+
- `test/convergence/bf16/test_mini_models_multimodal.py`
66+
- `test/convergence/fp32/test_mini_models_multimodal.py`
67+
68+
Each file needs: imports, availability guard, `MiniModelConfig` entry in `MINI_MODEL_SETUPS` dict, and a `pytest.param` entry in the parametrize block. The `MiniModelConfig` entry is identical across all files for the same model. The `pytest.param` tolerances differ — use bf16 tolerances (looser) for bf16 files and fp32 tolerances (tighter) for fp32 files. Copy tolerance values from a similar existing model (e.g., Llama for dense, Mixtral for MoE).
69+
70+
### 7. `test/utils.py` (MODIFY)
71+
72+
Add `revert_liger_kernel_to_{model_type}` function that reloads the modeling module.
73+
74+
### 8. `README.md` (MODIFY)
75+
76+
Add row to the Patching table under "### Patching":
77+
```
78+
| {ModelName} | `liger_kernel.transformers.apply_liger_kernel_to_{model_type}` | {Supported Operations} |
79+
```
80+
81+
## Code Style
82+
83+
- Line length 120, double quotes, single imports sorted with isort
84+
- Follow exact patterns from existing code — do not innovate on style
85+
- When modifying existing files, insert new entries in **alphabetical order** alongside similar existing entries. Never append to the end of a section — find the correct alphabetical position.
86+
- After generating all files, run `make checkstyle` to verify formatting. If it fails, run `ruff check . --fix && ruff format .` to auto-fix, then verify with `make checkstyle` again.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Decision Matrix
2+
3+
When analyzing a HuggingFace model for Liger Kernel support, you must resolve these 12 architectural decisions by reading the model's `modeling_*.py` source code.
4+
5+
## 1. Norm Type
6+
7+
**Question:** Does the model use RMSNorm, LayerNorm, or both?
8+
9+
**How to detect:**
10+
- Search for `class *RMSNorm` in the modeling file → RMSNorm
11+
- Search for `nn.LayerNorm` usage → LayerNorm
12+
- Multimodal models often use both (RMSNorm for text, LayerNorm for vision)
13+
14+
**Liger mapping:** `LigerRMSNorm` or `LigerLayerNorm`
15+
16+
## 2. RMSNorm Casting Mode
17+
18+
**Question:** How does the model handle dtype casting during normalization?
19+
20+
**How to detect:** Read the RMSNorm forward method:
21+
- Casts input to fp32, computes variance, casts back → `"gemma"`
22+
- Computes variance in fp32 only (input stays original dtype) → `"llama"`
23+
- No casting at all → `"none"`
24+
25+
**Default:** `"llama"` (most common)
26+
27+
## 3. RMSNorm Offset
28+
29+
**Question:** Does the weight have a +1.0 offset?
30+
31+
**How to detect:** In the RMSNorm forward, look for `(1 + self.weight)` or `self.weight + 1`:
32+
- Present → `offset=1.0` (Gemma family)
33+
- Absent → `offset=0.0` (most models)
34+
35+
## 4. RMSNorm In-Place
36+
37+
**Question:** Can the backward pass modify dY in-place?
38+
39+
**How to detect:** Check if the model has two sequential norm layers with a residual connection between them (like Gemma2's `pre_feedforward_layernorm` + `post_feedforward_layernorm`):
40+
- Sequential norms with residual → `in_place=False`
41+
- Otherwise → `in_place=True`
42+
43+
## 5. MLP Activation Type
44+
45+
**Question:** What activation function does the gated MLP use?
46+
47+
**How to detect:** Read the MLP class forward method:
48+
- `silu` or `F.silu` → SwiGLU → `LigerSwiGLUMLP`
49+
- `gelu` or `gelu_new` or `gelu_fast` → GeGLU → `LigerGEGLUMLP`
50+
- Phi3-style (single gate+up projection split) → `LigerPhi3SwiGLUMLP`
51+
52+
**Also check:** The config's `hidden_act` field.
53+
54+
## 6. Dense vs MoE
55+
56+
**Question:** Is the model dense, MoE, or hybrid (some layers dense, some MoE)?
57+
58+
**How to detect:**
59+
- Search for `Expert`, `MoE`, `SparseMoe`, `TopK` routing classes
60+
- Check if decoder layers have a `block_sparse_moe` or `experts` attribute
61+
- Hybrid: check for `is_moe_layer` or conditional MoE per-layer
62+
63+
**Liger mapping:**
64+
- Dense → standard patching
65+
- MoE (transformers v5) → `LigerExperts`
66+
- MoE (transformers v4) → `LigerBlockSparseTop2MLP`
67+
- Qwen3-style MoE → `LigerQwen3MoeSwiGLUMLP`
68+
69+
## 7. Vision Components
70+
71+
**Question:** Does the model have a vision encoder?
72+
73+
**How to detect:**
74+
- Check for `pixel_values` in the `forward` signature
75+
- Look for a separate vision model class (e.g., `*VisionModel`)
76+
- Check config for `vision_config` or `text_config` sub-configs
77+
78+
**If yes:**
79+
- Vision encoder norms are usually `nn.LayerNorm` → patch with `LigerLayerNorm`
80+
- Text and vision must be patched separately
81+
82+
## 8. RoPE Variant
83+
84+
**Question:** What type of positional embedding does the model use?
85+
86+
**How to detect:** Search for the `apply_rotary_pos_emb` function:
87+
- Standard (q, k, cos, sin) → `liger_rotary_pos_emb` (rope=True)
88+
- Llama4-style → `liger_llama4_text_rotary_pos_emb`
89+
- Qwen2VL MRoPE → `liger_multimodal_rotary_pos_emb`
90+
- No rotary embedding or custom variant → `rope=False`
91+
92+
## 9. Output Class
93+
94+
**Question:** What return type does the model's ForCausalLM.forward use?
95+
96+
**How to detect:** Read the return statement and type annotation:
97+
- Standard → `LigerCausalLMOutputWithPast`
98+
- MoE (has `aux_loss`) → `LigerMoeCausalLMOutputWithPast`
99+
- Custom VL output → create model-specific output class in `output_classes.py`
100+
101+
## 10. Hidden State Access
102+
103+
**Question:** How does the model access hidden states from base model output?
104+
105+
**How to detect:** In the ForCausalLM.forward, after calling `self.model(...)`:
106+
- `outputs[0]` → most models (Llama, Mistral, Gemma, etc.)
107+
- `outputs.last_hidden_state` → Phi3, Qwen3.5 MoE, some newer models
108+
109+
## 11. Logit Softcapping
110+
111+
**Question:** Does the model apply softcapping to logits before loss?
112+
113+
**How to detect:** Check config for `final_logit_softcapping`:
114+
- Present → pass `final_logit_softcapping=self.config.final_logit_softcapping` to `LigerForCausalLMLoss`
115+
- Absent → no softcapping (most models)
116+
- **VL models:** Config path may be `self.config.text_config.final_logit_softcapping` instead of `self.config.final_logit_softcapping`. Check whether the model uses a composite config with `text_config` sub-config.
117+
118+
**Models with softcapping:** Gemma2, Gemma3
119+
120+
## 12. Decoder Layer Norm Names
121+
122+
**Question:** What are the attribute names of norm layers in each decoder layer?
123+
124+
**How to detect:** Read the decoder layer class `__init__`:
125+
- Standard: `input_layernorm`, `post_attention_layernorm`
126+
- Gemma2 extra: `pre_feedforward_layernorm`, `post_feedforward_layernorm`
127+
- GLM4: `post_self_attn_layernorm`, `post_mlp_layernorm`
128+
- Some models: `q_norm`, `k_norm` on self_attn
129+
130+
Also check the final norm on the base model (usually `model.norm` or `model.final_layernorm`).
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Model Profile: Gemma
2+
3+
This profile demonstrates the key differences from Llama: GeGLU activation, RMSNorm offset, and Gemma casting mode.
4+
5+
## Identity
6+
- model_type: gemma
7+
- causal_lm_class: GemmaForCausalLM
8+
- base_model_class: GemmaModel
9+
- base_model_prefix: "model"
10+
- modeling_module: transformers.models.gemma.modeling_gemma
11+
- config_module: transformers.models.gemma.configuration_gemma
12+
13+
## Normalization
14+
- norm_class: GemmaRMSNorm
15+
- norm_type: RMSNorm
16+
- casting_mode: gemma (everything cast to fp32, then computed, then cast back)
17+
- offset: 1.0 (weight uses `1 + self.weight` pattern)
18+
- in_place: true
19+
- final_norm_attr: model.norm
20+
- decoder_norm_attrs:
21+
- input_layernorm
22+
- post_attention_layernorm
23+
- attn_norm_attrs: none
24+
25+
## MLP
26+
- mlp_class: GemmaMLP
27+
- activation: gelu (uses GELU activation, not SiLU)
28+
- liger_mlp_class: LigerGEGLUMLP
29+
- gate_proj_attr: gate_proj
30+
- up_proj_attr: up_proj
31+
- down_proj_attr: down_proj
32+
33+
## Structure
34+
- type: dense
35+
- moe_expert_class: n/a
36+
- moe_router_class: n/a
37+
- shared_expert: false
38+
39+
## Positional Embedding
40+
- rope_type: standard
41+
- rope_function: apply_rotary_pos_emb
42+
43+
## Output
44+
- output_class: LigerCausalLMOutputWithPast
45+
- hidden_state_access: outputs[0]
46+
- has_logit_softcapping: false
47+
- softcapping_config_attr: none
48+
49+
## Vision
50+
- has_vision: false
51+
52+
## Forward Signature
53+
Same as Llama — no extra parameters.
54+
55+
## Mini Model Config
56+
```python
57+
GemmaConfig(
58+
hidden_size=32,
59+
intermediate_size=64,
60+
num_hidden_layers=2,
61+
num_attention_heads=2,
62+
num_key_value_heads=2,
63+
vocab_size=1024,
64+
rms_norm_eps=1e-6,
65+
hidden_activation="gelu_pytorch_tanh",
66+
)
67+
```
68+
69+
## Key Differences from Llama
70+
71+
1. **Activation**: Uses `geglu` parameter (not `swiglu`) in the patch function
72+
2. **RMSNorm**: Requires `offset=1.0` and `casting_mode="gemma"`
73+
3. **MLP class**: `LigerGEGLUMLP` instead of `LigerSwiGLUMLP`
74+
4. **Patching uses partial**: `_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)`
75+
76+
## Gemma2 Additional Differences
77+
- `in_place: false` (residual between sequential norms)
78+
- Extra norm layers: `pre_feedforward_layernorm`, `post_feedforward_layernorm`
79+
- Has `final_logit_softcapping` in config
80+
- Uses `LigerRMSNormForGemma2` variant

0 commit comments

Comments
 (0)