Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/adapter/tests/test_torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def test_pairwise_preference_generator(self) -> None:
surrogate=surrogate,
),
optimization_config=OptimizationConfig(
Objective(
objective=Objective(
metric=Metric(Keys.PAIRWISE_PREFERENCE_QUERY.value),
minimize=False,
)
Expand Down
284 changes: 210 additions & 74 deletions ax/core/optimization_config.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ax/core/tests/test_multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_setting_opt_config(self) -> None:
m3 = BraninMetric("m3", ["x1", "x2"])
self.experiment.add_tracking_metric(m3)
self.experiment.optimization_config = OptimizationConfig(
Objective(metric=m3, minimize=True)
objective=Objective(metric=m3, minimize=True)
)
self.assertDictEqual(
self.experiment._metric_to_trial_type,
Expand Down
109 changes: 107 additions & 2 deletions ax/core/tests/test_optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
ScalarizedOutcomeConstraint,
)
from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.utils.common.testutils import TestCase
from pyre_extensions import assert_is_instance


OC_STR = (
"OptimizationConfig("
'objective=Objective(expression="m1"), '
'objectives=[Objective(expression="m1")], '
"outcome_constraints=[OutcomeConstraint(m3 >= -0.25), "
"OutcomeConstraint(m4 <= 0.25), "
"ScalarizedOutcomeConstraint(0.5*m3 + 0.5*m4 >= 0.9975 * baseline)])"
Expand Down Expand Up @@ -271,6 +271,111 @@ def test_CloneWithArgs(self) -> None:
)


class OptimizationConfigObjectivesListTest(TestCase):
"""Tests for the new `OptimizationConfig(objectives=[...])` construction path."""

def setUp(self) -> None:
super().setUp()
self.metrics = {
"m1": Metric(name="m1"),
"m2": Metric(name="m2"),
"m3": Metric(name="m3"),
}
self.sig = {m: m for m in self.metrics}
self.obj1 = Objective(expression="m1", metric_name_to_signature=self.sig)
self.obj2 = Objective(expression="-m2", metric_name_to_signature=self.sig)
self.scalarized_obj = Objective(
expression="2*m1 + m2", metric_name_to_signature=self.sig
)

def test_objectives_kwarg_construction(self) -> None:
"""Test single and multi-objective construction via objectives kwarg."""
# Single objective
config = OptimizationConfig(objectives=[self.obj1])
self.assertEqual(config.objectives, [self.obj1])
self.assertEqual(config.objective, self.obj1)
self.assertFalse(config.is_moo_problem)

# Multi-objective
config = OptimizationConfig(objectives=[self.obj1, self.obj2])
self.assertEqual(config.objectives, [self.obj1, self.obj2])
self.assertTrue(config.is_moo_problem)
with self.assertRaisesRegex(UnsupportedError, "multiple objectives"):
config.objective

def test_objectives_kwarg_metric_aggregation(self) -> None:
"""Test metric_names, metric_name_to_signature, metric_signatures."""
constraint = OutcomeConstraint(
expression="m3 >= 0.5", metric_name_to_signature=self.sig
)
config = OptimizationConfig(
objectives=[self.obj1, self.obj2],
outcome_constraints=[constraint],
)
self.assertEqual(config.metric_names, {"m1", "m2", "m3"})
self.assertEqual(
config.metric_name_to_signature, {"m1": "m1", "m2": "m2", "m3": "m3"}
)
self.assertEqual(config.metric_signatures, {"m1", "m2", "m3"})

def test_objectives_kwarg_validation(self) -> None:
"""Test validation errors for objectives kwarg."""
with self.subTest("mutual_exclusivity"):
with self.assertRaisesRegex(UserInputError, "Cannot specify both"):
OptimizationConfig(objective=self.obj1, objectives=[self.obj1])

with self.subTest("neither_specified"):
with self.assertRaisesRegex(UserInputError, "Must specify either"):
OptimizationConfig()

with self.subTest("empty_list"):
with self.assertRaisesRegex(UserInputError, "must not be empty"):
OptimizationConfig(objectives=[])

with self.subTest("multi_objective_expression"):
multi_obj = Objective(
expression="m1, -m2", metric_name_to_signature=self.sig
)
with self.assertRaisesRegex(ValueError, "single or scalarized"):
OptimizationConfig(objectives=[multi_obj])

with self.subTest("duplicate_metric_names"):
obj_dup = Objective(expression="m1", metric_name_to_signature=self.sig)
with self.assertRaisesRegex(UserInputError, "appears in multiple"):
OptimizationConfig(objectives=[self.obj1, obj_dup])

def test_objectives_kwarg_clone_and_repr(self) -> None:
"""Test clone, clone_with_args, and repr for objectives-list configs."""
config = OptimizationConfig(objectives=[self.obj1, self.obj2])

# clone preserves objectives
cloned = config.clone()
self.assertEqual(len(cloned.objectives), 2)
self.assertEqual(cloned.objectives[0].expression, "m1")
self.assertEqual(cloned.objectives[1].expression, "-m2")
self.assertTrue(cloned.is_moo_problem)

# clone_with_args(objective=) replaces the list with a single objective
cloned = config.clone_with_args(objective=self.obj1)
self.assertEqual(len(cloned.objectives), 1)
self.assertFalse(cloned.is_moo_problem)

# clone_with_args(objectives=) replaces the list
obj3 = Objective(expression="m3", metric_name_to_signature=self.sig)
cloned = config.clone_with_args(objectives=[self.obj1, obj3])
self.assertEqual(len(cloned.objectives), 2)
self.assertEqual(cloned.objectives[1].expression, "m3")

# objective= and objectives= are mutually exclusive in clone_with_args
with self.assertRaisesRegex(UserInputError, "Cannot specify both"):
config.clone_with_args(objective=self.obj1, objectives=[self.obj1])

# repr always uses "objectives="
self.assertIn("objectives=", repr(config))
single_config = OptimizationConfig(objectives=[self.obj1])
self.assertIn("objectives=", repr(single_config))


class MultiObjectiveOptimizationConfigTest(TestCase):
def setUp(self) -> None:
super().setUp()
Expand Down
4 changes: 2 additions & 2 deletions ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2727,7 +2727,7 @@ def test_generate_candidates_does_not_generate_if_missing_data(self) -> None:
)
self.branin_experiment.add_tracking_metric(custom_metric)
self.branin_experiment.optimization_config = OptimizationConfig(
Objective(
objective=Objective(
metric=CustomTestMetric(
name="custom_test_metric", test_attribute="test"
),
Expand Down Expand Up @@ -2974,7 +2974,7 @@ def setUp(self) -> None:
self.branin_experiment_no_impl_runner_or_metrics = MultiTypeExperiment(
search_space=get_branin_search_space(),
optimization_config=OptimizationConfig(
Objective(metric=Metric(name="branin"), minimize=True)
objective=Objective(metric=Metric(name="branin"), minimize=True)
),
default_trial_type="type1",
default_runner=None,
Expand Down
6 changes: 3 additions & 3 deletions ax/service/tests/test_best_point_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_best_raw_objective_point_scalarized(self) -> None:
exp = get_branin_experiment()
gs = choose_generation_strategy_legacy(search_space=exp.search_space)
exp.optimization_config = OptimizationConfig(
ScalarizedObjective(metrics=[get_branin_metric()], minimize=True)
objective=ScalarizedObjective(metrics=[get_branin_metric()], minimize=True)
)
with self.assertRaisesRegex(ValueError, "Cannot identify best "):
get_best_raw_objective_point_with_trial_index(exp)
Expand All @@ -637,7 +637,7 @@ def test_best_raw_objective_point_scalarized_multi(self) -> None:
exp = get_branin_experiment()
gs = choose_generation_strategy_legacy(search_space=exp.search_space)
exp.optimization_config = OptimizationConfig(
ScalarizedObjective(
objective=ScalarizedObjective(
metrics=[get_branin_metric(), get_branin_metric(lower_is_better=False)],
weights=[0.1, -0.9],
minimize=True,
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def test_best_parameters_from_model_predictions_scalarized(self) -> None:
)
exp.add_tracking_metric(metric2)
exp.optimization_config = OptimizationConfig(
ScalarizedObjective(
objective=ScalarizedObjective(
metrics=[metric1, metric2],
weights=[0.5, 0.5],
minimize=True,
Expand Down
20 changes: 12 additions & 8 deletions ax/service/tests/test_report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,12 @@ def _test_get_standard_plots_moo_relative_constraints(
names = obj.metric_names
# Create a new Objective rather than mutating _expression_str to
# avoid stale _parsed cached_property.
none_throws(exp.optimization_config)._objective = Objective(
expression=f"{names[0]}, -{names[1]}",
metric_name_to_signature={n: n for n in names},
)
none_throws(exp.optimization_config)._objectives = [
Objective(
expression=f"{names[0]}, -{names[1]}",
metric_name_to_signature={n: n for n in names},
)
]
exp.get_metric(names[0]).lower_is_better = False
assert_is_instance(
exp.optimization_config, MultiObjectiveOptimizationConfig
Expand Down Expand Up @@ -494,10 +496,12 @@ def test_get_standard_plots_moo_no_objective_thresholds(self) -> None:
# first objective to maximize, second to minimize
obj = none_throws(exp.optimization_config).objective
names = obj.metric_names
none_throws(exp.optimization_config)._objective = Objective(
expression=f"{names[0]}, -{names[1]}",
metric_name_to_signature={n: n for n in names},
)
none_throws(exp.optimization_config)._objectives = [
Objective(
expression=f"{names[0]}, -{names[1]}",
metric_name_to_signature={n: n for n in names},
)
]
exp.trials[0].run()
plots = get_standard_plots(
experiment=exp,
Expand Down
7 changes: 7 additions & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,13 @@ def object_from_json(
object_json = _sanitize_inputs_to_surrogate_spec(object_json=object_json)
if isclass(_class) and issubclass(_class, OptimizationConfig):
object_json.pop("risk_measure", None) # Deprecated.
# Backward compat: old JSON uses "objective", new uses "objectives".
if (
_class is OptimizationConfig
and "objective" in object_json
and "objectives" not in object_json
):
object_json["objectives"] = [object_json.pop("objective")]
return ax_class_from_json_dict(
_class=_class, object_json=object_json, **vars(registry_kwargs)
)
Expand Down
11 changes: 6 additions & 5 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def optimization_config_to_dict(
"""Convert Ax optimization config to a dictionary."""
return {
"__type": optimization_config.__class__.__name__,
"objective": optimization_config.objective,
"objectives": optimization_config.objectives,
"outcome_constraints": optimization_config.outcome_constraints,
"pruning_target_parameterization": (
optimization_config.pruning_target_parameterization
Expand Down Expand Up @@ -782,16 +782,17 @@ def _build_opt_config_dict(
will then recursively encode them via ``metric_to_dict``, capturing the
full metric type.
"""
objective_dict = _build_objective_dict(
objective=opt_config.objective, experiment_metrics=experiment_metrics
)
objective_dicts = [
_build_objective_dict(objective=obj, experiment_metrics=experiment_metrics)
for obj in opt_config.objectives
]
constraint_dicts = [
_build_constraint_dict(constraint=c, experiment_metrics=experiment_metrics)
for c in opt_config.outcome_constraints
]
result: dict[str, Any] = {
"__type": opt_config.__class__.__name__,
"objective": objective_dict,
"objectives": objective_dicts,
"outcome_constraints": constraint_dicts,
"pruning_target_parameterization": opt_config.pruning_target_parameterization,
}
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
get_metric,
get_mll_type,
get_model_type,
get_moo_optimization_config,
get_multi_objective,
get_multi_objective_optimization_config,
get_multi_type_experiment,
Expand Down Expand Up @@ -380,6 +381,7 @@
("Objective", get_objective),
("ObjectiveThreshold", get_objective_threshold),
("OptimizationConfig", get_optimization_config),
("OptimizationConfig", get_moo_optimization_config),
("OrEarlyStoppingStrategy", get_or_early_stopping_strategy),
("OrderConstraint", get_order_constraint),
("OutcomeConstraint", get_outcome_constraint),
Expand Down
12 changes: 7 additions & 5 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def opt_config_and_tracking_metrics_from_sqa(
register the full metric types (e.g. BraninMetric) rather than
plain Metric placeholders.
"""
objective = None
objectives: list[Objective] = []
objective_thresholds = []
outcome_constraints = []
tracking_metrics = []
Expand All @@ -659,7 +659,7 @@ def opt_config_and_tracking_metrics_from_sqa(

result = self.metric_from_sqa(metric_sqa=metric_sqa)
if isinstance(result, Objective):
objective = result
objectives.append(result)
# Collect metrics from the objective
if metric_sqa.intent in (
MetricIntent.MULTI_OBJECTIVE,
Expand Down Expand Up @@ -729,14 +729,15 @@ def opt_config_and_tracking_metrics_from_sqa(
tracking_metrics.append(result)
all_metrics.append(raw_metric)

if objective is None:
if not objectives:
return None, tracking_metrics, all_metrics

if preference_objective_sqa is not None:
if objective_thresholds:
raise SQADecodeError(
"PreferenceOptimizationConfig cannot have objective thresholds."
)
objective = objectives[0]
properties = preference_objective_sqa.properties or {}
optimization_config = PreferenceOptimizationConfig(
objective=assert_is_instance(objective, MultiObjective),
Expand All @@ -747,7 +748,8 @@ def opt_config_and_tracking_metrics_from_sqa(
outcome_constraints=outcome_constraints,
pruning_target_parameterization=pruning_target_parameterization,
)
elif objective_thresholds or type(objective) is MultiObjective:
elif objective_thresholds or type(objectives[0]) is MultiObjective:
objective = objectives[0]
optimization_config = MultiObjectiveOptimizationConfig(
objective=assert_is_instance(
objective, Union[MultiObjective, ScalarizedObjective]
Expand All @@ -758,7 +760,7 @@ def opt_config_and_tracking_metrics_from_sqa(
)
else:
optimization_config = OptimizationConfig(
objective=objective,
objectives=objectives,
outcome_constraints=outcome_constraints,
pruning_target_parameterization=pruning_target_parameterization,
)
Expand Down
13 changes: 7 additions & 6 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,13 +839,14 @@ def optimization_config_to_sqa(
),
experiment_metrics=experiment_metrics,
)
metrics_sqa.append(obj_sqa)
else:
obj_sqa = self.objective_to_sqa(
objective=optimization_config.objective,
experiment_metrics=experiment_metrics,
)

metrics_sqa.append(obj_sqa)
for obj in optimization_config.objectives:
obj_sqa = self.objective_to_sqa(
objective=obj,
experiment_metrics=experiment_metrics,
)
metrics_sqa.append(obj_sqa)
for constraint in optimization_config.outcome_constraints:
constraint_sqa = self.outcome_constraint_to_sqa(
outcome_constraint=constraint,
Expand Down
13 changes: 13 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
get_fixed_parameter,
get_generator_run,
get_model_predictions_per_arm,
get_moo_optimization_config,
get_multi_objective_optimization_config,
get_multi_type_experiment,
get_objective,
Expand Down Expand Up @@ -1424,6 +1425,18 @@ def test_optimization_config_pruning_target_parameterization_sqa_roundtrip(
)
self.assertEqual(loaded_pruning_target_parameterization.parameters["z"], False)

def test_moo_optimization_config_sqa_roundtrip(self) -> None:
"""Test SQA round-trip for OptimizationConfig with multiple objectives."""
experiment = get_experiment_with_batch_trial()
experiment.add_tracking_metric(Metric(name="m3", lower_is_better=True))
experiment.optimization_config = get_moo_optimization_config()
save_experiment(experiment)
loaded_experiment = load_experiment(experiment.name)
self.assertEqual(experiment, loaded_experiment)
loaded_oc = none_throws(loaded_experiment.optimization_config)
self.assertEqual(len(loaded_oc.objectives), 2)
self.assertTrue(loaded_oc.is_moo_problem)

def test_multi_objective_optimization_config_pruning_target_sqa_roundtrip(
self,
) -> None:
Expand Down
Loading
Loading