Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
79894c5
Refactor shared validators and converters
Scienfitz Mar 20, 2026
0c8dd26
Update permutation augmentation utility interface
Scienfitz Mar 20, 2026
388f81a
Add mirror augmentation utility
Scienfitz Mar 20, 2026
61e9a49
Add Symmetry domain model
Scienfitz Mar 20, 2026
f612319
Add Parameter.is_equivalent and apply in PermutationSymmetry
Scienfitz Mar 20, 2026
a982cb8
Integrate symmetries into surrogates and recommenders
Scienfitz Mar 20, 2026
bccda12
Update constraints for symmetry support
Scienfitz Mar 20, 2026
0d7da57
Add hypothesis strategies for symmetries and conditions
Scienfitz Mar 20, 2026
756809b
Add symmetry tests
Scienfitz Mar 20, 2026
6f4855f
Add symmetry documentation
Scienfitz Mar 20, 2026
e221574
Add symmetry example
Scienfitz Mar 20, 2026
b1f704b
Handle CompositeSurrogate in symmetry integration
Scienfitz Mar 20, 2026
9c894a9
Fix mypy errors in categorical validator and dependency type ignore
Scienfitz Mar 20, 2026
2dac5a4
Add symmetry validation tests
Scienfitz Mar 20, 2026
f853266
Update CHANGELOG
Scienfitz Mar 20, 2026
3247695
Replace deprecated set_random_seed with Settings in example
Scienfitz Mar 20, 2026
dda3997
Fix Sphinx cross-references for symmetry classes
Scienfitz Mar 20, 2026
08a09b8
Fix bug in permutation constraint
Scienfitz Mar 20, 2026
1adf5c1
Improve docstring
Scienfitz Apr 2, 2026
180521c
Add docstrings to to_symmetries and to_symmetry methods
Scienfitz Apr 10, 2026
fe9cb4c
Add use_data_augmentation to symmetry summary
Scienfitz Apr 10, 2026
88035b3
Rework imports
Scienfitz Apr 10, 2026
dbb2693
Improve example
Scienfitz Apr 24, 2026
19af1b3
Rename partial functions
Scienfitz Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Interpoint constraints for continuous search spaces
- Transfer learning benchmarks for shifted and inverted Hartmann functions
- Coding convention instructions for agentic developers (`AGENTS.md`, `CLAUDE.md`)
- Symmetry classes (`PermutationSymmetry`, `MirrorSymmetry`, `DependencySymmetry`)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add the base class? You could even mention the symmetry framework as a whole first and only then mention the classes, because that the former is new is not clear from the bullet (you cannot distinguish it from the case where we later add a fourth class)

for expressing invariances and configuring surrogate data augmentation
- `Parameter.is_equivalent` method for structural parameter comparison

### Breaking Changes
- `ContinuousLinearConstraint.to_botorch` now returns a collection of constraint tuples
instead of a single tuple (needed for interpoint constraints)
- `df_apply_permutation_augmentation` has a different interface and now expects
permutation groups instead of column groups

### Fixed
- `SHAPInsight` breaking with `numpy>=2.4` due to no longer accepted implicit array to
scalar conversion
- Using `np.isclose` for assessing equality of `Interval` bounds instead of hard
equality check
- `DiscretePermutationInvarianceConstraint` no longer erroneously removes diagonal
points (e.g., where all permuted parameters have the same value)

### Changed
- The `Campaign.allow_*` flag mechanism is now based on `AutoBool` logic, providing
Expand Down
4 changes: 0 additions & 4 deletions baybe/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ class Constraint(ABC, SerialMixin):
eval_during_modeling: ClassVar[bool]
"""Class variable encoding whether the condition is evaluated during modeling."""

eval_during_augmentation: ClassVar[bool] = False
Comment thread
Scienfitz marked this conversation as resolved.
"""Class variable encoding whether the constraint could be considered during data
augmentation."""

numerical_only: ClassVar[bool] = False
"""Class variable encoding whether the constraint is valid only for numerical
parameters."""
Expand Down
2 changes: 1 addition & 1 deletion baybe/constraints/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Condition(ABC, SerialMixin):
"""Abstract base class for all conditions.

Conditions always evaluate an expression regarding a single parameter.
Conditions are part of constraints, a constraint can have multiple conditions.
Conditions are part of constraints and symmetries.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe links to constraints and symmetries would be good?

"""

@abstractmethod
Expand Down
94 changes: 65 additions & 29 deletions baybe/constraints/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
if TYPE_CHECKING:
import polars as pl

from baybe.symmetries.dependency import DependencySymmetry
from baybe.symmetries.permutation import PermutationSymmetry


@define
class DiscreteExcludeConstraint(DiscreteConstraint):
Expand Down Expand Up @@ -195,10 +198,6 @@ class DiscreteDependenciesConstraint(DiscreteConstraint):
a single constraint.
"""

# class variables
eval_during_augmentation: ClassVar[bool] = True
# See base class

# object variables
conditions: list[Condition] = field()
"""The list of individual conditions."""
Expand Down Expand Up @@ -271,39 +270,56 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:

return inds_bad

def to_symmetries(
self, use_data_augmentation: bool = True
) -> tuple[DependencySymmetry, ...]:
"""Convert to :class:`~baybe.symmetries.dependency.DependencySymmetry` objects.

Create one symmetry object per dependency relationship, i.e., per
(parameter, condition, affected_parameters) triple.

Args:
use_data_augmentation: Flag indicating whether the resulting symmetry
objects should apply data augmentation. ``True`` means that
measurement augmentation will be performed by replacing inactive
affected parameter values with all possible values.

Returns:
A tuple of dependency symmetries, one for each dependency in the
constraint.
"""
from baybe.symmetries.dependency import DependencySymmetry

return tuple(
DependencySymmetry(
parameter_name=p,
condition=c,
affected_parameter_names=aps,
use_data_augmentation=use_data_augmentation,
)
for p, c, aps in zip(
self.parameters, self.conditions, self.affected_parameters, strict=True
)
)


@define
class DiscretePermutationInvarianceConstraint(DiscreteConstraint):
"""Constraint class for declaring that a set of parameters is permutation invariant.

More precisely, this means that, ``(val_from_param1, val_from_param2)`` is
equivalent to ``(val_from_param2, val_from_param1)``. Since it does not make sense
to have this constraint with duplicated labels, this implementation also internally
applies the :class:`baybe.constraints.discrete.DiscreteNoLabelDuplicatesConstraint`.
equivalent to ``(val_from_param2, val_from_param1)``.

*Note:* This constraint is evaluated during creation. In the future it might also be
evaluated during modeling to make use of the invariance.
"""

# class variables
eval_during_augmentation: ClassVar[bool] = True
# See base class

# object variables
dependencies: DiscreteDependenciesConstraint | None = field(default=None)
"""Dependencies connected with the invariant parameters."""

@override
def get_invalid(self, data: pd.DataFrame) -> pd.Index:
# Get indices of entries with duplicate label entries. These will also be

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been a while .... Can you maybe quickly comment on the reason for the logic change again?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is related to the "diagonal" points you mention in the Changelog, tho I do not understand what "diagonal" really means here

# dropped by this constraint.
mask_duplicate_labels = pd.Series(False, index=data.index)
mask_duplicate_labels[
DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid(
data
)
] = True

# Merge a permutation invariant representation of all affected parameters with
# the other parameters and indicate duplicates. This ensures that variation in
# other parameters is also accounted for.
Expand All @@ -314,20 +330,14 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:
data[self.parameters].apply(cast(Callable, frozenset), axis=1),
],
axis=1,
).loc[
~mask_duplicate_labels # only consider label-duplicate-free part
]
)
mask_duplicate_permutations = df_eval.duplicated(keep="first")

# Indices of entries with label-duplicates
inds_duplicate_labels = data.index[mask_duplicate_labels]

# Indices of duplicate permutations in the (already label-duplicate-free) data
inds_duplicate_permutations = df_eval.index[mask_duplicate_permutations]
# Indices of duplicate permutations
inds_invalid = data.index[mask_duplicate_permutations]

# If there are dependencies connected to the invariant parameters evaluate them
# here and remove resulting duplicates with a DependenciesConstraint
inds_invalid = inds_duplicate_labels.union(inds_duplicate_permutations)
if self.dependencies:
self.dependencies.permutation_invariant = True
inds_duplicate_independency_adjusted = self.dependencies.get_invalid(
Expand All @@ -337,6 +347,32 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index:

return inds_invalid

def to_symmetry(self, use_data_augmentation: bool = True) -> PermutationSymmetry:
"""Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`.

The constraint's parameters form the primary permutation group. If
dependencies are attached, their parameters are added as an additional
group that is permuted in lockstep.

Args:
use_data_augmentation: Flag indicating whether the resulting symmetry
object should apply data augmentation. ``True`` means that
measurement augmentation will be performed by generating all
permutations of parameter values within each group.

Returns:
The corresponding permutation symmetry.
"""
from baybe.symmetries.permutation import PermutationSymmetry

groups = [self.parameters]
if self.dependencies:
groups.append(list(self.dependencies.parameters))
return PermutationSymmetry(
permutation_groups=groups,
use_data_augmentation=use_data_augmentation,
)


@define
class DiscreteCustomConstraint(DiscreteConstraint):
Expand Down
17 changes: 17 additions & 0 deletions baybe/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar

import attrs
import pandas as pd
from attrs import define, field
from attrs.converters import optional as optional_c
Expand Down Expand Up @@ -88,6 +89,22 @@ def to_searchspace(self) -> SearchSpace:

return SearchSpace.from_parameter(self)

def is_equivalent(self, other: Parameter) -> bool:
"""Check if this parameter is equivalent to another, ignoring the name.

Two parameters are considered equivalent if they have the same type and
all attributes are equal except for the name.

Args:
other: The parameter to compare against.

Returns:
``True`` if the parameters are equivalent, ``False`` otherwise.
"""
if type(self) is not type(other):
return False
return attrs.evolve(self, name=other.name) == other

@abstractmethod
def summary(self) -> dict:
"""Return a custom summarization of the parameter."""
Expand Down
14 changes: 5 additions & 9 deletions baybe/parameters/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
from baybe.parameters.enum import CategoricalEncoding
from baybe.parameters.validation import validate_unique_values
from baybe.settings import active_settings
from baybe.utils.conversion import nonstring_to_tuple


def _convert_values(value, self, field) -> tuple[str, ...]:
"""Sort and convert values for categorical parameters."""
value = nonstring_to_tuple(value, self, field)
return tuple(sorted(value, key=lambda x: (str(type(x)), x)))
from baybe.utils.conversion import normalize_convertible2str_sequence


def _validate_label_min_len(self, attr, value) -> None:
Expand All @@ -38,8 +32,10 @@ class CategoricalParameter(_DiscreteLabelLikeParameter):
# object variables
_values: tuple[str | bool, ...] = field(
alias="values",
converter=Converter(_convert_values, takes_self=True, takes_field=True), # type: ignore
validator=(
converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the mypy part about? Haven't seen this before

normalize_convertible2str_sequence, takes_self=True, takes_field=True
),
validator=( # type: ignore[arg-type] # mypy: validator tuple

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why suddenly ignores needed (and what is the mypy part about)? Due to newer mypy release?

validate_unique_values,
deep_iterable(
member_validator=(instance_of((str, bool)), _validate_label_min_len),
Expand Down
3 changes: 2 additions & 1 deletion baybe/parameters/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

from baybe.exceptions import NumericalUnderflowError
from baybe.parameters.base import ContinuousParameter, DiscreteParameter
from baybe.parameters.validation import validate_is_finite, validate_unique_values
from baybe.parameters.validation import validate_unique_values
from baybe.settings import active_settings
from baybe.utils.interval import InfiniteIntervalError, Interval
from baybe.utils.validation import validate_is_finite

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: this PR moves some stuff around (like the validate_is_finite and others), which technically represents breaking changes. Changelog entry or no?



@define(frozen=True, slots=False)
Expand Down
17 changes: 0 additions & 17 deletions baybe/parameters/validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Validation functionality for parameters."""

from collections.abc import Sequence
from typing import Any

import numpy as np
from attrs.validators import gt, instance_of, lt


Expand All @@ -28,18 +26,3 @@ def validate_decorrelation(obj: Any, attribute: Any, value: float) -> None:
if isinstance(value, float):
gt(0.0)(obj, attribute, value)
lt(1.0)(obj, attribute, value)


def validate_is_finite( # noqa: DOC101, DOC103
obj: Any, _: Any, value: Sequence[float]
) -> None:
"""Validate that ``value`` contains no infinity/nan.

Raises:
ValueError: If ``value`` contains infinity/nan.
"""
if not all(np.isfinite(value)):
raise ValueError(
f"Cannot assign the following values containing infinity/nan to "
f"parameter {obj.name}: {value}."
)
30 changes: 30 additions & 0 deletions baybe/recommenders/pure/bayesian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def _get_acquisition_function(self, objective: Objective) -> AcquisitionFunction
return qLogNEHVI() if objective.is_multi_output else qLogEI()
return self.acquisition_function

def _get_surrogate_for_augmentation(self) -> Surrogate | None:
"""Get the Surrogate instance for augmentation/validation, if available."""
from baybe.surrogates.composite import CompositeSurrogate, _ReplicationMapping

model = self._surrogate_model

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the downstream consequences of this branching logic 🤔

  1. What happens is we have a CompositeSurrogate that is not a _ReplicationMapping? The method would return None in this case!?
  2. What's the reason for the isinstance(template, Surrogate) check?
  3. I'm not yet convinced by the None output type. This function is only ever called in places where a surrogate is (should!) be be available. So I think replacing the None path with an assert self._surrogate is not None check is the better approach

if isinstance(model, Surrogate):
return model
if isinstance(model, CompositeSurrogate):
# All inner surrogates are copies of the same template
surrogates = model.surrogates
if isinstance(surrogates, _ReplicationMapping):
template = surrogates.template
if isinstance(template, Surrogate):
return template
return None
Comment thread
AVHopp marked this conversation as resolved.

def get_surrogate(
self,
searchspace: SearchSpace,
Expand Down Expand Up @@ -114,6 +130,13 @@ def _setup_botorch_acqf(
f"{len(objective.targets)}-target multi-output context."
)

# Perform data augmentation if configured
surrogate_for_augmentation = self._get_surrogate_for_augmentation()
if surrogate_for_augmentation is not None:
Comment thread
AVHopp marked this conversation as resolved.
measurements = surrogate_for_augmentation.augment_measurements(
measurements, searchspace.parameters
)

surrogate = self.get_surrogate(searchspace, objective, measurements)
self._botorch_acqf = acqf.to_botorch(
surrogate,
Expand Down Expand Up @@ -156,6 +179,13 @@ def recommend(

validate_object_names(searchspace.parameters + objective.targets)

# Validate compatibility of surrogate symmetries with searchspace
surrogate_for_validation = self._get_surrogate_for_augmentation()
if surrogate_for_validation is not None:
for s in surrogate_for_validation.symmetries:
s.validate_searchspace_context(searchspace)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: Validation so far is only part of the recommend call here in the recommenders. Validation has not been included in the Campaign yet. This is due to two factors

  • To properly validate the symmetries and searchspace compatibility there needs to be a mechanism that can iterate over all possible recommenders of a metarecommender. Otherwise this upfront validation already fails for the two phase recommender if the second recommender has symmetries
  • There would be double validation with campaign and recommend call so the context info of whether validation was already performed needs to be passed somewhere. Likely fixable with settings mechanism not yet available

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AdrianSosic I see now that the 2nd point could be solved with the Settings mechanism but I have no idea how to solve issue 1.

In the absence of that its not realy possible to turn it into an upfront validation, so I would probably not change the validation for this moment unless you have a smarter idea

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for being pragmatic and not trying to come up with something potentially convoluted right now. Even if we find a better way for the validation later, including it is just a plain improvement without negative consequences to users, so we can add it later without problems.


# Experimental input validation
if (measurements is None) or measurements.empty:
raise NotImplementedError(
f"Recommenders of type '{BayesianRecommender.__name__}' do not support "
Expand Down
5 changes: 0 additions & 5 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,6 @@ def transform(

return comp_rep

@property
def constraints_augmentable(self) -> tuple[Constraint, ...]:
"""The searchspace constraints that can be considered during augmentation."""
return tuple(c for c in self.constraints if c.eval_during_augmentation)

def get_parameters_by_name(self, names: Sequence[str]) -> tuple[Parameter, ...]:
"""Return parameters with the specified names.

Expand Down
Loading
Loading