Skip to content

Commit 07dd9be

Browse files
Hailey-Zhlowdy1
andauthored
[NPU]support fused neighborhood attention for npu (#1034)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR introduces support for Fused Neighborhood Attention (FNA) optimized specifically for NPU architectures. The implementation focuses on memory efficiency and hardware affinity to prevent performance bottlenecks. Key modifications include: Grid Dimension Refactoring: Adjusted the attention grid to a 1D structure. This change optimizes thread block mapping and prevents User Buffer (UB) overflow, ensuring the workload fits within the NPU's local memory constraints. NPU-Affinity Softmax: Refactored the Softmax tiling and grid dimensions to align with NPU compute unit sizes, maximizing throughput and reducing synchronization overhead. ## Details <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Hardware Type: NPU(Atlas A2) ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> Tested passed with `python benchmark/scripts/benchmark_fused_neighborhood_attention.py ` `pytest -v test/transformers/test_fused_neighborhood_attention.py` - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: lowdy1 <xiahouweidong@gmail.com>
1 parent d8d6630 commit 07dd9be

File tree

2 files changed

+897
-0
lines changed

2 files changed

+897
-0
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
3030
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_backward
3131
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_forward
32+
from liger_kernel.ops.backends._ascend.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
33+
from liger_kernel.ops.backends._ascend.ops.fused_neighborhood_attention import fused_neighborhood_attention_forward
3234
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
3335
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
3436
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
@@ -136,4 +138,6 @@
136138
"LigerSparsemaxFunction",
137139
"sparsemax_forward",
138140
"sparsemax_backward",
141+
"LigerFusedNeighborhoodAttentionFunction",
142+
"fused_neighborhood_attention_forward",
139143
]

0 commit comments

Comments
 (0)