Commit 3f85b72
Add Kimi AttentionResiduals (AttnRes) kernelFeature/add attn res kernel (#1161)
## Summary
Implements **Attention Residuals (AttnRes)** from Kimi/Moonshot AI
([arxiv.org/abs/2603.15031](https://arxiv.org/abs/2603.15031)).
This PR addresses issue #1158.
### What is AttnRes?
AttnRes replaces standard residual connections with softmax attention
over depth blocks to solve the **PreNorm dilution** problem where deep
layer contributions get diluted:
```python
V = stack(blocks) # [N, B, T, D]
K = RMSNorm(V) # per-block normalize
scores = einsum(w, K) # [N, B, T] — w is [D] learned query
alpha = softmax(scores, 0) # over block dim
h = einsum(alpha, V) # [B, T, D] — weighted sum
```
### Implementation
- Single fused Triton kernel: RMSNorm + dot product + softmax + weighted sum
- Efficient memory usage: Scores stored in registers (supports N≤32 blocks)
- Complete autograd support: Forward and backward kernels with @ensure_contiguous decorator
- Benchmark script: Compares against PyTorch and torch.compile
### Files Added
- src/liger_kernel/ops/attn_res.py - Core kernel implementation (318 lines)
- benchmark/scripts/benchmark_attn_res.py - Benchmark and correctness tests (246 lines)
- Updated src/liger_kernel/ops/__init__.py to export LigerAttnResFunction
### Test Plan
- Run correctness tests: python benchmark/scripts/benchmark_attn_res.py --quick
- Run full benchmark: python benchmark/scripts/benchmark_attn_res.py
- Test on RTX 5090 (Blackwell architecture)
- Verify forward pass correctness (fp16, bf16, fp32)
- Verify backward pass correctness
### Benchmark Results
Tested on: **NVIDIA GeForce RTX 5090**, CUDA 12.8
> Note: Due to resource constraints, testing was only performed on RTX 5090. Additional testing on datacenter GPUs (A100, H100) would be valuable to validate performance across different architectures. Maintainers are welcome to run benchmarks on other hardware configurations.
To reproduce: `python benchmark/scripts/benchmark_attn_res.py`
#### Forward Pass Performance
| Config | PyTorch | torch.compile | Liger AttnRes | Speedup vs PyTorch | Speedup vs compile |
|--------|---------|---------------|---------------|--------------------|--------------------|
| N=4, D=4096, fp16 | 5.164 ms | 0.691 ms | **0.206 ms** | **25.11x** ✨ | 3.35x |
| N=8, D=4096, fp16 | 10.011 ms | 2.716 ms | **0.394 ms** | **25.39x** ✨ | 6.89x |
| N=8, D=8192, fp16 | 20.076 ms | 3.183 ms | **0.780 ms** | **25.74x** ✨ | 4.08x |
| N=16, D=4096, fp16 | 19.946 ms | 2.996 ms | **1.004 ms** | **19.86x** ✨ | 2.98x |
| N=8, D=4096, bf16 | 10.009 ms | 1.596 ms | **0.393 ms** | **25.47x** ✨ | 4.06x |
#### Forward + Backward Performance
| Config | PyTorch | Liger AttnRes | Speedup |
|--------|---------|---------------|---------|
| N=4, D=4096, fp16 | 20.525 ms | **1.003 ms** | **20.46x** ✨ |
| N=8, D=4096, fp16 | 39.880 ms | **1.747 ms** | **22.83x** ✨ |
| N=8, D=8192, fp16 | 79.723 ms | **3.317 ms** | **24.03x** ✨ |
| N=16, D=4096, fp16 | 80.480 ms | **3.555 ms** | **22.64x** ✨ |
| N=8, D=4096, bf16 | 39.888 ms | **1.744 ms** | **22.88x** ✨ |
#### Visualizations
**Speed - Full (Forward + Backward):**
<img width="1000" alt="attn_res_speed_full" src="https://github.com/user-attachments/assets/a1bb7710-1d37-4824-8d30-79fd9e14d774" />
**Speed - Forward:**
<img width="1000" alt="attn_res_speed_forward" src="https://github.com/user-attachments/assets/a5649c68-fbf4-4d8e-9cb6-8a119ca6f8a3" />
**Speed - Backward:**
<img width="1000" alt="attn_res_speed_backward" src="https://github.com/user-attachments/assets/78161703-fea0-405f-a2ec-ae5be44f72ae" />
**Memory - Full:**
<img width="1000" alt="attn_res_memory_full" src="https://github.com/user-attachments/assets/d902bf36-ff65-4657-a471-8e92524261e0" />
#### Correctness Tests
All tests passed with expected numerical precision:
- Forward pass: max diff < 2e-3 (fp16/bf16), < 2e-6 (fp32)
- Backward pass: max diff < 4e-3 (fp16/bf16), < 2e-6 (fp32)
### Key Insights
1. Exceptional speedup: 20-25x faster than PyTorch's einsum-based implementation
2. Beats torch.compile: 3-7x faster than torch.compile in forward pass
3. Scales well: Performance maintained across different N (number of blocks) and D (hidden dimension)
4. Memory efficient: Single-pass fused kernel minimizes memory traffic
### The dramatic speedup is achieved by:
- Fusing RMSNorm + attention + weighted sum into a single kernel
- Storing attention scores in registers (no global memory roundtrip)
- Optimized memory access patterns for coalesced reads/writes
---
Closes #1158
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>1 parent cc7d605 commit 3f85b72
File tree
5 files changed
+726
-0
lines changed- benchmark/scripts
- src/liger_kernel
- ops
- transformers
- test/transformers
5 files changed
+726
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
30 | 30 | | |
31 | 31 | | |
32 | 32 | | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
33 | 36 | | |
34 | 37 | | |
35 | 38 | | |
| |||
0 commit comments