Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1b8aeb8
MoE Kernel (#2465)
jeromeku May 3, 2025
223eb67
Llama4 MoE Grouped GEMM (#2639)
jeromeku May 28, 2025
4804886
Create LICENSE
danielhanchen May 28, 2025
8931a39
Fix Typos in Documentation and Comments (#2721)
leopardracer Jun 17, 2025
04250a4
Docs: Fix typo and improve MoE docstrings (#2784)
kilavvy Jun 23, 2025
74ad30e
MoE kernels AGPLv3
danielhanchen Jul 7, 2025
e33f89e
Formatting & bug fixes (#3563)
danielhanchen Nov 7, 2025
896e43a
Vllm guided decoding params (#3662)
Datta0 Dec 1, 2025
8f5bb27
Revert "[FIX] Vllm guided decoding params (#3662)"
danielhanchen Dec 1, 2025
702a3c4
auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2025
dd2f397
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
danielhanchen Dec 1, 2025
5847b78
auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2025
5ea7023
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
danielhanchen Dec 1, 2025
5218ad1
auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2025
24f72b4
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
danielhanchen Dec 1, 2025
89cfdd2
Improve moe kernels for unsloth fine tuning (#3812)
Datta0 Feb 5, 2026
b699bcc
fix for tma (#4023)
leizhenyuan Feb 11, 2026
5d1960e
fix(ROCm): prevent false TMA support detection on AMD GPUs (#4126)
GoldenGrapeGentleman Mar 1, 2026
62aa4db
Restructure imports for moe stuff and fixup qwen3moe extractor
Datta0 Mar 3, 2026
c5b008c
Update headers
Datta0 Mar 3, 2026
3efe043
Merge remote-tracking branch 'origin/main' into moe_kernels_refactor
Datta0 Mar 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions unsloth_zoo/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Unsloth Zoo kernel package namespace.
661 changes: 661 additions & 0 deletions unsloth_zoo/kernels/moe/LICENSE

Large diffs are not rendered by default.

86 changes: 86 additions & 0 deletions unsloth_zoo/kernels/moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
## MoE Grouped GEMM

Optimized implementation of `MoE MLP Block`.
Licensed under AGPLv3.

### Background

`MoE MLP` requires the following steps:
- Calculate `topk_weights` and `topk_indices`
- If using a grouped gemm implementation, calculate permutation indices needed to rearrange tokens grouped by expert
- For each expert:
- `expert_tokens`: gather the tokens assigned to the expert
- `first_gemm`: `gate / up proj` @ `expert_tokens`
- `silu_and_mul`: `silu` and `mul` of `first_gemm`
- `second_gemm`: `silu_and_mul` @ `down proj`
- `scatter_second_gemm`: scatter the `second_gemm` to the original token order
- `topk_weight_mul`: `second_gemm` @ `topk_weights`
- `final_output`: if `topk > 1`, `topk_weight_mul.view(num_tokens, topk, -1).sum(dim=1)` else `topk_weight_mul`

One way to eliminate the loop is to use a grouped GEMM, where all expert GEMMs are computed within a single kernel, which iterates over tiles of the expert GEMMs as individual GEMMs, where each GEMM, the `A` matrix is `M' x K` and the `B` matrix is `K x N`, where `M'` is the number of tokens assigned to the expert and `B` is the weight matrix for that expert.

This requires an additional permute (and subsequent copy) of the hidden states such that the tokens assigned to each expert are contiguous in memory before running the first grouped GEMM within the Expert MLP.
Additionally, after the second grouped GEMM, the hidden states must be permuted back to the original token order and multiplied by `topk_weights` to get the final output.

### Optimizations
This repo implements a grouped GEMM-based MoE MLP with the following optimizations:
- Eliminates the loop over experts by performing gemms as a grouped GEMM, computing the expert gemms within a single fused triton kernel
- Fuses the permutation of hidden states from token order (original input order) to expert order (tokens grouped by expert) within the prologue of first the first grouped GEMM
- Fuses the (un)permutation of hidden states from expert order back to token order in second GEMM
- Fuses the mul of hidden states by expert weights within epilogue of second GEMM (only implemented for inference, not for training)

### Structure
- `grouped_gemm/interface.py`: wrappers for the individual forward / backward kernels as well as the `torch.autograd.Function`
- `grouped_gemm/kernels/forward.py`: forward kernel
- `grouped_gemm/kernels/backward.py`: backward dX and dW kernels
- `grouped_gemm/kernels/tuning.py`: manual tuning utils
- `grouped_gemm/kernels/autotuning.py`: autotuning utils
- `grouped_gemm/reference/moe_block.py`: contains `Qwen3MoeFusedGroupedGEMMBlock`, a reference implementation of Huggingface `Qwen3SparseMOEBlock` with fused triton kernel in-place of original HF expert computation
- `grouped_gemm/reference/moe_ops.py`: supporting ops (routing, token sorting, etc.) and reference MoE block using a torch-native grouped gemm approach.

### Tests
- `grouped_gemm/tests/test_grouped_gemm.py`: unit tests for forward, backward grouped gemm kernels as well as the wrapped grouped gemm autograd.Function. Best not to run this entire test suite at once due to the large number of parametrized unit tests. Rather, use filters to run specific
sets of tests. E.g., to run forward tests with autotune turned on: `pytest -sv -k "forward and autotune" --tb=short tests/test_grouped_gemm.py`. Use the test function names and parameter ids for words to filter on.
- `grouped_gemm/tests/test_qwen3_moe.py`: end to end test for Qwen3 MoE block. IMPORTANT: read `tests/run_qwen3_moe_tests.sh` as well as notes in the test itself for complications when running parametrized pytest test suites and triton / autotune. TLDR: use the test script and NOT pytest to run the tests.

### Benchmarks
- `grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` or `Llama4TextMoe` against the fused implementation


Running with these flags on an `H100` to bench forward pass (run with `--help` to see all available flags):

For `Qwen3-30B-A3B`:
```
python benchmark/benchmark_fused_moe.py --model qwen3 --mode forward --seqlen 1024 --permute_x --permute_y --autotune
```

For the backward bench:
```
python benchmark/benchmark_fused_moe.py --model qwen3 --mode backward --seqlen 1024 --permute_x --permute_y --autotune
```

For `Llama-4-Scout-17B-16E`:
```
python benchmark/benchmark_fused_moe.py --model llama4 --autotune --mode=forward --permute_y
```
Ditto for backwards.

### Notes
- Tested and benched on `H100`, though should run on Ampere and possibly even earlier gpu generations though the autotuning configs will need to be adjusted.
- The env I used to develop the kernel was `pytorch 2.7/2.8` and `pytorch-triton 3.3`.
- The kernels can be run either as autotuned (see `autotuning.py`) or with manually specified config (see `tuning.py`). Recommended to run using autotuner since the MoE block requires 2 configs for the forward (2 grouped gemms) and 4 for the backwards (dX and dW per grouped gemm, 2 grouped gemms).
- Running with autotuning turned off with the default manual kernel config will result is **highly** sub-optimal performance as it is only meant for testing / debugging purposes.
- I've tried to strike a balance between compilation time and autotuning search space -- can probably squeeze even more performance for specific workloads.
- The Llama4 reference layer is still highly under-optimized as there are many low-hanging opportunities for further speedups around routing and shared expert calculation.

TODO:
- TMA store: implemented but not enabled currently due to non-determinism arising from triton pipelining bug.
- Warp specialization: Hopper support for WS not yet enabled on triton 3.3x branch which ships with latest pytorch 2.7.
- Additional optimizations:
- Fused / optimized implementations of routing, token sorting, etc.
- Better software pipelining within grouped gemm
- Threadblock swizzling for better L2 caching
- Llama4
- Fused gather / topk weight merging
- Custom topk, gather indices kernel
- Shared expert fusion with experts calculation
15 changes: 15 additions & 0 deletions unsloth_zoo/kernels/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
Loading