Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions olive/passes/onnx/kquant_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
depending on onnxruntime's quantization modules.
"""

import fnmatch
import logging
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -248,7 +249,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"nodes_to_exclude": PassConfigParam(
type_=list,
default_value=None,
description="List of node names to exclude from quantization.",
description=(
"List of node names to exclude from quantization. An entry that contains a "
"wildcard ('*' or '?') is treated as a Unix shell-style glob pattern (e.g. "
"'*/projector/*' to exclude all projector MatMuls); all other entries are "
"matched by exact node name. A node is excluded if its name equals any exact "
"entry or matches any glob entry."
),
),
**get_external_data_config(),
}
Expand Down Expand Up @@ -290,13 +297,22 @@ def _quantize_model(
nodes_to_exclude = nodes_to_exclude or []
customized_weight_config = customized_weight_config or {}

# Split exclusion entries into exact names and glob patterns. Only entries that contain a
# wildcard ('*' or '?') are treated as globs; everything else is matched by exact name. This
# keeps exact names that happen to contain other fnmatch metacharacters (e.g. '[' / ']')
# from unintentionally matching unrelated nodes.
exclude_exact = set(nodes_to_exclude)
exclude_globs = [pattern for pattern in nodes_to_exclude if "*" in pattern or "?" in pattern]

globally_registered = {}

ir_model.graph.sort()
for node in ir_model.graph.all_nodes():
node_name = node.name

if node_name in nodes_to_exclude:
if node_name in exclude_exact or any(
fnmatch.fnmatchcase(node_name or "", pattern) for pattern in exclude_globs
):
logger.debug("Exclude quantization of %s as specified by nodes_to_exclude.", node_name)
continue

Expand Down
59 changes: 59 additions & 0 deletions test/passes/onnx/test_kquant_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,62 @@ def test_kquant_with_nodes_to_exclude(self, matmul_model_path, tmp_path):

assert len(matmul_nbits_nodes) == 1, "Expected 1 MatMulNBits node (MatMul_2 quantized)"
assert len(matmul_nodes) == 1, "Expected 1 original MatMul node (MatMul_1 excluded)"

def test_kquant_with_nodes_to_exclude_glob(self, matmul_model_path, tmp_path):
"""Test k-quant where nodes_to_exclude uses a glob pattern."""
olive_model = ONNXModelHandler(model_path=str(matmul_model_path))
accelerator_spec = AcceleratorSpec(
accelerator_type="CPU",
execution_provider="CPUExecutionProvider",
)
# "*_1" matches MatMul_1 only; MatMul_2 should still be quantized.
pass_config = {
"bits": 4,
"block_size": 32,
"nodes_to_exclude": ["*_1"],
}
p = create_pass_from_dict(
OnnxKQuantQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec
)

output_path = tmp_path / "quantized_glob_model.onnx"
quantized_model = p.run(olive_model, output_path)

assert os.path.exists(quantized_model.model_path)

quantized_onnx = onnx.load(quantized_model.model_path)
matmul_nbits_nodes = [n for n in quantized_onnx.graph.node if n.op_type == str(OpType.MatMulNBits)]
matmul_nodes = [n for n in quantized_onnx.graph.node if n.op_type == "MatMul"]

assert len(matmul_nbits_nodes) == 1, "Expected 1 MatMulNBits node (MatMul_2 quantized)"
assert len(matmul_nodes) == 1, "Expected 1 original MatMul node (MatMul_1 excluded via glob)"

def test_kquant_exclude_entry_with_metachars_is_exact(self, matmul_model_path, tmp_path):
"""An exclusion entry with fnmatch metacharacters but no wildcard is matched exactly.

'MatMul_[12]' must NOT glob-match node 'MatMul_1'; without a '*'/'?' wildcard the entry is
treated as an exact name, so both MatMuls are still quantized.
"""
olive_model = ONNXModelHandler(model_path=str(matmul_model_path))
accelerator_spec = AcceleratorSpec(
accelerator_type="CPU",
execution_provider="CPUExecutionProvider",
)
pass_config = {
"bits": 4,
"block_size": 32,
"nodes_to_exclude": ["MatMul_[12]"],
}
p = create_pass_from_dict(
OnnxKQuantQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec
)

output_path = tmp_path / "quantized_meta_model.onnx"
quantized_model = p.run(olive_model, output_path)

assert os.path.exists(quantized_model.model_path)

quantized_onnx = onnx.load(quantized_model.model_path)
matmul_nbits_nodes = [n for n in quantized_onnx.graph.node if n.op_type == str(OpType.MatMulNBits)]

assert len(matmul_nbits_nodes) == 2, "Expected both MatMuls quantized (bracket entry not glob-matched)"
Loading