Skip to content

Commit bab43f6

Browse files
Update
[ghstack-poisoned]
1 parent 4eced3b commit bab43f6

4 files changed

Lines changed: 383 additions & 0 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(WEBGPU_SRCS
3737
runtime/ops/update_cache/UpdateCache.cpp
3838
runtime/ops/sdpa/Sdpa.cpp
3939
runtime/ops/select_as_symint/SelectAsSymint.cpp
40+
runtime/ops/quantized_linear/QuantizedLinear.cpp
4041
)
4142

4243
add_library(webgpu_backend ${WEBGPU_SRCS})
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
11+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h>
13+
14+
#include <webgpu/webgpu.h>
15+
16+
#include <cstdint>
17+
#include <cstring>
18+
#include <stdexcept>
19+
20+
namespace executorch::backends::webgpu {
21+
22+
namespace {
23+
24+
// Uniform layout matching the WGSL Params struct (16-byte aligned, 32 bytes).
25+
struct Q4gswParams {
26+
uint32_t M;
27+
uint32_t N;
28+
uint32_t K;
29+
uint32_t K_packed;
30+
uint32_t group_size;
31+
uint32_t padded_N;
32+
uint32_t has_bias;
33+
uint32_t _pad;
34+
};
35+
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");
36+
37+
// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
38+
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
39+
const int in_id = args.at(0);
40+
const int weight_id = args.at(1);
41+
const int scales_id = args.at(2);
42+
const int group_size_id = args.at(3);
43+
const int bias_id = args.at(4);
44+
const int out_id = args.at(5);
45+
46+
WGPUDevice device = graph.device();
47+
48+
const auto& in = graph.get_tensor(in_id);
49+
const auto& weight = graph.get_tensor(weight_id);
50+
const auto& scales = graph.get_tensor(scales_id);
51+
const auto& out = graph.get_tensor(out_id);
52+
53+
if (in.dims.empty() || weight.dims.size() < 2 || scales.dims.size() < 2) {
54+
throw std::runtime_error("WebGPU linear_q4gsw: malformed input dims");
55+
}
56+
57+
// Shapes from the tensors' own dims (no dtype field at runtime).
58+
const uint32_t K = static_cast<uint32_t>(in.dims.back());
59+
if (K == 0) {
60+
throw std::runtime_error("WebGPU linear_q4gsw: K == 0");
61+
}
62+
uint64_t in_numel = 1;
63+
for (int64_t d : in.dims) {
64+
in_numel *= static_cast<uint64_t>(d);
65+
}
66+
const uint32_t M = static_cast<uint32_t>(in_numel / K);
67+
const uint32_t N = static_cast<uint32_t>(weight.dims[0]);
68+
const uint32_t K_packed = static_cast<uint32_t>(weight.dims[1]);
69+
const uint32_t num_groups = static_cast<uint32_t>(scales.dims[0]);
70+
const uint32_t padded_N = static_cast<uint32_t>(scales.dims[1]);
71+
if (M == 0 || N == 0) {
72+
throw std::runtime_error("WebGPU linear_q4gsw: M or N == 0");
73+
}
74+
// int4 packing is 2 nibbles/byte, so K_packed must be ceil(K/2) (guards OOB).
75+
if (K_packed != (K + 1) / 2) {
76+
throw std::runtime_error("WebGPU linear_q4gsw: K_packed must be ceil(K/2)");
77+
}
78+
79+
// One workgroup per output row (M); validate dispatch before any alloc.
80+
const uint32_t workgroup_count =
81+
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");
82+
83+
// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
84+
const uint64_t scales_numel =
85+
static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N);
86+
const uint64_t weight_numel =
87+
static_cast<uint64_t>(N) * static_cast<uint64_t>(K_packed);
88+
if (in.nbytes != in_numel * sizeof(float) ||
89+
out.nbytes != static_cast<uint64_t>(M) * N * sizeof(float) ||
90+
scales.nbytes != scales_numel * sizeof(float) ||
91+
weight.nbytes != weight_numel) {
92+
throw std::runtime_error(
93+
"WebGPU linear_q4gsw: fp32-only (byte-size mismatch)");
94+
}
95+
96+
int64_t group_size = 0;
97+
if (graph.get_value_type(group_size_id) == WebGPUGraph::ValueType::Int) {
98+
group_size = graph.get_int(group_size_id);
99+
}
100+
if (group_size <= 0) {
101+
throw std::runtime_error("WebGPU linear_q4gsw: group_size <= 0");
102+
}
103+
104+
// Optional bias: real buffer if present, else a dummy for the fixed layout.
105+
uint32_t has_bias = 0;
106+
WGPUBuffer bias_buffer = nullptr;
107+
uint64_t bias_size = 4;
108+
if (graph.get_value_type(bias_id) == WebGPUGraph::ValueType::Tensor) {
109+
const auto& bias = graph.get_tensor(bias_id);
110+
if (bias.buffer == nullptr || bias.nbytes < N * sizeof(float)) {
111+
throw std::runtime_error(
112+
"WebGPU linear_q4gsw: bias present but null/undersized");
113+
}
114+
has_bias = 1;
115+
bias_buffer = bias.buffer;
116+
bias_size = bias.nbytes;
117+
}
118+
if (bias_buffer == nullptr) {
119+
bias_buffer = graph.create_scratch_buffer(4);
120+
bias_size = 4;
121+
}
122+
123+
Q4gswParams params = {};
124+
params.M = M;
125+
params.N = N;
126+
params.K = K;
127+
params.K_packed = K_packed;
128+
params.group_size = static_cast<uint32_t>(group_size);
129+
params.padded_N = padded_N;
130+
params.has_bias = has_bias;
131+
132+
WGPUBufferDescriptor uniform_desc = {};
133+
uniform_desc.size = sizeof(Q4gswParams);
134+
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
135+
uniform_desc.mappedAtCreation = true;
136+
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
137+
void* mapped =
138+
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(Q4gswParams));
139+
std::memcpy(mapped, &params, sizeof(Q4gswParams));
140+
wgpuBufferUnmap(uniform_buffer);
141+
graph.add_uniform_buffer_bytes(sizeof(Q4gswParams));
142+
143+
WGPUShaderSourceWGSL wgsl_desc = {};
144+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
145+
wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN};
146+
WGPUShaderModuleDescriptor shader_desc = {};
147+
shader_desc.nextInChain = &wgsl_desc.chain;
148+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
149+
150+
// Bind group layout: out (rw) + in/weight/scales/bias (ro storage) + uniform.
151+
WGPUBindGroupLayoutEntry entries[6] = {};
152+
entries[0].binding = 0;
153+
entries[0].visibility = WGPUShaderStage_Compute;
154+
entries[0].buffer.type = WGPUBufferBindingType_Storage;
155+
for (uint32_t i = 1; i <= 4; i++) {
156+
entries[i].binding = i;
157+
entries[i].visibility = WGPUShaderStage_Compute;
158+
entries[i].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
159+
}
160+
entries[5].binding = 5;
161+
entries[5].visibility = WGPUShaderStage_Compute;
162+
entries[5].buffer.type = WGPUBufferBindingType_Uniform;
163+
164+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
165+
bgl_desc.entryCount = 6;
166+
bgl_desc.entries = entries;
167+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
168+
169+
WGPUPipelineLayoutDescriptor pl_desc = {};
170+
pl_desc.bindGroupLayoutCount = 1;
171+
pl_desc.bindGroupLayouts = &bgl;
172+
WGPUPipelineLayout pipeline_layout =
173+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
174+
175+
const uint32_t wg_size =
176+
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
177+
WGPUConstantEntry wg_size_constant = {};
178+
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
179+
wg_size_constant.value = static_cast<double>(wg_size);
180+
181+
WGPUComputePipelineDescriptor pipeline_desc = {};
182+
pipeline_desc.layout = pipeline_layout;
183+
pipeline_desc.compute.module = shader;
184+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
185+
pipeline_desc.compute.constantCount = 1;
186+
pipeline_desc.compute.constants = &wg_size_constant;
187+
WGPUComputePipeline pipeline =
188+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
189+
190+
WGPUBindGroupEntry bg_entries[6] = {};
191+
bg_entries[0].binding = 0;
192+
bg_entries[0].buffer = out.buffer;
193+
bg_entries[0].size = out.nbytes;
194+
bg_entries[1].binding = 1;
195+
bg_entries[1].buffer = in.buffer;
196+
bg_entries[1].size = in.nbytes;
197+
bg_entries[2].binding = 2;
198+
bg_entries[2].buffer = weight.buffer;
199+
bg_entries[2].size = weight.nbytes;
200+
bg_entries[3].binding = 3;
201+
bg_entries[3].buffer = scales.buffer;
202+
bg_entries[3].size = scales.nbytes;
203+
bg_entries[4].binding = 4;
204+
bg_entries[4].buffer = bias_buffer;
205+
bg_entries[4].size = bias_size;
206+
bg_entries[5].binding = 5;
207+
bg_entries[5].buffer = uniform_buffer;
208+
bg_entries[5].size = sizeof(Q4gswParams);
209+
210+
WGPUBindGroupDescriptor bg_desc = {};
211+
bg_desc.layout = bgl;
212+
bg_desc.entryCount = 6;
213+
bg_desc.entries = bg_entries;
214+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
215+
216+
graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"});
217+
218+
wgpuShaderModuleRelease(shader);
219+
wgpuBindGroupLayoutRelease(bgl);
220+
wgpuPipelineLayoutRelease(pipeline_layout);
221+
wgpuBufferRelease(uniform_buffer);
222+
}
223+
224+
} // namespace
225+
226+
WEBGPU_REGISTER_OPERATORS {
227+
WEBGPU_REGISTER_OP(et_vk.linear_q4gsw.default, q4gsw_linear_impl);
228+
}
229+
230+
} // namespace executorch::backends::webgpu
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
2+
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
3+
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
4+
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
5+
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;
6+
7+
struct Params {
8+
M: u32,
9+
N: u32,
10+
K: u32,
11+
K_packed: u32,
12+
group_size: u32,
13+
padded_N: u32,
14+
has_bias: u32,
15+
_pad: u32,
16+
}
17+
@group(0) @binding(5) var<uniform> params: Params;
18+
19+
override wg_size: u32 = 64u;
20+
21+
// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
22+
@compute @workgroup_size(wg_size, 1, 1)
23+
fn main(
24+
@builtin(workgroup_id) wid: vec3<u32>,
25+
@builtin(local_invocation_id) lid: vec3<u32>) {
26+
let m = wid.x;
27+
if (m >= params.M) {
28+
return;
29+
}
30+
let in_base = m * params.K;
31+
32+
var n: u32 = lid.x;
33+
loop {
34+
if (n >= params.N) {
35+
break;
36+
}
37+
var acc: f32 = 0.0;
38+
var k: u32 = 0u;
39+
loop {
40+
if (k >= params.K) {
41+
break;
42+
}
43+
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
44+
let byte_idx = n * params.K_packed + (k >> 1u);
45+
let word = t_weight[byte_idx >> 2u];
46+
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
47+
var nib: u32;
48+
if ((k & 1u) == 0u) {
49+
nib = b & 0x0Fu; // even k -> low nibble
50+
} else {
51+
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
52+
}
53+
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
54+
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
55+
acc = acc + t_input[in_base + k] * q * scale;
56+
k = k + 1u;
57+
}
58+
if (params.has_bias != 0u) {
59+
acc = acc + t_bias[n];
60+
}
61+
t_out[m * params.N + n] = acc;
62+
n = n + wg_size;
63+
}
64+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39
17+
inline constexpr const char* kQ4gswLinearWGSL = R"(
18+
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
19+
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
20+
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
21+
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
22+
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;
23+
24+
struct Params {
25+
M: u32,
26+
N: u32,
27+
K: u32,
28+
K_packed: u32,
29+
group_size: u32,
30+
padded_N: u32,
31+
has_bias: u32,
32+
_pad: u32,
33+
}
34+
@group(0) @binding(5) var<uniform> params: Params;
35+
36+
override wg_size: u32 = 64u;
37+
38+
// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
39+
@compute @workgroup_size(wg_size, 1, 1)
40+
fn main(
41+
@builtin(workgroup_id) wid: vec3<u32>,
42+
@builtin(local_invocation_id) lid: vec3<u32>) {
43+
let m = wid.x;
44+
if (m >= params.M) {
45+
return;
46+
}
47+
let in_base = m * params.K;
48+
49+
var n: u32 = lid.x;
50+
loop {
51+
if (n >= params.N) {
52+
break;
53+
}
54+
var acc: f32 = 0.0;
55+
var k: u32 = 0u;
56+
loop {
57+
if (k >= params.K) {
58+
break;
59+
}
60+
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
61+
let byte_idx = n * params.K_packed + (k >> 1u);
62+
let word = t_weight[byte_idx >> 2u];
63+
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
64+
var nib: u32;
65+
if ((k & 1u) == 0u) {
66+
nib = b & 0x0Fu; // even k -> low nibble
67+
} else {
68+
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
69+
}
70+
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
71+
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
72+
acc = acc + t_input[in_base + k] * q * scale;
73+
k = k + 1u;
74+
}
75+
if (params.has_bias != 0u) {
76+
acc = acc + t_bias[n];
77+
}
78+
t_out[m * params.N + n] = acc;
79+
n = n + wg_size;
80+
}
81+
}
82+
)";
83+
84+
inline constexpr uint32_t kQ4gswLinearWorkgroupSizeX = 64;
85+
inline constexpr uint32_t kQ4gswLinearWorkgroupSizeY = 1;
86+
inline constexpr uint32_t kQ4gswLinearWorkgroupSizeZ = 1;
87+
88+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)