Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,17 +519,33 @@ def make_quant_init(self, config):
)

def make_tied_embeddings_init(self, config):
# Determine if lm_head is unquantized. int4/8 can have options to int4_nodes_to_exclude. FP models are always unquantized.
self.unquantized_lm_head = "/lm_head/MatMul" in self.quant_attrs["nodes_to_exclude"] or self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16}
self.shared_embeddings = self.extra_options.get("shared_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False)
shared_embeddings = self.extra_options.get("shared_embeddings", config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False)

# Determine dtype for quantized lm_head
self.int4_lm_head = self.extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"}
self.int8_lm_head = self.extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last", "k_quant_linear", "rtn_last"}

# shared_embeddings conflicts with exclude_embeds and exclude_lm_head
if self.exclude_embeds or self.exclude_lm_head:
self.shared_embeddings = False
elif self.shared_embeddings and not self.unquantized_lm_head:
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
self.shared_embeddings = self.int8_lm_head or self.extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"}
# Determine if embeddings and lm_head will be quantized or not
self.quantized_embeds = (
self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}
and "Gather" in self.quant_attrs["op_types_to_quantize"]
and "/model/embed_tokens/Gather" not in self.quant_attrs["nodes_to_exclude"]
and not self.exclude_embeds
)
self.quantized_lm_head = (
self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}
and "MatMul" in self.quant_attrs["op_types_to_quantize"]
and "/lm_head/MatMul" not in self.quant_attrs["nodes_to_exclude"]
and not self.exclude_lm_head
and not self.prune_lm_head
)

if shared_embeddings:
self.tied_quantized_embeddings = self.quantized_embeds and self.quantized_lm_head
self.tied_unquantized_embeddings = not self.tied_quantized_embeddings
else:
self.tied_unquantized_embeddings = False
self.tied_quantized_embeddings = False
Comment on lines +525 to +529

def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
# Create config with attributes from config.json and generation_config.json (if latter file exists)
Expand Down Expand Up @@ -752,9 +768,8 @@ def to_int4(self) -> ir.Model:
def save_model(self, out_dir):
print(f"Saving ONNX model in {out_dir}")

already_quantized_in_qdq_format = (
self.quant_type is not None and self.quant_attrs["use_qdq"]
) # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
# Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
already_quantized_in_qdq_format = self.quant_type is not None and self.quant_attrs["use_qdq"]
if self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4} and not already_quantized_in_qdq_format:
model = self.to_int4()
else:
Expand Down Expand Up @@ -1413,7 +1428,7 @@ def make_embedding(self, embedding):
basename = "/model/embed_tokens"

# Use GatherBlockQuantized if and only if tied embeddings are enabled and export model is quantized. quantized d_type in set_onnx_dtype is INT4/UINT4
if self.shared_embeddings and self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}:
if self.tied_quantized_embeddings:
gather_name = f"{basename}/GatherBlockQuantized"
gather_output = f"{gather_name}/output_0"

Expand All @@ -1425,6 +1440,7 @@ def make_embedding(self, embedding):
f"/model/constants/INT64/[{self.vocab_size}, {flat_dim}]",
]
weight_reshape_output = f"{weight_reshape_name}/output_0"

# Quantized weight dtype is uint8. See here for more info:
# https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73
self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=[self.vocab_size, flat_dim])
Expand All @@ -1445,7 +1461,7 @@ def make_embedding(self, embedding):
)

# Use Transpose + Gather for tied embeddings for float embedding layers
elif self.shared_embeddings and self.unquantized_lm_head:
elif self.tied_unquantized_embeddings:
transpose_name = f"{basename}/Transpose"
transpose_output = f"{transpose_name}/output_0"
self.make_transpose(
Expand Down
239 changes: 239 additions & 0 deletions test/python/builder/test_tied_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from __future__ import annotations

import importlib.util
import sys
import types
from pathlib import Path

import onnx_ir as ir
import pytest

BUILDERS_DIR = Path(__file__).parents[3] / "src" / "python" / "py" / "models" / "builders"
sys.path.insert(0, str(BUILDERS_DIR.parents[1]))


def _load_builder_module(module_name):
spec = importlib.util.spec_from_file_location(f"models.builders.{module_name}", BUILDERS_DIR / f"{module_name}.py")
module = importlib.util.module_from_spec(spec)
sys.modules[f"models.builders.{module_name}"] = module
spec.loader.exec_module(module)
return module


sys.modules.setdefault("models", types.ModuleType("models"))
builders_package = sys.modules.setdefault("models.builders", types.ModuleType("models.builders"))
builders_package.__path__ = [str(BUILDERS_DIR)]

base_module = _load_builder_module("base")
Model = base_module.Model


def _make_model_for_tied_embeddings(
*,
shared_embeddings=None,
tie_word_embeddings=None,
onnx_dtype=ir.DataType.FLOAT16,
op_types=("MatMul", "Gather"),
nodes_to_exclude=(),
exclude_embeds=False,
exclude_lm_head=False,
prune_lm_head=False,
int4_algo_config="default",
):
model = Model.__new__(Model)
model.extra_options = {"int4_algo_config": int4_algo_config}
if shared_embeddings is not None:
model.extra_options["shared_embeddings"] = shared_embeddings
model.onnx_dtype = onnx_dtype
model.quant_attrs = {
"op_types_to_quantize": op_types,
"nodes_to_exclude": list(nodes_to_exclude),
}
model.exclude_embeds = exclude_embeds
model.exclude_lm_head = exclude_lm_head
model.prune_lm_head = prune_lm_head

config = types.SimpleNamespace(tie_word_embeddings=tie_word_embeddings)
model.make_tied_embeddings_init(config)
return model


def test_shared_embeddings_option_overrides_config_tie_word_embeddings():
model = _make_model_for_tied_embeddings(
shared_embeddings=False,
tie_word_embeddings=True,
onnx_dtype=ir.DataType.INT4,
)

assert model.quantized_embeds is True
assert model.quantized_lm_head is True
assert model.tied_quantized_embeddings is False
assert model.tied_unquantized_embeddings is False


def test_tie_word_embeddings_defaults_to_false_when_unset_or_none():
model_unset = _make_model_for_tied_embeddings()
model_none = _make_model_for_tied_embeddings(tie_word_embeddings=None)

assert model_unset.tied_quantized_embeddings is False
assert model_unset.tied_unquantized_embeddings is False
assert model_none.tied_quantized_embeddings is False
assert model_none.tied_unquantized_embeddings is False


@pytest.mark.parametrize(
"onnx_dtype, op_types, nodes_to_exclude, exclude_embeds, exclude_lm_head, prune_lm_head, expected_embeds, expected_lm_head",
[
(ir.DataType.INT4, ("MatMul", "Gather"), (), False, False, False, True, True),
(ir.DataType.UINT4, ("MatMul", "Gather"), (), False, False, False, True, True),
(ir.DataType.FLOAT16, ("MatMul", "Gather"), (), False, False, False, False, False),
(ir.DataType.INT4, ("MatMul",), (), False, False, False, False, True),
(ir.DataType.INT4, ("Gather",), (), False, False, False, True, False),
(ir.DataType.INT4, ("MatMul", "Gather"), ("/model/embed_tokens/Gather",), False, False, False, False, True),
(ir.DataType.INT4, ("MatMul", "Gather"), ("/lm_head/MatMul",), False, False, False, True, False),
(ir.DataType.INT4, ("MatMul", "Gather"), (), True, False, False, False, True),
(ir.DataType.INT4, ("MatMul", "Gather"), (), False, True, False, True, False),
(ir.DataType.INT4, ("MatMul", "Gather"), (), False, False, True, True, False),
],
)
def test_quantization_eligibility_for_embeddings_and_lm_head(
onnx_dtype,
op_types,
nodes_to_exclude,
exclude_embeds,
exclude_lm_head,
prune_lm_head,
expected_embeds,
expected_lm_head,
):
model = _make_model_for_tied_embeddings(
shared_embeddings=True,
tie_word_embeddings=False,
onnx_dtype=onnx_dtype,
op_types=op_types,
nodes_to_exclude=nodes_to_exclude,
exclude_embeds=exclude_embeds,
exclude_lm_head=exclude_lm_head,
prune_lm_head=prune_lm_head,
)

assert model.quantized_embeds is expected_embeds
assert model.quantized_lm_head is expected_lm_head


@pytest.mark.parametrize(
"quantized_embeds, quantized_lm_head, expected_tied_quantized, expected_tied_unquantized",
[
(True, True, True, False),
(True, False, False, True),
(False, True, False, True),
(False, False, False, True),
],
)
def test_shared_embeddings_prefers_quantized_path_only_when_both_layers_are_quantized(
quantized_embeds,
quantized_lm_head,
expected_tied_quantized,
expected_tied_unquantized,
):
op_types = tuple(op for enabled, op in ((quantized_embeds, "Gather"), (quantized_lm_head, "MatMul")) if enabled)
model = _make_model_for_tied_embeddings(
shared_embeddings=True,
tie_word_embeddings=False,
onnx_dtype=ir.DataType.INT4,
op_types=op_types,
)

assert model.tied_quantized_embeddings is expected_tied_quantized
assert model.tied_unquantized_embeddings is expected_tied_unquantized


@pytest.mark.parametrize(
"int4_algo_config, expected_int4_lm_head, expected_int8_lm_head",
[
("default", False, False),
("rtn", True, False),
("k_quant", True, False),
("k_quant_mixed", False, True),
("k_quant_last", False, True),
("k_quant_linear", False, True),
("rtn_last", False, True),
],
)
def test_lm_head_quantized_dtype_flags_derive_from_int4_algo_config(
int4_algo_config,
expected_int4_lm_head,
expected_int8_lm_head,
):
model = _make_model_for_tied_embeddings(
shared_embeddings=True,
tie_word_embeddings=True,
onnx_dtype=ir.DataType.INT4,
int4_algo_config=int4_algo_config,
)

assert model.int4_lm_head is expected_int4_lm_head
assert model.int8_lm_head is expected_int8_lm_head


def _make_minimal_model_for_int4_matmul():
model = Model.__new__(Model)
model.io_dtype = ir.DataType.FLOAT16
model.quant_attrs = {"accuracy_level": 0}

model._float_called = False
model._initializers = []
model._nodes = []
model._values = []

def _make_matmul_float(_matmul, _basename, _root_input, **_kwargs):
model._float_called = True
return "float_fallback"

def _make_initializer(tensor, name, to=None):
model._initializers.append((name, to, tensor))

def _make_node(op_type, **kwargs):
model._nodes.append((op_type, kwargs))

def _make_value(name, dtype, shape):
model._values.append((name, dtype, shape))

model.make_matmul_float = _make_matmul_float
model.make_initializer = _make_initializer
model.make_node = _make_node
model.make_value = _make_value
return model


def test_int4_matmul_uses_float_fallback_when_model_not_already_quantized():
model = _make_minimal_model_for_int4_matmul()

matmul = types.SimpleNamespace(weight=object())
result = model.make_matmul_int4(matmul, "/lm_head/MatMul", "hidden_states")

assert result == "float_fallback"
assert model._float_called is True
assert model._nodes == []


def test_int4_matmul_emits_matmul_nbits_when_model_already_quantized():
model = _make_minimal_model_for_int4_matmul()

matmul = types.SimpleNamespace(
qweight=object(),
scales=object(),
qzeros=object(),
g_idx=object(),
bits=4,
group_size=32,
in_features=64,
out_features=128,
)
result = model.make_matmul_int4(matmul, "/lm_head/MatMul", "hidden_states")

assert result == "/lm_head/MatMulNBits"
assert model._float_called is False
assert any(op_type == "MatMulNBits" for op_type, _ in model._nodes)
assert any(name == "lm_head.MatMulNBits.qweight" for name, _, _ in model._initializers)
assert any(name == "lm_head.MatMulNBits.scales" for name, _, _ in model._initializers)
10 changes: 3 additions & 7 deletions test/python/test_onnxruntime_genai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

import onnxruntime_genai as og
from _test_utils import download_models, is_webgpu_ep_available, run_subprocess
from models.test_gemma4_models import run_gemma4_vision_tests
from models.test_qwen_fara_models import run_qwen_fara_vision_tests

logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG)
log = logging.getLogger("onnxruntime-genai-tests")
Expand All @@ -21,14 +19,16 @@ def run_onnxruntime_genai_api_tests(
log: logging.Logger,
test_models: str | bytes | os.PathLike,
):
log.debug("Running: ONNX Runtime GenAI API Tests")
log.debug("Running: ONNX Runtime GenAI API, builder, and model tests")

command = [
sys.executable,
"-m",
"pytest",
"-sv",
"test_onnxruntime_genai_api.py",
"builder",
"models",
"--test_models",
test_models,
]
Expand Down Expand Up @@ -108,10 +108,6 @@ def main():
# Run ONNX Runtime GenAI tests
run_onnxruntime_genai_api_tests(os.path.abspath(args.cwd), log, os.path.abspath(args.test_models))

# Run vision model tests (tests auto-skip if models are not present)
run_gemma4_vision_tests(os.path.abspath(args.cwd), log, os.path.abspath(args.test_models))
run_qwen_fara_vision_tests(os.path.abspath(args.cwd), log, os.path.abspath(args.test_models))

if args.e2e:
run_onnxruntime_genai_e2e_tests(os.path.abspath(args.cwd), log, output_paths)

Expand Down
Loading