Skip to content

Commit 491cd7f

Browse files
authored
Add context_lengths option to QairtGenAIBuilder (#2505)
## Add context_lengths option to QairtGenAIBuilder (#2505) Introduces a `context_lengths` parameter to `QairtGenAIBuilder` that allows users to specify an explicit list of context lengths (CLs) to compile for HTP backends. This provides fine-grained control over which context-length binaries are produced, as an alternative to the fixed CL set generated by `multi_graph`. The two options are mutually exclusive. **Usage:** ```json "qgab": { "type": "QairtGenAIBuilder", "backend": "HTP", "context_lengths": [512, 1024, 2048, 3072, 4096, 6144, 8192, 10240, 13312, 16384] } ``` Sets `arn_cl_options.context_length` directly on the underlying GenAI builder when provided, bypassing the `multi_graph` path. ## Checklist before requesting a review - [x] I have added unit tests for the new parameter - [x] All tests pass locally - [x] I have updated relevant documentation/descriptions - [x] I have run linting
1 parent a2742ca commit 491cd7f

2 files changed

Lines changed: 121 additions & 3 deletions

File tree

olive/passes/qairt/gen_ai_builder.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
9494
default_value=False,
9595
description="Produces context binaries with additional context length combinations. "
9696
"Improves token generation performance for different context lengths but increases preparation time. "
97-
"HTP only.",
97+
"Mutually exclusive with context_lengths. HTP only.",
98+
),
99+
"context_lengths": PassConfigParam(
100+
type_=list[int],
101+
default_value=None,
102+
description="Explicit list of context lengths (CLs) to compile. "
103+
"Overrides the default CL set produced by multi_graph. "
104+
"Mutually exclusive with multi_graph. HTP only.",
98105
),
99106
}
100107

@@ -134,6 +141,13 @@ def validate_config(
134141
if config.multi_graph:
135142
logger.error("multi_graph is unsupported on non-HTP backends")
136143
return False
144+
if config.context_lengths:
145+
logger.error("context_lengths is unsupported on non-HTP backends")
146+
return False
147+
148+
if config.context_lengths and config.multi_graph:
149+
logger.error("context_lengths and multi_graph are mutually exclusive")
150+
return False
137151

138152
native_kv_supported_sequence_lengths = [[32, 128]]
139153
if config.native_kv and config.sequence_lengths not in native_kv_supported_sequence_lengths:
@@ -237,7 +251,12 @@ def _run_for_config(
237251
config.num_splits
238252
)
239253

240-
gen_ai_builder.multi_graph = config.multi_graph
254+
if config.context_lengths:
255+
gen_ai_builder._transformation_config.model_transformer_config.arn_cl_options.context_length = (
256+
config.context_lengths
257+
)
258+
else:
259+
gen_ai_builder.multi_graph = config.multi_graph
241260

242261
gen_ai_container = gen_ai_builder.build()
243262
gen_ai_container.save(output_model_path, exist_ok=True)

test/passes/qairt/test_gen_ai_builder.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import builtins
88
from pathlib import Path
9-
from unittest.mock import MagicMock, patch
9+
from unittest.mock import MagicMock, PropertyMock, patch
1010

1111
import pytest
1212

@@ -36,6 +36,8 @@ def test_gen_ai_builder_default_config(mock_accelerator_spec):
3636
assert config["num_splits"].default_value == -1
3737
assert "multi_graph" in config
3838
assert config["multi_graph"].default_value is False
39+
assert "context_lengths" in config
40+
assert config["context_lengths"].default_value is None
3941

4042

4143
def test_gen_ai_builder_cpu_backend_success(tmp_path, mock_hf_model, mock_qairt_modules):
@@ -495,3 +497,100 @@ def test_gen_ai_builder_native_kv_validation_valid_sequence_lengths(mock_acceler
495497
).config
496498

497499
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True
500+
501+
502+
def test_gen_ai_builder_context_lengths_configuration(tmp_path, mock_qairt_prepared_model, mock_qairt_modules):
503+
"""Test that context_lengths directly sets arn_cl_options.context_length."""
504+
output_path = tmp_path / "output"
505+
506+
mock_builder = MagicMock()
507+
mock_container = MagicMock()
508+
mock_builder.build.return_value = mock_container
509+
mock_builder._compilation_config = MagicMock()
510+
mock_builder._compilation_config.graph_custom_configs = [MagicMock()]
511+
mock_builder._compilation_config.device_custom_configs = [MagicMock()]
512+
mock_builder._compilation_config.context_custom_configs = [MagicMock()]
513+
mock_builder._transformation_config = MagicMock()
514+
mock_builder._transformation_config.model_transformer_config = MagicMock()
515+
mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock()
516+
mock_builder._transformation_config.model_transformer_config.split_model = MagicMock()
517+
518+
mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder
519+
520+
custom_cls = [1024, 2048, 4096]
521+
gen_ai_pass = create_pass_from_dict(
522+
QairtGenAIBuilder,
523+
{"backend": "HTP", "context_lengths": custom_cls},
524+
disable_search=True,
525+
)
526+
527+
result = gen_ai_pass.run(mock_qairt_prepared_model, str(output_path))
528+
529+
assert mock_builder._transformation_config.model_transformer_config.arn_cl_options.context_length == custom_cls
530+
assert isinstance(result, QairtModelHandler)
531+
532+
533+
def test_gen_ai_builder_context_lengths_skips_multi_graph_setter(
534+
tmp_path, mock_qairt_prepared_model, mock_qairt_modules
535+
):
536+
"""Test that multi_graph setter is not invoked when context_lengths is set."""
537+
output_path = tmp_path / "output"
538+
539+
mock_builder_cls = type("mock_builder_cls", (MagicMock,), {})
540+
mock_builder = mock_builder_cls()
541+
mock_container = MagicMock()
542+
mock_builder.build.return_value = mock_container
543+
mock_builder._compilation_config = MagicMock()
544+
mock_builder._compilation_config.graph_custom_configs = [MagicMock()]
545+
mock_builder._compilation_config.device_custom_configs = [MagicMock()]
546+
mock_builder._compilation_config.context_custom_configs = [MagicMock()]
547+
mock_builder._transformation_config = MagicMock()
548+
mock_builder._transformation_config.model_transformer_config = MagicMock()
549+
mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock()
550+
mock_builder._transformation_config.model_transformer_config.split_model = MagicMock()
551+
552+
multi_graph_mock = PropertyMock()
553+
type(mock_builder).multi_graph = multi_graph_mock
554+
555+
mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder
556+
557+
gen_ai_pass = create_pass_from_dict(
558+
QairtGenAIBuilder,
559+
{"backend": "HTP", "context_lengths": [512, 2048]},
560+
disable_search=True,
561+
)
562+
563+
gen_ai_pass.run(mock_qairt_prepared_model, str(output_path))
564+
multi_graph_mock.assert_not_called()
565+
566+
567+
def test_gen_ai_builder_validate_config_context_lengths_cpu_rejected(mock_accelerator_spec, mock_qairt_modules):
568+
"""Test that context_lengths is rejected on non-HTP backends."""
569+
config = create_pass_from_dict(
570+
QairtGenAIBuilder,
571+
{"backend": "CPU", "context_lengths": [1024, 2048]},
572+
disable_search=True,
573+
).config
574+
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False
575+
576+
577+
def test_gen_ai_builder_validate_config_context_lengths_and_multi_graph_rejected(
578+
mock_accelerator_spec, mock_qairt_modules
579+
):
580+
"""Test that context_lengths and multi_graph are mutually exclusive."""
581+
config = create_pass_from_dict(
582+
QairtGenAIBuilder,
583+
{"backend": "HTP", "context_lengths": [1024, 2048], "multi_graph": True},
584+
disable_search=True,
585+
).config
586+
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False
587+
588+
589+
def test_gen_ai_builder_validate_config_context_lengths_htp_valid(mock_accelerator_spec, mock_qairt_modules):
590+
"""Test that context_lengths passes validation on HTP."""
591+
config = create_pass_from_dict(
592+
QairtGenAIBuilder,
593+
{"backend": "HTP", "context_lengths": [1024, 2048, 4096]},
594+
disable_search=True,
595+
).config
596+
assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True

0 commit comments

Comments
 (0)