Skip to content

Commit cf633a3

Browse files
committed
XNNPACK graph runtime: layer norm in-tree operator
Adds the first in-tree (non-XNNPACK) operator: a fused float32 layer norm with scalar and aarch64 NEON kernels selected at runtime via cpuinfo, wired into create_operator so LayerNorm nodes dispatch to it. Also makes the Operator interface fallible: setup/prepare/reshape/execute now return runtime::Error, propagated through execution planning and the executor, so an unsupported dtype or malformed node fails cleanly instead of asserting or reading out of bounds. The operator is runtime-only for now; AOT serialization and partitioner support land separately. Covered by native graph e2e tests, including the error path. Authored with Claude. ghstack-source-id: c89130b ghstack-comment-id: 4713392085 Pull-Request: #20293
1 parent 75b1935 commit cf633a3

15 files changed

Lines changed: 615 additions & 14 deletions

File tree

backends/xnnpack/runtime/executor/executor.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,12 @@ runtime::Error Executor::run_step(size_t step_idx, const plan::PlanStep& step) {
6464
}
6565

6666
auto t0 = std::chrono::steady_clock::now();
67-
s.op->execute(
67+
err = s.op->execute(
6868
{inputs.data(), inputs.size()},
6969
{outputs.data(), outputs.size()});
70+
if (err != runtime::Error::Ok) {
71+
return;
72+
}
7073
auto t1 = std::chrono::steady_clock::now();
7174
auto us =
7275
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0)
@@ -231,7 +234,8 @@ runtime::Error Executor::update_planned_memory(Span<core::Tensor> inputs) {
231234
for (auto slot : op_step->input_slots) {
232235
input_specs.push_back(memory_plan.value_specs[slot]);
233236
}
234-
op_step->op->reshape({input_specs.data(), input_specs.size()});
237+
ET_CHECK_OK_OR_RETURN_ERROR(
238+
op_step->op->reshape({input_specs.data(), input_specs.size()}));
235239
}
236240

237241
return runtime::Error::Ok;
@@ -296,8 +300,8 @@ runtime::Result<Executor> Executor::build(graph::Graph& graph) {
296300
for (auto slot : op_step->output_slots)
297301
outputs.push_back(&values[slot]);
298302

299-
op_step->op->prepare(
300-
{inputs.data(), inputs.size()}, {outputs.data(), outputs.size()});
303+
ET_CHECK_OK_OR_RETURN_ERROR(op_step->op->prepare(
304+
{inputs.data(), inputs.size()}, {outputs.data(), outputs.size()}));
301305
}
302306

303307
auto t4 = std::chrono::steady_clock::now();
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <executorch/backends/xnnpack/runtime/kernels/layer_norm/layer_norm.h>
2+
#include <executorch/backends/xnnpack/runtime/kernels/layer_norm/layer_norm_scalar.h>
3+
#ifdef __aarch64__
4+
#include <executorch/backends/xnnpack/runtime/kernels/layer_norm/layer_norm_neon.h>
5+
#endif
6+
7+
#include <cpuinfo.h>
8+
9+
namespace executorch::backends::xnnpack::kernels {
10+
11+
LayerNormF32Fn select_layer_norm_f32_kernel() {
12+
#ifdef __aarch64__
13+
if (cpuinfo_initialize() && cpuinfo_has_arm_neon()) {
14+
return layer_norm_f32_neon;
15+
}
16+
#endif
17+
return layer_norm_f32_scalar;
18+
}
19+
20+
} // namespace executorch::backends::xnnpack::kernels
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
5+
namespace executorch::backends::xnnpack::kernels {
6+
7+
using LayerNormF32Fn = void (*)(
8+
const float* input,
9+
float* output,
10+
const float* weight,
11+
const float* bias,
12+
size_t outer_size,
13+
size_t inner_size,
14+
float eps);
15+
16+
LayerNormF32Fn select_layer_norm_f32_kernel();
17+
18+
} // namespace executorch::backends::xnnpack::kernels
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#ifdef __aarch64__
2+
3+
#include <executorch/backends/xnnpack/runtime/kernels/layer_norm/layer_norm_neon.h>
4+
5+
#include <arm_neon.h>
6+
#include <cassert>
7+
#include <cmath>
8+
9+
namespace executorch::backends::xnnpack::kernels {
10+
11+
namespace {
12+
float sum_f32_neon(const float* data, size_t len) {
13+
float32x4_t acc0 = vdupq_n_f32(0);
14+
float32x4_t acc1 = vdupq_n_f32(0);
15+
float32x4_t acc2 = vdupq_n_f32(0);
16+
float32x4_t acc3 = vdupq_n_f32(0);
17+
float32x4_t acc4 = vdupq_n_f32(0);
18+
float32x4_t acc5 = vdupq_n_f32(0);
19+
float32x4_t acc6 = vdupq_n_f32(0);
20+
float32x4_t acc7 = vdupq_n_f32(0);
21+
22+
size_t i = len;
23+
for (; i >= 32; i -= 32) {
24+
float32x4x2_t in01 = vld1q_f32_x2(data);
25+
float32x4x2_t in23 = vld1q_f32_x2(data + 8);
26+
float32x4x2_t in45 = vld1q_f32_x2(data + 16);
27+
float32x4x2_t in67 = vld1q_f32_x2(data + 24);
28+
29+
acc0 = vaddq_f32(acc0, in01.val[0]);
30+
acc1 = vaddq_f32(acc1, in01.val[1]);
31+
acc2 = vaddq_f32(acc2, in23.val[0]);
32+
acc3 = vaddq_f32(acc3, in23.val[1]);
33+
acc4 = vaddq_f32(acc4, in45.val[0]);
34+
acc5 = vaddq_f32(acc5, in45.val[1]);
35+
acc6 = vaddq_f32(acc6, in67.val[0]);
36+
acc7 = vaddq_f32(acc7, in67.val[1]);
37+
38+
data += 32;
39+
}
40+
41+
acc0 = vaddq_f32(acc0, acc1);
42+
acc2 = vaddq_f32(acc2, acc3);
43+
acc4 = vaddq_f32(acc4, acc5);
44+
acc6 = vaddq_f32(acc6, acc7);
45+
46+
acc0 = vaddq_f32(acc0, acc2);
47+
acc4 = vaddq_f32(acc4, acc6);
48+
49+
acc0 = vaddq_f32(acc0, acc4);
50+
51+
for (; i >= 4; i -= 4) {
52+
float32x4_t in = vld1q_f32(data);
53+
acc0 = vaddq_f32(acc0, in);
54+
data += 4;
55+
}
56+
57+
float acc = vaddvq_f32(acc0);
58+
59+
for (; i > 0; i--) {
60+
acc += *data;
61+
data++;
62+
}
63+
64+
return acc;
65+
}
66+
67+
float var_sum_f32_neon(const float* data, float mean, size_t len) {
68+
float32x4_t vmean = vdupq_n_f32(mean);
69+
70+
float32x4_t acc0 = vdupq_n_f32(0);
71+
float32x4_t acc1 = vdupq_n_f32(0);
72+
float32x4_t acc2 = vdupq_n_f32(0);
73+
float32x4_t acc3 = vdupq_n_f32(0);
74+
75+
size_t i = len;
76+
for (; i >= 16; i -= 16) {
77+
float32x4x2_t in01 = vld1q_f32_x2(data);
78+
float32x4x2_t in23 = vld1q_f32_x2(data + 8);
79+
80+
float32x4_t delta0 = vsubq_f32(in01.val[0], vmean);
81+
float32x4_t delta1 = vsubq_f32(in01.val[1], vmean);
82+
float32x4_t delta2 = vsubq_f32(in23.val[0], vmean);
83+
float32x4_t delta3 = vsubq_f32(in23.val[1], vmean);
84+
85+
float32x4_t delta_sq0 = vmulq_f32(delta0, delta0);
86+
float32x4_t delta_sq1 = vmulq_f32(delta1, delta1);
87+
float32x4_t delta_sq2 = vmulq_f32(delta2, delta2);
88+
float32x4_t delta_sq3 = vmulq_f32(delta3, delta3);
89+
90+
acc0 = vaddq_f32(acc0, delta_sq0);
91+
acc1 = vaddq_f32(acc1, delta_sq1);
92+
acc2 = vaddq_f32(acc2, delta_sq2);
93+
acc3 = vaddq_f32(acc3, delta_sq3);
94+
95+
data += 16;
96+
}
97+
98+
acc0 = vaddq_f32(acc0, acc1);
99+
acc2 = vaddq_f32(acc2, acc3);
100+
acc0 = vaddq_f32(acc0, acc2);
101+
102+
for (; i >= 4; i -= 4) {
103+
float32x4_t in = vld1q_f32(data);
104+
float32x4_t delta = vsubq_f32(in, vmean);
105+
float32x4_t delta_sq = vmulq_f32(delta, delta);
106+
acc0 = vaddq_f32(acc0, delta_sq);
107+
data += 4;
108+
}
109+
110+
float acc = vaddvq_f32(acc0);
111+
112+
for (; i > 0; i--) {
113+
float in = *data;
114+
float delta = in - mean;
115+
float delta_sq = delta * delta;
116+
acc += delta_sq;
117+
data++;
118+
}
119+
120+
return acc;
121+
}
122+
123+
template <bool UseWeightBias>
124+
void normalize_f32_neon(
125+
const float* input,
126+
float mean,
127+
float inv_std,
128+
const float* weight,
129+
const float* bias,
130+
float* out,
131+
size_t len) {
132+
float32x4_t vmean = vdupq_n_f32(mean);
133+
float32x4_t vinv_std = vdupq_n_f32(inv_std);
134+
135+
size_t i = len;
136+
for (; i >= 16; i -= 16) {
137+
float32x4x2_t in01 = vld1q_f32_x2(input);
138+
float32x4x2_t in23 = vld1q_f32_x2(input + 8);
139+
140+
float32x4_t norm0 = vmulq_f32(vsubq_f32(in01.val[0], vmean), vinv_std);
141+
float32x4_t norm1 = vmulq_f32(vsubq_f32(in01.val[1], vmean), vinv_std);
142+
float32x4_t norm2 = vmulq_f32(vsubq_f32(in23.val[0], vmean), vinv_std);
143+
float32x4_t norm3 = vmulq_f32(vsubq_f32(in23.val[1], vmean), vinv_std);
144+
145+
if constexpr (UseWeightBias) {
146+
float32x4x2_t w01 = vld1q_f32_x2(weight);
147+
float32x4x2_t w23 = vld1q_f32_x2(weight + 8);
148+
149+
float32x4x2_t b01 = vld1q_f32_x2(bias);
150+
float32x4x2_t b23 = vld1q_f32_x2(bias + 8);
151+
152+
norm0 = vmlaq_f32(b01.val[0], norm0, w01.val[0]);
153+
norm1 = vmlaq_f32(b01.val[1], norm1, w01.val[1]);
154+
norm2 = vmlaq_f32(b23.val[0], norm2, w23.val[0]);
155+
norm3 = vmlaq_f32(b23.val[1], norm3, w23.val[1]);
156+
157+
weight += 16;
158+
bias += 16;
159+
}
160+
161+
vst1q_f32(out, norm0);
162+
vst1q_f32(out + 4, norm1);
163+
vst1q_f32(out + 8, norm2);
164+
vst1q_f32(out + 12, norm3);
165+
166+
input += 16;
167+
out += 16;
168+
}
169+
170+
for (; i > 0; i--) {
171+
float in = *input;
172+
float norm = (in - mean) * inv_std;
173+
174+
if constexpr (UseWeightBias) {
175+
auto w = *weight;
176+
auto b = *bias;
177+
178+
norm = (norm * w) + b;
179+
180+
weight++;
181+
bias++;
182+
}
183+
184+
*out = norm;
185+
186+
input++;
187+
out++;
188+
}
189+
}
190+
} // anonymous namespace
191+
192+
void layer_norm_f32_neon(
193+
const float* input,
194+
float* output,
195+
const float* weight,
196+
const float* bias,
197+
size_t outer_size,
198+
size_t inner_size,
199+
float eps) {
200+
for (size_t i = 0; i < outer_size; i++) {
201+
const float* in_row = input + i * inner_size;
202+
float* out_row = output + i * inner_size;
203+
204+
float sum = sum_f32_neon(in_row, inner_size);
205+
float mean = sum / static_cast<float>(inner_size);
206+
207+
float var_sum = var_sum_f32_neon(in_row, mean, inner_size);
208+
float inv_std =
209+
1.0f / std::sqrt(var_sum / static_cast<float>(inner_size) + eps);
210+
211+
if (weight != nullptr) {
212+
assert(bias != nullptr);
213+
normalize_f32_neon<true>(
214+
in_row, mean, inv_std, weight, bias, out_row, inner_size);
215+
} else {
216+
assert(bias == nullptr);
217+
normalize_f32_neon<false>(
218+
in_row, mean, inv_std, nullptr, nullptr, out_row, inner_size);
219+
}
220+
}
221+
}
222+
223+
} // namespace executorch::backends::xnnpack::kernels
224+
225+
#endif
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
5+
namespace executorch::backends::xnnpack::kernels {
6+
7+
void layer_norm_f32_neon(
8+
const float* input,
9+
float* output,
10+
const float* weight,
11+
const float* bias,
12+
size_t outer_size,
13+
size_t inner_size,
14+
float eps);
15+
16+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <executorch/backends/xnnpack/runtime/kernels/layer_norm/layer_norm_scalar.h>
2+
3+
#include <cmath>
4+
5+
namespace executorch::backends::xnnpack::kernels {
6+
7+
void layer_norm_f32_scalar(
8+
const float* input,
9+
float* output,
10+
const float* weight,
11+
const float* bias,
12+
size_t outer_size,
13+
size_t inner_size,
14+
float eps) {
15+
for (size_t i = 0; i < outer_size; i++) {
16+
const float* in_row = input + i * inner_size;
17+
float* out_row = output + i * inner_size;
18+
19+
float sum = 0.0f;
20+
for (size_t j = 0; j < inner_size; j++) {
21+
sum += in_row[j];
22+
}
23+
float mean = sum / static_cast<float>(inner_size);
24+
25+
float var_sum = 0.0f;
26+
for (size_t j = 0; j < inner_size; j++) {
27+
float diff = in_row[j] - mean;
28+
var_sum += diff * diff;
29+
}
30+
float inv_std =
31+
1.0f / std::sqrt(var_sum / static_cast<float>(inner_size) + eps);
32+
33+
for (size_t j = 0; j < inner_size; j++) {
34+
float normalized = (in_row[j] - mean) * inv_std;
35+
if (weight) {
36+
normalized *= weight[j];
37+
}
38+
if (bias) {
39+
normalized += bias[j];
40+
}
41+
out_row[j] = normalized;
42+
}
43+
}
44+
}
45+
46+
} // namespace executorch::backends::xnnpack::kernels
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
5+
namespace executorch::backends::xnnpack::kernels {
6+
7+
void layer_norm_f32_scalar(
8+
const float* input,
9+
float* output,
10+
const float* weight,
11+
const float* bias,
12+
size_t outer_size,
13+
size_t inner_size,
14+
float eps);
15+
16+
}

0 commit comments

Comments
 (0)