Skip to content

Commit 51e3de5

Browse files
committed
Update
[ghstack-poisoned]
2 parents 0413e92 + 6c97971 commit 51e3de5

3 files changed

Lines changed: 118 additions & 1 deletion

File tree

backends/xnnpack/runtime/core/quant_params.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,15 @@ struct PerRowQuantParams {
6767
int8_t axis = -1;
6868
DType scale_dtype = DType::Float32;
6969
bool has_zero_point = false;
70+
// When true, this is a dynamically-quantized activation (XNNPACK qdint8):
71+
// the per-row scale/zero point are computed at runtime rather than stored.
72+
// `axis` is the reduced (channel) dim, so the number of trailing "row" dims
73+
// (XNNPACK's num_nonbatch_dims) is -axis for the usual negative axis.
74+
bool is_dynamic = false;
7075

7176
bool operator==(const PerRowQuantParams& o) const {
7277
return axis == o.axis && scale_dtype == o.scale_dtype &&
73-
has_zero_point == o.has_zero_point;
78+
has_zero_point == o.has_zero_point && is_dynamic == o.is_dynamic;
7479
}
7580
};
7681

backends/xnnpack/runtime/plan/xnn_subgraph.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,29 @@ runtime::Result<xnn_datatype> map_xnn_datatype(const graph::TensorSpec& spec) {
4444
case DType::QUInt8:
4545
return xnn_datatype_quint8;
4646
case DType::QInt8:
47+
if (auto* pr = std::get_if<core::PerRowQuantParams>(&*spec.quant_params);
48+
pr != nullptr && pr->is_dynamic) {
49+
return xnn_datatype_qdint8;
50+
}
4751
if (std::holds_alternative<core::PerAxisQuantParams>(
4852
*spec.quant_params)) {
4953
return xnn_datatype_qcint8;
5054
}
5155
return xnn_datatype_qint8;
5256
case DType::QInt32:
57+
// Per-channel bias is channelwise int32 (qcint32); per-tensor is qint32.
58+
if (std::holds_alternative<core::PerAxisQuantParams>(
59+
*spec.quant_params)) {
60+
return xnn_datatype_qcint32;
61+
}
5362
return xnn_datatype_qint32;
63+
case DType::QInt4:
64+
// 4-bit weights: per-channel (qcint4) or blockwise/group (qbint4).
65+
if (std::holds_alternative<core::PerBlockQuantParams>(
66+
*spec.quant_params)) {
67+
return xnn_datatype_qbint4;
68+
}
69+
return xnn_datatype_qcint4;
5470
default:
5571
ET_LOG(Error, "Unsupported quantized dtype for XNNPACK delegation");
5672
return runtime::Error::NotSupported;
@@ -727,6 +743,28 @@ runtime::Result<uint32_t> define_tensor(
727743
constant_tensor != nullptr && !constant_tensor->aux_storage.empty(),
728744
NotSupported,
729745
"Per-axis quantized tensor is missing scale data");
746+
// Per-channel asymmetric quantization (per-channel zero points) is not
747+
// supported; fail cleanly rather than silently using a zero zero-point.
748+
ET_CHECK_OR_RETURN_ERROR(
749+
!pa->has_zero_point,
750+
NotSupported,
751+
"Per-channel asymmetric quantization is not supported");
752+
// XNNPACK requires one scale per element of the quantized (channel) dim.
753+
ET_CHECK_OR_RETURN_ERROR(
754+
pa->axis >= 0 && static_cast<size_t>(pa->axis) < dims.size(),
755+
Internal,
756+
"Per-axis quant axis %d out of range (%zu dims)",
757+
(int)pa->axis,
758+
dims.size());
759+
size_t num_scales =
760+
constant_tensor->aux_storage[0].size_in_bytes / sizeof(float);
761+
ET_CHECK_OR_RETURN_ERROR(
762+
num_scales == dims[pa->axis],
763+
Internal,
764+
"Per-axis scale count %zu != channel dim %zu (axis %d)",
765+
num_scales,
766+
dims[pa->axis],
767+
(int)pa->axis);
730768
auto* scales =
731769
static_cast<const float*>(constant_tensor->aux_storage[0].data);
732770
int32_t zero_point = (xnn_dtype == xnn_datatype_qcint4) ? 8 : 0;
@@ -742,6 +780,79 @@ runtime::Result<uint32_t> define_tensor(
742780
external_id,
743781
flags,
744782
&id);
783+
} else if (
784+
auto* pb = std::get_if<core::PerBlockQuantParams>(&*spec.quant_params)) {
785+
// Blockwise 4-bit weight (qbint4): one bf16 scale per group of block_size
786+
// elements along the channel axis.
787+
ET_CHECK_OR_RETURN_ERROR(
788+
constant_tensor != nullptr && !constant_tensor->aux_storage.empty(),
789+
NotSupported,
790+
"Blockwise quantized tensor is missing scale data");
791+
ET_CHECK_OR_RETURN_ERROR(
792+
!pb->has_zero_point,
793+
NotSupported,
794+
"Blockwise asymmetric quantization is not supported");
795+
ET_CHECK_OR_RETURN_ERROR(
796+
pb->axis >= 0 && static_cast<size_t>(pb->axis) < dims.size() &&
797+
pb->block_size > 0,
798+
Internal,
799+
"Invalid blockwise quant axis/block_size");
800+
size_t num_elements = 1;
801+
for (auto d : dims)
802+
num_elements *= d;
803+
size_t expected_scales = num_elements / static_cast<size_t>(pb->block_size);
804+
size_t num_scales =
805+
constant_tensor->aux_storage[0].size_in_bytes / sizeof(uint16_t);
806+
ET_CHECK_OR_RETURN_ERROR(
807+
num_scales == expected_scales,
808+
Internal,
809+
"Blockwise scale count %zu != elements %zu / block_size %d",
810+
num_scales,
811+
num_elements,
812+
pb->block_size);
813+
// Scales are bf16 (uint16) per the blockwise quant convention.
814+
auto* scales =
815+
static_cast<const uint16_t*>(constant_tensor->aux_storage[0].data);
816+
int32_t zero_point = (xnn_dtype == xnn_datatype_qbint4) ? 8 : 0;
817+
status = xnn_define_blockwise_quantized_tensor_value(
818+
subgraph,
819+
xnn_dtype,
820+
zero_point,
821+
scales,
822+
spec.sizes.size(),
823+
static_cast<size_t>(pb->axis),
824+
static_cast<size_t>(pb->block_size),
825+
dims.data(),
826+
data,
827+
external_id,
828+
flags,
829+
&id);
830+
} else if (
831+
auto* pr = std::get_if<core::PerRowQuantParams>(&*spec.quant_params)) {
832+
// Dynamically-quantized activation: XNNPACK computes per-row scales at
833+
// runtime. Static per-row is not an XNNPACK path.
834+
ET_CHECK_OR_RETURN_ERROR(
835+
pr->is_dynamic,
836+
NotSupported,
837+
"Static per-row quantization is not supported");
838+
ET_CHECK_OR_RETURN_ERROR(
839+
data == nullptr,
840+
Internal,
841+
"Dynamically-quantized tensor must not have constant data");
842+
// num_nonbatch_dims is the count of trailing "row" dims, encoded as -axis.
843+
size_t ndims = spec.sizes.size();
844+
size_t num_nonbatch_dims = pr->axis < 0
845+
? static_cast<size_t>(-pr->axis)
846+
: ndims - static_cast<size_t>(pr->axis);
847+
status = xnn_define_dynamically_quantized_tensor_value(
848+
subgraph,
849+
xnn_dtype,
850+
ndims,
851+
num_nonbatch_dims,
852+
dims.data(),
853+
external_id,
854+
flags,
855+
&id);
745856
} else {
746857
ET_LOG(Error, "Unsupported quantization scheme for XNNPACK delegation");
747858
return runtime::Error::NotSupported;

backends/xnnpack/runtime/plan/xnn_support.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ bool check_xnn_dtype_support(core::DType dtype) {
1717
case core::DType::QUInt8:
1818
case core::DType::QInt8:
1919
case core::DType::QInt32:
20+
case core::DType::QInt4:
2021
return true;
2122
default:
2223
return false;

0 commit comments

Comments
 (0)