Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions olive/cli/capture_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def register_subcommand(parser: ArgumentParser):
"for the CUDA graph to be used correctly."
),
)
mb_group.add_argument(
"--extra_options",
Comment thread
xiaoyu-work marked this conversation as resolved.
Outdated
type=str,
required=False,
help="Extra key-value pairs options to pass to the model builder. e.g., 'int4_is_symmetric=true,int4_op_types_to_quantize=MatMul/Gemm'.",
)

sub_parser.add_argument(
"--use_ort_genai", action="store_true", help="Use OnnxRuntime generate() API to run the model"
Expand Down Expand Up @@ -194,6 +200,10 @@ def _get_run_config(self, tempdir: str) -> dict:
(("passes", "m", "enable_cuda_graph"), self.args.enable_cuda_graph),
]
)
if self.args.extra_options:
to_replace.append(
(("passes", "m", "extra_options"), self._parse_extra_options(self.args.extra_options.split(",")))
)
if self.args.int4_block_size is not None:
to_replace.append((("passes", "m", "int4_block_size"), self.args.int4_block_size))
if self.args.int4_accuracy_level is not None:
Expand Down Expand Up @@ -235,6 +245,18 @@ def _get_run_config(self, tempdir: str) -> dict:

return config

@staticmethod
def _parse_extra_options(kv_items):
kv_pairs = {}

if kv_items:
for kv_str in kv_items:
kv = kv_str.split("=")
kv_pairs[kv[0].strip()] = kv[1].strip()

print(f"Extra options: {kv_pairs}")
return kv_pairs


TEMPLATE = {
"systems": {
Expand Down
11 changes: 10 additions & 1 deletion olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"for the CUDA graph to be used correctly."
),
),
"extra_options": PassConfigParam(
type_=dict[str, Any],
required=False,
description="Extra key-value pairs options to pass to the model builder.",
),
}

@classmethod
Expand Down Expand Up @@ -191,7 +196,7 @@ def _run_for_config(
output_model_path: str,
) -> ONNXModelHandler:
try:
from onnxruntime_genai.models.builder import create_model
from onnxruntime_genai.models.builder import check_extra_options, create_model
except ImportError:
raise ImportError(
"onnxruntime-genai package is required to run ModelBuilder pass. Please install the package"
Expand Down Expand Up @@ -229,6 +234,10 @@ def _run_for_config(
if model.adapter_path:
extra_args["adapter_path"] = model.adapter_path

# Add extra options support for model builder
if config.extra_options:
Comment thread
jambayk marked this conversation as resolved.
Outdated
extra_args.update(check_extra_options(config.extra_options))

Comment thread
jambayk marked this conversation as resolved.
Outdated
extra_args.update(
{
key: value.value if isinstance(value, IntEnumBase) else value
Expand Down
6 changes: 5 additions & 1 deletion test/passes/onnx/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
def test_model_builder(tmp_path, metadata_only):
input_model = make_local_tiny_llama(tmp_path / "input_model", "onnx" if metadata_only else "hf")

p = create_pass_from_dict(ModelBuilder, {"precision": "fp32", "metadata_only": metadata_only}, disable_search=True)
p = create_pass_from_dict(
ModelBuilder,
{"precision": "fp32", "metadata_only": metadata_only, "extra_options": {"int4_is_symmetric": "true"}},
Comment thread
jambayk marked this conversation as resolved.
Outdated
Comment thread
jambayk marked this conversation as resolved.
Outdated
disable_search=True,
)
output_folder = tmp_path / "output_model"

# execute the pass
Expand Down
Loading