Skip to content

Commit 35cf6a0

Browse files
authored
[NPU]: optimize cross entropy kernel gradient computation (#1178)
- Hoist loop-invariant scalar computations (z_scale, one_minus_ls, z_deriv) out of the inner loop to avoid redundant recalculation - Fuse softmax, z-loss derivative, and smoothing term into a single vector expression in the non-weighted branch - Guard tl.where with block-range check (y >= i and y < i + BLOCK_SIZE) to skip unnecessary vector operations when target index is not in the current block Hardware Type: Type: Atlas 800I A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent d991472 commit 35cf6a0

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ def liger_cross_entropy_kernel(
194194
# dx_y = softmax(x_y) - 1
195195
# dx_i = softmax(x_i), for i ≠ y
196196
if HAS_GRADIENTS:
197+
# Hoist loop-invariant scalar computations
198+
z_scale = 1.0 + 2.0 * lse_square_scale * lse # (1 + 2 * lse_square_scale * lse)
199+
one_minus_ls = 1.0 - label_smoothing
200+
z_deriv = 2.0 * lse_square_scale * lse # for weighted branch
201+
197202
for i in range(0, n_cols, BLOCK_SIZE):
198203
X_offsets = i + tl.arange(0, BLOCK_SIZE)
199204
X_block = tl.load(
@@ -207,29 +212,28 @@ def liger_cross_entropy_kernel(
207212
X_block = softcap * intermediate
208213

209214
if not HAS_WEIGHT:
210-
# softmax(x_i)
211-
X_block = tl.exp(X_block - m) / d
212-
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
213-
X_block += 2 * lse_square_scale * lse * X_block
214-
# smoothing term
215-
X_block += -eps
215+
# softmax(x_i) * (1 + 2 * lse_square_scale * lse) - eps
216+
# Fuses: softmax, z-loss derivative, and smoothing term into fewer vector ops
217+
X_block = tl.exp(X_block - m) / d * z_scale - eps
216218
# special handle dx_y
217-
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
219+
if y >= i and y < i + BLOCK_SIZE:
220+
X_block = tl.where(X_offsets != y, X_block, X_block - one_minus_ls)
218221
# reduction scale
219222
if reduction == "mean":
220223
X_block = X_block / n_non_ignore
221224
else:
222225
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
223226
softmax_X = tl.exp(X_block - m) / d
224227
# derivative of original_loss
225-
dloss_ori = (1 - label_smoothing) * softmax_X
228+
dloss_ori = one_minus_ls * softmax_X
226229
# specially handle dx_y
227-
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
230+
if y >= i and y < i + BLOCK_SIZE:
231+
dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - one_minus_ls)
228232
dloss_ori = dloss_ori * weight_y
229233
# derivative of smooth_loss
230234
dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
231235
# derivative of z-loss
232-
dz_loss = 2 * lse_square_scale * lse * softmax_X
236+
dz_loss = z_deriv * softmax_X
233237
# reduction scale
234238
if reduction == "mean":
235239
dloss_ori = dloss_ori / sum_non_ignore_weight

0 commit comments

Comments
 (0)