|
6 | 6 |
|
7 | 7 | import builtins |
8 | 8 | from pathlib import Path |
9 | | -from unittest.mock import MagicMock, patch |
| 9 | +from unittest.mock import MagicMock, PropertyMock, patch |
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 |
|
@@ -36,6 +36,8 @@ def test_gen_ai_builder_default_config(mock_accelerator_spec): |
36 | 36 | assert config["num_splits"].default_value == -1 |
37 | 37 | assert "multi_graph" in config |
38 | 38 | assert config["multi_graph"].default_value is False |
| 39 | + assert "context_lengths" in config |
| 40 | + assert config["context_lengths"].default_value is None |
39 | 41 |
|
40 | 42 |
|
41 | 43 | def test_gen_ai_builder_cpu_backend_success(tmp_path, mock_hf_model, mock_qairt_modules): |
@@ -495,3 +497,100 @@ def test_gen_ai_builder_native_kv_validation_valid_sequence_lengths(mock_acceler |
495 | 497 | ).config |
496 | 498 |
|
497 | 499 | assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True |
| 500 | + |
| 501 | + |
| 502 | +def test_gen_ai_builder_context_lengths_configuration(tmp_path, mock_qairt_prepared_model, mock_qairt_modules): |
| 503 | + """Test that context_lengths directly sets arn_cl_options.context_length.""" |
| 504 | + output_path = tmp_path / "output" |
| 505 | + |
| 506 | + mock_builder = MagicMock() |
| 507 | + mock_container = MagicMock() |
| 508 | + mock_builder.build.return_value = mock_container |
| 509 | + mock_builder._compilation_config = MagicMock() |
| 510 | + mock_builder._compilation_config.graph_custom_configs = [MagicMock()] |
| 511 | + mock_builder._compilation_config.device_custom_configs = [MagicMock()] |
| 512 | + mock_builder._compilation_config.context_custom_configs = [MagicMock()] |
| 513 | + mock_builder._transformation_config = MagicMock() |
| 514 | + mock_builder._transformation_config.model_transformer_config = MagicMock() |
| 515 | + mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock() |
| 516 | + mock_builder._transformation_config.model_transformer_config.split_model = MagicMock() |
| 517 | + |
| 518 | + mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder |
| 519 | + |
| 520 | + custom_cls = [1024, 2048, 4096] |
| 521 | + gen_ai_pass = create_pass_from_dict( |
| 522 | + QairtGenAIBuilder, |
| 523 | + {"backend": "HTP", "context_lengths": custom_cls}, |
| 524 | + disable_search=True, |
| 525 | + ) |
| 526 | + |
| 527 | + result = gen_ai_pass.run(mock_qairt_prepared_model, str(output_path)) |
| 528 | + |
| 529 | + assert mock_builder._transformation_config.model_transformer_config.arn_cl_options.context_length == custom_cls |
| 530 | + assert isinstance(result, QairtModelHandler) |
| 531 | + |
| 532 | + |
| 533 | +def test_gen_ai_builder_context_lengths_skips_multi_graph_setter( |
| 534 | + tmp_path, mock_qairt_prepared_model, mock_qairt_modules |
| 535 | +): |
| 536 | + """Test that multi_graph setter is not invoked when context_lengths is set.""" |
| 537 | + output_path = tmp_path / "output" |
| 538 | + |
| 539 | + mock_builder_cls = type("mock_builder_cls", (MagicMock,), {}) |
| 540 | + mock_builder = mock_builder_cls() |
| 541 | + mock_container = MagicMock() |
| 542 | + mock_builder.build.return_value = mock_container |
| 543 | + mock_builder._compilation_config = MagicMock() |
| 544 | + mock_builder._compilation_config.graph_custom_configs = [MagicMock()] |
| 545 | + mock_builder._compilation_config.device_custom_configs = [MagicMock()] |
| 546 | + mock_builder._compilation_config.context_custom_configs = [MagicMock()] |
| 547 | + mock_builder._transformation_config = MagicMock() |
| 548 | + mock_builder._transformation_config.model_transformer_config = MagicMock() |
| 549 | + mock_builder._transformation_config.model_transformer_config.arn_cl_options = MagicMock() |
| 550 | + mock_builder._transformation_config.model_transformer_config.split_model = MagicMock() |
| 551 | + |
| 552 | + multi_graph_mock = PropertyMock() |
| 553 | + type(mock_builder).multi_graph = multi_graph_mock |
| 554 | + |
| 555 | + mock_qairt_modules["gen_ai_api"].GenAIBuilderFactory.create.return_value = mock_builder |
| 556 | + |
| 557 | + gen_ai_pass = create_pass_from_dict( |
| 558 | + QairtGenAIBuilder, |
| 559 | + {"backend": "HTP", "context_lengths": [512, 2048]}, |
| 560 | + disable_search=True, |
| 561 | + ) |
| 562 | + |
| 563 | + gen_ai_pass.run(mock_qairt_prepared_model, str(output_path)) |
| 564 | + multi_graph_mock.assert_not_called() |
| 565 | + |
| 566 | + |
| 567 | +def test_gen_ai_builder_validate_config_context_lengths_cpu_rejected(mock_accelerator_spec, mock_qairt_modules): |
| 568 | + """Test that context_lengths is rejected on non-HTP backends.""" |
| 569 | + config = create_pass_from_dict( |
| 570 | + QairtGenAIBuilder, |
| 571 | + {"backend": "CPU", "context_lengths": [1024, 2048]}, |
| 572 | + disable_search=True, |
| 573 | + ).config |
| 574 | + assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False |
| 575 | + |
| 576 | + |
| 577 | +def test_gen_ai_builder_validate_config_context_lengths_and_multi_graph_rejected( |
| 578 | + mock_accelerator_spec, mock_qairt_modules |
| 579 | +): |
| 580 | + """Test that context_lengths and multi_graph are mutually exclusive.""" |
| 581 | + config = create_pass_from_dict( |
| 582 | + QairtGenAIBuilder, |
| 583 | + {"backend": "HTP", "context_lengths": [1024, 2048], "multi_graph": True}, |
| 584 | + disable_search=True, |
| 585 | + ).config |
| 586 | + assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is False |
| 587 | + |
| 588 | + |
| 589 | +def test_gen_ai_builder_validate_config_context_lengths_htp_valid(mock_accelerator_spec, mock_qairt_modules): |
| 590 | + """Test that context_lengths passes validation on HTP.""" |
| 591 | + config = create_pass_from_dict( |
| 592 | + QairtGenAIBuilder, |
| 593 | + {"backend": "HTP", "context_lengths": [1024, 2048, 4096]}, |
| 594 | + disable_search=True, |
| 595 | + ).config |
| 596 | + assert QairtGenAIBuilder.validate_config(config, mock_accelerator_spec) is True |
0 commit comments