Skip to content
67 changes: 56 additions & 11 deletions ax/adapter/adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,31 +76,67 @@
from ax import adapter as adapter_module # noqa F401


def extract_parameter_constraints(
parameter_constraints: list[ParameterConstraint], param_names: list[str]
def _extract_constraints(
parameter_constraints: list[ParameterConstraint],
param_names: list[str],
is_equality: bool,
) -> TBounds:
"""Convert Ax parameter constraints into a tuple of NumPy arrays representing the
system of linear inequality constraints.
"""Extract linear constraints into a tuple of NumPy arrays.

Shared helper for extracting inequality (``A x <= b``) or equality
(``A x = b``) constraints.

Args:
parameter_constraints: A list of parameter constraint objects.
param_names: A list of parameter names.
is_equality: If True, extract equality constraints; otherwise
extract inequality constraints.

Returns:
An optional tuple of NumPy arrays (A, b) representing the system of linear
inequality constraints A x < b.
An optional tuple of NumPy arrays (A, b).
"""
if len(parameter_constraints) == 0:
filtered = [c for c in parameter_constraints if c.is_equality == is_equality]
if len(filtered) == 0:
return None
A = np.zeros((len(parameter_constraints), len(param_names)))
b = np.zeros((len(parameter_constraints), 1))
for i, c in enumerate(parameter_constraints):
A = np.zeros((len(filtered), len(param_names)))
b = np.zeros((len(filtered), 1))
for i, c in enumerate(filtered):
b[i, 0] = c.bound
for name, val in c.constraint_dict.items():
A[i, param_names.index(name)] = val
return (A, b)


def extract_inequality_constraints(
parameter_constraints: list[ParameterConstraint], param_names: list[str]
) -> TBounds:
"""Convert Ax inequality parameter constraints into NumPy arrays.

Args:
parameter_constraints: A list of parameter constraint objects.
param_names: A list of parameter names.

Returns:
An optional tuple of NumPy arrays (A, b) with ``A x <= b``.
"""
return _extract_constraints(parameter_constraints, param_names, is_equality=False)


def extract_equality_constraints(
parameter_constraints: list[ParameterConstraint], param_names: list[str]
) -> TBounds:
"""Convert Ax equality parameter constraints into NumPy arrays.

Args:
parameter_constraints: A list of parameter constraint objects.
param_names: A list of parameter names.

Returns:
An optional tuple of NumPy arrays (A, b) with ``A x = b``.
"""
return _extract_constraints(parameter_constraints, param_names, is_equality=True)


def extract_search_space_digest(
search_space: SearchSpace, param_names: list[str]
) -> SearchSpaceDigest:
Expand Down Expand Up @@ -402,6 +438,7 @@ def validate_and_apply_final_transform(
pending_observations: list[npt.NDArray] | None,
objective_thresholds: npt.NDArray | None = None,
pruning_target_point: npt.NDArray | None = None,
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
final_transform: Callable[[npt.NDArray], Tensor] = torch.tensor,
) -> tuple[
Tensor,
Expand All @@ -410,6 +447,7 @@ def validate_and_apply_final_transform(
list[Tensor] | None,
Tensor | None,
Tensor | None,
tuple[Tensor, Tensor] | None,
]:
# TODO: use some container down the road (similar to
# SearchSpaceDigest) to limit the return arguments
Expand Down Expand Up @@ -437,13 +475,20 @@ def validate_and_apply_final_transform(
pruning_target_tensor: Tensor | None = None
if pruning_target_point is not None:
pruning_target_tensor = final_transform(pruning_target_point)
equality_constraints_tensors: tuple[Tensor, Tensor] | None = None
if equality_constraints is not None:
equality_constraints_tensors = (
final_transform(equality_constraints[0]),
final_transform(equality_constraints[1]),
)
return (
obj_weights_tensor,
outcome_constraints_tensors,
linear_constraints_tensors,
pending_obs_tensors,
obj_thresholds_tensor,
pruning_target_tensor,
equality_constraints_tensors,
)


Expand Down Expand Up @@ -700,7 +745,7 @@ def get_pareto_frontier_and_configs(
if obj_t is not None:
obj_t = array_to_tensor(obj_t)
# Transform to tensors.
obj_w, oc_c, _, _, _, _ = validate_and_apply_final_transform(
obj_w, oc_c, _, _, _, _, _ = validate_and_apply_final_transform(
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
linear_constraints=None,
Expand Down
9 changes: 7 additions & 2 deletions ax/adapter/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import numpy as np
from ax.adapter.adapter_utils import (
extract_parameter_constraints,
extract_equality_constraints,
extract_inequality_constraints,
extract_search_space_digest,
get_fixed_features,
parse_observation_features,
Expand Down Expand Up @@ -92,7 +93,10 @@ def _gen(
# Get fixed features
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
# Extract param constraints
linear_constraints = extract_parameter_constraints(
linear_constraints = extract_inequality_constraints(
search_space.parameter_constraints, self.parameters
)
equality_constraints_np = extract_equality_constraints(
search_space.parameter_constraints, self.parameters
)
# Extract generated points.
Expand Down Expand Up @@ -177,6 +181,7 @@ def _gen(
n=n,
search_space_digest=search_space_digest,
linear_constraints=linear_constraints,
equality_constraints=equality_constraints_np,
fixed_features=fixed_features_dict,
model_gen_options=model_gen_options,
rounding_func=transform_callback(self.parameters, self.transforms),
Expand Down
92 changes: 92 additions & 0 deletions ax/adapter/tests/test_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_get_fresh_pairwise_trial_indices,
arm_to_np_array,
can_map_to_binary,
extract_equality_constraints,
extract_inequality_constraints,
extract_objective_weight_matrix,
extract_search_space_digest,
feasible_hypervolume,
Expand All @@ -35,6 +37,7 @@
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -377,6 +380,7 @@ def test_validate_and_apply_final_transform_with_target_point(self) -> None:
_,
_,
target_p,
_,
) = validate_and_apply_final_transform(
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
Expand Down Expand Up @@ -412,6 +416,7 @@ def test_validate_and_apply_final_transform_none_target_point(self) -> None:
_,
_,
target_p,
_,
) = validate_and_apply_final_transform(
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
Expand Down Expand Up @@ -652,3 +657,90 @@ def _attach(
self.assertNotIn(0, result)
self.assertNotIn(1, result)
self.assertIn(2, result)

def test_extract_inequality_constraints(self) -> None:
param_names = ["x", "y"]
ineq = ParameterConstraint(inequality="x + y <= 1")
eq = ParameterConstraint(equality="x + y == 1")

# Only inequality constraints are extracted
result = extract_inequality_constraints([ineq, eq], param_names)
self.assertIsNotNone(result)
assert result is not None
A, b = result
self.assertEqual(A.shape, (1, 2))
self.assertEqual(b.shape, (1, 1))
np.testing.assert_array_equal(A[0], [1.0, 1.0])
np.testing.assert_array_equal(b[0], [1.0])

# Returns None when no inequality constraints
result = extract_inequality_constraints([eq], param_names)
self.assertIsNone(result)

# Returns None for empty list
result = extract_inequality_constraints([], param_names)
self.assertIsNone(result)

def test_extract_equality_constraints(self) -> None:
param_names = ["x", "y"]
ineq = ParameterConstraint(inequality="x + y <= 1")
eq = ParameterConstraint(equality="x + y == 1")

# Only equality constraints are extracted
result = extract_equality_constraints([ineq, eq], param_names)
self.assertIsNotNone(result)
assert result is not None
A, b = result
self.assertEqual(A.shape, (1, 2))
self.assertEqual(b.shape, (1, 1))
np.testing.assert_array_equal(A[0], [1.0, 1.0])
np.testing.assert_array_equal(b[0], [1.0])

# Returns None when no equality constraints
result = extract_equality_constraints([ineq], param_names)
self.assertIsNone(result)

def test_extract_constraints_mixed(self) -> None:
"""Both functions correctly partition a mixed list."""
param_names = ["x", "y"]
ineq1 = ParameterConstraint(inequality="x <= 0.5")
ineq2 = ParameterConstraint(inequality="y <= 0.8")
eq1 = ParameterConstraint(equality="x + y == 1")

ineq_result = extract_inequality_constraints([ineq1, eq1, ineq2], param_names)
eq_result = extract_equality_constraints([ineq1, eq1, ineq2], param_names)

assert ineq_result is not None
assert eq_result is not None
self.assertEqual(ineq_result[0].shape, (2, 2)) # 2 inequalities
self.assertEqual(eq_result[0].shape, (1, 2)) # 1 equality

def test_validate_and_apply_final_transform_equality_constraints(self) -> None:
"""equality_constraints are converted to tensors."""
objective_weights = np.array([1.0, 0.0])
A_eq = np.array([[1.0, 1.0]])
b_eq = np.array([[1.0]])

_, _, _, _, _, _, eq_c = validate_and_apply_final_transform(
objective_weights=objective_weights,
outcome_constraints=None,
linear_constraints=None,
pending_observations=None,
equality_constraints=(A_eq, b_eq),
)
self.assertIsNotNone(eq_c)
assert eq_c is not None
self.assertTrue(torch.equal(eq_c[0], torch.tensor(A_eq)))
self.assertTrue(torch.equal(eq_c[1], torch.tensor(b_eq)))

def test_validate_and_apply_final_transform_no_equality_constraints(self) -> None:
"""equality_constraints defaults to None."""
objective_weights = np.array([1.0])

_, _, _, _, _, _, eq_c = validate_and_apply_final_transform(
objective_weights=objective_weights,
outcome_constraints=None,
linear_constraints=None,
pending_observations=None,
)
self.assertIsNone(eq_c)
59 changes: 59 additions & 0 deletions ax/adapter/tests/test_random_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,65 @@ def test_gen_w_constraints(self) -> None:
self.assertEqual(obsf[1].parameters, {"x": 3.0, "y": 4.0, "z": 3.0})
self.assertTrue(np.array_equal(gen_results.weights, np.array([1.0, 2.0])))

def test_gen_w_equality_constraints(self) -> None:
# Verify that equality constraints from the search space are extracted
# and passed through to the generator's gen() call.
x = RangeParameter("x", ParameterType.FLOAT, lower=0, upper=1)
y = RangeParameter("y", ParameterType.FLOAT, lower=0, upper=1)
z = RangeParameter("z", ParameterType.FLOAT, lower=0, upper=1)
parameter_constraints = [
ParameterConstraint(equality="x + y == 0.5"),
]
search_space = SearchSpace([x, y, z], parameter_constraints)
experiment = Experiment(search_space=search_space)
adapter = RandomAdapter(experiment=experiment, generator=RandomGenerator())
with mock.patch.object(
adapter.generator,
"gen",
return_value=(
np.array([[0.2, 0.3, 0.4]]),
np.array([1.0]),
),
) as mock_gen:
adapter._gen(
n=1,
search_space=search_space,
pending_observations={},
fixed_features=ObservationFeatures({}),
optimization_config=None,
model_gen_options=self.model_gen_options,
)
gen_args = mock_gen.mock_calls[0][2]
eq_constraints = gen_args["equality_constraints"]
self.assertIsNotNone(eq_constraints)
A, b = eq_constraints
# x + y = 0.5 => A = [[1, 1, 0]], b = [[0.5]]
self.assertTrue(np.array_equal(A, np.array([[1.0, 1.0, 0.0]])))
self.assertTrue(np.array_equal(b, np.array([[0.5]])))

def test_gen_no_equality_constraints(self) -> None:
# Verify that equality_constraints is None when there are no equality
# constraints on the search space.
adapter = RandomAdapter(experiment=self.experiment, generator=RandomGenerator())
with mock.patch.object(
adapter.generator,
"gen",
return_value=(
np.array([[0.5, 1.5, 2.5]]),
np.array([1.0]),
),
) as mock_gen:
adapter._gen(
n=1,
search_space=self.search_space,
pending_observations={},
fixed_features=ObservationFeatures({}),
optimization_config=None,
model_gen_options=self.model_gen_options,
)
gen_args = mock_gen.mock_calls[0][2]
self.assertIsNone(gen_args["equality_constraints"])

def test_gen_simple(self) -> None:
# Test with no constraints, no fixed feature, no pending observations
search_space = SearchSpace(self.parameters[:2])
Expand Down
12 changes: 9 additions & 3 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
_get_fresh_pairwise_trial_indices,
arm_to_np_array,
array_to_observation_data,
extract_equality_constraints,
extract_inequality_constraints,
extract_objective_thresholds,
extract_objective_weight_matrix,
extract_outcome_constraints,
extract_parameter_constraints,
extract_search_space_digest,
get_fixed_features,
observation_data_to_array,
Expand Down Expand Up @@ -1044,7 +1045,10 @@ def _get_transformed_model_gen_args(
arm=optimization_config.pruning_target_parameterization,
parameters=self.parameters,
)
linear_constraints = extract_parameter_constraints(
linear_constraints = extract_inequality_constraints(
search_space.parameter_constraints, self.parameters
)
equality_constraints_np = extract_equality_constraints(
search_space.parameter_constraints, self.parameters
)
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
Expand All @@ -1065,14 +1069,15 @@ def _get_transformed_model_gen_args(
pending_array = pending_observations_as_array_list(
pending_observations, self.outcomes, self.parameters
)
obj_w, out_c, lin_c, pend_o, obj_t, pruning_target_p = (
obj_w, out_c, lin_c, pend_o, obj_t, pruning_target_p, eq_c = (
validate_and_apply_final_transform(
objective_weights=objective_weights,
outcome_constraints=outcome_constraints,
linear_constraints=linear_constraints,
pending_observations=pending_array,
objective_thresholds=objective_thresholds,
pruning_target_point=pruning_target_point,
equality_constraints=equality_constraints_np,
final_transform=self._array_to_tensor,
)
)
Expand All @@ -1089,6 +1094,7 @@ def _get_transformed_model_gen_args(
outcome_constraints=out_c,
objective_thresholds=obj_t,
linear_constraints=lin_c,
equality_constraints=eq_c,
fixed_features=fixed_features_dict,
pending_observations=pend_o,
model_gen_options=model_gen_options or {},
Expand Down
Loading
Loading