Skip to content

Commit d43568a

Browse files
[ExecuTorch][WebGPU] linear_q4gsw test suite: Llama-1B shapes + 4k/8k sweep
Pull Request resolved: #20227 Adds the numerical test suite for `et_vk.linear_q4gsw` (stacked on the op diff), mirroring the SDPA test suite. A named CONFIGS sweep covers real Llama-3.2-1B linear shapes — q/o-proj (2048->2048), k/v-proj (2048->512), gate/up-proj (2048->8192), down-proj (8192->2048), lm_head (2048->128256) — plus 4k/8k large-token prefill (M=4096/8192 on the 2048->2048 and 2048->512 projections). `test/ops/quantized_linear/test_quantized_linear.py` exports each config's `.pte` + an fp64 dequant-matmul "truth" golden; `test/test_webgpu_native.cpp` reconstructs the deterministic ramp input bit-for-bit, runs the op on the GPU, and compares per element; `scripts/test_webgpu_native_ci.sh` wires the fixtures into the Dawn(Tint)+SwiftShader CI. ghstack-source-id: 392908895 @exported-using-ghexport Differential Revision: [D108314849](https://our.internmc.facebook.com/intern/diff/D108314849/)
1 parent fe2e07b commit d43568a

4 files changed

Lines changed: 356 additions & 0 deletions

File tree

backends/webgpu/scripts/test_webgpu_native_ci.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ export_add_model('${PTE_MODEL}')
5454
export_chained_add_model('${PTE_CHAINED_MODEL}')
5555
" || echo "WARN: add export failed; webgpu_native_test self-skips models whose .pte is absent"
5656

57+
$PYTHON_EXECUTABLE -c "
58+
from executorch.backends.webgpu.test.ops.quantized_linear.test_quantized_linear import export_all_quantized_linear_models
59+
export_all_quantized_linear_models('/tmp')
60+
" || echo "WARN: q4gsw export failed; required configs will FAIL in webgpu_native_test"
61+
5762
$PYTHON_EXECUTABLE -c "
5863
from executorch.backends.webgpu.test.ops.rms_norm.test_rms_norm import export_rms_norm_cases
5964
export_rms_norm_cases('${RMS_NORM_DIR}')
@@ -143,6 +148,7 @@ if [[ -x "${BIN_DIR}/webgpu_native_test" && -f "${PTE_MODEL}" ]]; then
143148
env WEBGPU_TEST_MODEL="${PTE_MODEL}" \
144149
WEBGPU_TEST_CHAINED_MODEL="${PTE_CHAINED_MODEL}" \
145150
WEBGPU_TEST_SDPA_DIR=/tmp/ \
151+
WEBGPU_TEST_QUANTIZED_LINEAR_DIR=/tmp/ \
146152
"${BIN_DIR}/webgpu_native_test"
147153
else
148154
echo "(skipping webgpu_native_test: no exported .pte — needs the executorch python wheel)"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""4-bit weight-only quantized linear (`et_vk.linear_q4gsw`) export + fp64 golden.
8+
9+
Mirrors test_sdpa.py: a named CONFIGS sweep over real Llama-3.2-1B linear shapes
10+
(q/o/k/v/gate/up/down proj + lm_head) plus large-M (4k/8k) prefill stress, each
11+
exported through VulkanPartitioner (which fuses dq+linear into
12+
`et_vk.linear_q4gsw.default`). The golden is the fp64 dequant-matmul truth
13+
(x @ dequant(W).T), so the GPU's fp32 error is measured against truth, not another
14+
fp32 approximation. The native test (test_webgpu_native.cpp) mirrors the same
15+
CONFIGS table and reconstructs the identical deterministic ramp input bit-for-bit.
16+
"""
17+
18+
import os
19+
import unittest
20+
from dataclasses import dataclass
21+
22+
import numpy as np
23+
import torch
24+
25+
from executorch.backends.vulkan import VulkanPartitioner
26+
from executorch.exir import to_edge_transform_and_lower
27+
from torchao.quantization.granularity import PerGroup
28+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
29+
30+
31+
@dataclass(frozen=True)
32+
class Q4gswConfig:
33+
name: str
34+
m: int # rows (tokens)
35+
k: int # in_features (reduction dim)
36+
n: int # out_features
37+
group_size: int = 32 # K % group_size == 0, K % 8 == 0, N % 8 == 0
38+
# heavy = huge fixture / slow on a CPU rasterizer; export_all skips unless asked.
39+
heavy: bool = False
40+
41+
42+
# Single source of truth, mirrored by the C++ kQ4gswConfigs table. Llama-3.2-1B:
43+
# hidden=2048, n_heads=32 head_dim=64 (q/o=2048->2048), n_kv=8 (k/v=2048->512),
44+
# FFN=8192 (gate/up=2048->8192), down=8192->2048, vocab=128256 (lm_head).
45+
CONFIGS = [
46+
# name M K N
47+
Q4gswConfig("q_proj", 1, 2048, 2048), # also covers o_proj (same shape)
48+
Q4gswConfig("kv_proj", 1, 2048, 512), # k_proj / v_proj
49+
Q4gswConfig("gate_proj", 1, 2048, 8192), # gate_proj / up_proj
50+
Q4gswConfig("down_proj", 1, 8192, 2048), # big reduction K
51+
Q4gswConfig("lm_head", 1, 2048, 128256, heavy=True), # 131MB packed .pte
52+
Q4gswConfig("q_proj_4k", 4096, 2048, 2048), # 4k-token prefill
53+
Q4gswConfig("kv_proj_4k", 4096, 2048, 512),
54+
Q4gswConfig("q_proj_8k", 8192, 2048, 2048, heavy=True), # 67MB golden
55+
Q4gswConfig("kv_proj_8k", 8192, 2048, 512, heavy=True),
56+
]
57+
58+
59+
def _make_quantized_model(k: int, n: int, group_size: int) -> torch.nn.Module:
60+
torch.manual_seed(0) # load-bearing: fixes the weights the golden derives from
61+
m = torch.nn.Linear(k, n, bias=False).eval()
62+
quantize_(
63+
m,
64+
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(group_size)),
65+
)
66+
return m
67+
68+
69+
def _ramp_input(m_rows: int, k: int) -> torch.Tensor:
70+
"""Deterministic fp32 input [M,K]; C++ q4gsw_ramp reconstructs it bit-for-bit.
71+
72+
x[flat] = ((flat % 17) - 8) / 16 over the flat row-major index -- exact in fp32
73+
(small modulus, power-of-two denominator).
74+
"""
75+
flat = np.arange(m_rows * k, dtype=np.int64)
76+
x = ((flat % 17) - 8).astype(np.float32) / np.float32(16.0)
77+
return torch.from_numpy(x).reshape(m_rows, k)
78+
79+
80+
def _fp64_golden(m: torch.nn.Module, x: torch.Tensor) -> np.ndarray:
81+
"""fp64 truth: x @ dequant(W).T. The kernel computes the same dequant-matmul, so
82+
fp64 makes this the true answer -- GPU fp32 error is measured vs truth, not vs a
83+
second fp32 approximation. torchao handles the signed-nibble recovery in dequantize().
84+
"""
85+
wq = m.weight.dequantize() # AffineQuantizedTensor -> dequantized weight [N,K]
86+
golden = x.double() @ wq.double().t() # [M,N] in fp64
87+
return golden.to(torch.float32).numpy().astype("<f4")
88+
89+
90+
def _export(m: torch.nn.Module, x: torch.Tensor):
91+
ep = torch.export.export(m, (x,))
92+
return to_edge_transform_and_lower(
93+
ep, partitioner=[VulkanPartitioner()]
94+
).to_executorch()
95+
96+
97+
class TestQuantizedLinear(unittest.TestCase):
98+
def test_export_delegates(self) -> None:
99+
# Each (non-heavy) config must fuse to a VulkanBackend delegate (q4gsw);
100+
# fusion is shape-independent, so skipping the heavy 131MB+ fixtures is free.
101+
for cfg in CONFIGS:
102+
if cfg.heavy:
103+
continue
104+
with self.subTest(config=cfg.name):
105+
m = _make_quantized_model(cfg.k, cfg.n, cfg.group_size)
106+
et = _export(m, _ramp_input(1, cfg.k))
107+
found = any(
108+
d.id == "VulkanBackend"
109+
for plan in et.executorch_program.execution_plan
110+
for d in plan.delegates
111+
)
112+
self.assertTrue(found, f"no VulkanBackend delegate in {cfg.name}")
113+
114+
def test_golden_matches_eager(self) -> None:
115+
# Dual oracle (mirrors SDPA test_golden_matches_eager_op): the fp64 dequant-
116+
# matmul truth and torchao's own fp32 quantized forward are independent refs
117+
# that must agree -- guards a bug in the fp64 oracle / dequantize() accessor.
118+
# M=1 non-heavy shapes (cheap; the math is shape-independent).
119+
for cfg in CONFIGS:
120+
if cfg.m != 1 or cfg.heavy:
121+
continue
122+
with self.subTest(config=cfg.name):
123+
m = _make_quantized_model(cfg.k, cfg.n, cfg.group_size)
124+
x = _ramp_input(1, cfg.k)
125+
golden = torch.from_numpy(_fp64_golden(m, x))
126+
torch.testing.assert_close(m(x), golden, atol=5e-4, rtol=1e-3)
127+
128+
129+
def export_quantized_linear_model(
130+
cfg: Q4gswConfig, pte_path: str, golden_path: str
131+
) -> None:
132+
"""Export one config's q4gsw .pte + its fp64 golden (raw LE fp32)."""
133+
m = _make_quantized_model(cfg.k, cfg.n, cfg.group_size)
134+
x = _ramp_input(cfg.m, cfg.k)
135+
et = _export(m, x)
136+
with open(pte_path, "wb") as f:
137+
f.write(et.buffer)
138+
_fp64_golden(m, x).tofile(golden_path)
139+
print(f"Exported {pte_path}; golden {golden_path} ({cfg.m * cfg.n} floats)")
140+
141+
142+
def export_all_quantized_linear_models(
143+
out_dir: str, include_heavy: bool = False
144+
) -> None:
145+
"""Write q4gsw_<name>.pte + q4gsw_<name>.golden.bin for each config.
146+
147+
Heavy configs (lm_head 131MB .pte; M=8k 67MB goldens) are skipped unless
148+
include_heavy -- plain CI never writes them; a real-GPU run opts in.
149+
"""
150+
for cfg in CONFIGS:
151+
if cfg.heavy and not include_heavy:
152+
print(f"(skipping heavy config {cfg.name}; set include_heavy=True)")
153+
continue
154+
pte = os.path.join(out_dir, f"q4gsw_{cfg.name}.pte")
155+
golden = os.path.join(out_dir, f"q4gsw_{cfg.name}.golden.bin")
156+
export_quantized_linear_model(cfg, pte, golden)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main()

backends/webgpu/test/test_webgpu_native.cpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,166 @@ static bool sdpa_within_tol(
375375
return ok;
376376
}
377377

378+
// linear_q4gsw sweep config; mirrors CONFIGS in test_quantized_linear.py.
379+
struct Q4gswConfig {
380+
const char* name;
381+
int m; // rows (tokens)
382+
int k; // in_features (reduction dim)
383+
int n; // out_features
384+
float tol_abs; // per-element abs gate
385+
float tol_rel; // per-element rel gate
386+
bool required; // dir set + .pte absent => FAIL (not skip)
387+
bool heavy; // huge/slow: export-gated; runs only if WEBGPU_TEST_HEAVY
388+
};
389+
390+
// Llama-3.2-1B linear shapes (q/o/k/v/gate/up/down + lm_head) + 4k/8k prefill.
391+
// tol scales with K (fp32 accum depth), not M; down_proj (K=8192) is looser.
392+
static const Q4gswConfig kQ4gswConfigs[] = {
393+
// name M K N tol_abs tol_rel req heavy
394+
{"q_proj", 1, 2048, 2048, 1e-4f, 1e-3f, true, false},
395+
{"kv_proj", 1, 2048, 512, 1e-4f, 1e-3f, true, false},
396+
{"gate_proj", 1, 2048, 8192, 1e-4f, 1e-3f, true, false},
397+
{"down_proj", 1, 8192, 2048, 1e-3f, 1e-2f, true, false}, // big-K accum
398+
{"lm_head", 1, 2048, 128256, 1e-4f, 1e-3f, false, true},
399+
{"q_proj_4k", 4096, 2048, 2048, 1e-4f, 1e-3f, true, false},
400+
{"kv_proj_4k", 4096, 2048, 512, 1e-4f, 1e-3f, true, false},
401+
{"q_proj_8k", 8192, 2048, 2048, 1e-4f, 1e-3f, false, true},
402+
{"kv_proj_8k", 8192, 2048, 512, 1e-4f, 1e-3f, false, true},
403+
};
404+
405+
// /16 ramp over the flat index; mirrors test_quantized_linear.py _ramp_input.
406+
static float q4gsw_ramp(int i) {
407+
return static_cast<float>((i % 17) - 8) / 16.0f;
408+
}
409+
410+
// Per-element dual tolerance (abs OR rel), parameterized like sdpa_within_tol.
411+
static bool quant_within_tol(
412+
const float* out,
413+
const float* golden,
414+
int n,
415+
float atol,
416+
float rtol,
417+
float* ma,
418+
float* mr) {
419+
float max_abs = 0.0f, max_rel = 0.0f;
420+
bool ok = true;
421+
for (int i = 0; i < n; i++) {
422+
const float ae = std::abs(out[i] - golden[i]);
423+
const float re = ae / std::max(std::abs(golden[i]), 1e-6f);
424+
max_abs = std::max(max_abs, ae);
425+
max_rel = std::max(max_rel, re);
426+
if (ae > atol && re > rtol) {
427+
ok = false;
428+
}
429+
}
430+
*ma = max_abs;
431+
*mr = max_rel;
432+
return ok;
433+
}
434+
435+
// Reconstruct _ramp_input bit-for-bit, run the op, compare to the fp64 golden.
436+
static bool test_q4gsw_config(
437+
const Q4gswConfig& cfg,
438+
const std::string& pte,
439+
const std::string& golden_path) {
440+
printf(
441+
"\n--- Test: linear_q4gsw (%s: M=%d,K=%d,N=%d) ---\n",
442+
cfg.name,
443+
cfg.m,
444+
cfg.k,
445+
cfg.n);
446+
447+
Module module(pte);
448+
if (module.load_forward() != Error::Ok) {
449+
printf("FAIL: could not load %s\n", pte.c_str());
450+
return false;
451+
}
452+
453+
const int in_numel = cfg.m * cfg.k;
454+
const int out_numel = cfg.m * cfg.n;
455+
std::vector<float> input(in_numel);
456+
for (int i = 0; i < in_numel; i++) {
457+
input[i] = q4gsw_ramp(i);
458+
}
459+
460+
auto x = make_tensor_ptr({cfg.m, cfg.k}, std::vector<float>(input));
461+
auto result = module.forward({EValue(x)});
462+
if (!result.ok()) {
463+
printf("FAIL: forward failed (error %d)\n", (int)result.error());
464+
return false;
465+
}
466+
const auto& outputs = result.get();
467+
if (outputs.empty() || !outputs[0].isTensor()) {
468+
printf("FAIL: no tensor output\n");
469+
return false;
470+
}
471+
const auto& out_tensor = outputs[0].toTensor();
472+
if (out_tensor.numel() != out_numel) {
473+
printf(
474+
"FAIL: output numel %zu != expected %d\n",
475+
(size_t)out_tensor.numel(),
476+
out_numel);
477+
return false;
478+
}
479+
const float* out_data = out_tensor.const_data_ptr<float>();
480+
481+
std::vector<float> golden = load_golden(golden_path, out_numel);
482+
if (golden.empty()) {
483+
printf("FAIL: could not load golden %s\n", golden_path.c_str());
484+
return false;
485+
}
486+
487+
float ma = 0.0f, mr = 0.0f;
488+
const bool pass = quant_within_tol(
489+
out_data, golden.data(), out_numel, cfg.tol_abs, cfg.tol_rel, &ma, &mr);
490+
printf(
491+
"Max abs error: %e Max rel error: %e (checked %d elements)\n",
492+
ma,
493+
mr,
494+
out_numel);
495+
if (!pass) {
496+
printf(
497+
"FAIL: linear_q4gsw %s exceeds tolerance (abs %g OR rel %g)\n",
498+
cfg.name,
499+
cfg.tol_abs,
500+
cfg.tol_rel);
501+
return false;
502+
}
503+
printf("PASS: linear_q4gsw %s\n", cfg.name);
504+
return true;
505+
}
506+
507+
// q4gsw sweep: self-discover q4gsw_<name>.pte; required=FAIL, heavy=gate, *ran.
508+
static bool test_q4gsw_sweep(const std::string& dir, bool* ran) {
509+
bool ok = true;
510+
const bool heavy_run = std::getenv("WEBGPU_TEST_HEAVY") != nullptr;
511+
for (const auto& cfg : kQ4gswConfigs) {
512+
const std::string pte = dir + "q4gsw_" + cfg.name + ".pte";
513+
FILE* f = std::fopen(pte.c_str(), "rb");
514+
if (!f) {
515+
if (cfg.required && !dir.empty()) {
516+
printf(
517+
"FAIL: required q4gsw config %s has no .pte in %s\n",
518+
cfg.name,
519+
dir.c_str());
520+
ok = false;
521+
}
522+
continue;
523+
}
524+
std::fclose(f);
525+
if (cfg.heavy && !heavy_run) {
526+
printf(
527+
"SKIP: heavy q4gsw config %s (set WEBGPU_TEST_HEAVY=1 on a real GPU)\n",
528+
cfg.name);
529+
continue;
530+
}
531+
const std::string golden = dir + "q4gsw_" + cfg.name + ".golden.bin";
532+
*ran = true;
533+
ok = test_q4gsw_config(cfg, pte, golden) && ok;
534+
}
535+
return ok;
536+
}
537+
378538
// Fused sdpa_with_kv_cache sweep config. Mirrors the Python CONFIGS table in
379539
// test_sdpa.py exactly (name, Hq, Hkv, D, S, Cmax, input_pos).
380540
struct SdpaConfig {
@@ -1289,6 +1449,15 @@ int main(int argc, char** argv) {
12891449
update_cache_model_path = env;
12901450
}
12911451

1452+
// Quantized-linear sweep dir (mirrors WEBGPU_TEST_SDPA_DIR).
1453+
std::string qlinear_dir;
1454+
if (const char* env = std::getenv("WEBGPU_TEST_QUANTIZED_LINEAR_DIR")) {
1455+
qlinear_dir = env;
1456+
if (!qlinear_dir.empty() && qlinear_dir.back() != '/') {
1457+
qlinear_dir += '/';
1458+
}
1459+
}
1460+
12921461
// SDPA sweep: configs self-discover their sdpa_<name>.pte/.golden.bin under
12931462
// this directory (default "" = the embedded-file root / cwd). Set
12941463
// WEBGPU_TEST_SDPA_DIR to point at the exported .pte directory (e.g. /tmp/).
@@ -1326,6 +1495,22 @@ int main(int argc, char** argv) {
13261495
ok = test_update_cache(update_cache_model_path) && ok;
13271496
}
13281497

1498+
bool q4gsw_ran = false;
1499+
bool q4gsw_ok = test_q4gsw_sweep(qlinear_dir, &q4gsw_ran);
1500+
if (q4gsw_ran) {
1501+
ok = q4gsw_ok && ok;
1502+
}
1503+
// Guard python<->C++ ramp bit-identity: q4gsw_ramp(0) = -0.5 exactly.
1504+
if (std::abs(q4gsw_ramp(0) - (-0.5f)) > 1e-12f) {
1505+
printf("FAIL: q4gsw_ramp bit-identity check\n");
1506+
ok = false;
1507+
}
1508+
if (!qlinear_dir.empty() && !q4gsw_ran) {
1509+
printf(
1510+
"FAIL: WEBGPU_TEST_QUANTIZED_LINEAR_DIR set but no q4gsw config ran\n");
1511+
ok = false;
1512+
}
1513+
13291514
bool sdpa_ran = false;
13301515
bool sdpa_ok = test_sdpa_sweep(sdpa_dir, &sdpa_ran);
13311516
if (sdpa_ran) {

0 commit comments

Comments
 (0)