-
Notifications
You must be signed in to change notification settings - Fork 79
Symmetry and Data Augmentation #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev/symmetry
Are you sure you want to change the base?
Changes from all commits
79894c5
0c8dd26
388f81a
61e9a49
f612319
a982cb8
bccda12
0d7da57
756809b
6f4855f
e221574
b1f704b
9c894a9
2dac5a4
f853266
3247695
dda3997
08a09b8
1adf5c1
180521c
fe9cb4c
88035b3
dbb2693
19af1b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,7 +95,7 @@ class Condition(ABC, SerialMixin): | |
| """Abstract base class for all conditions. | ||
|
|
||
| Conditions always evaluate an expression regarding a single parameter. | ||
| Conditions are part of constraints, a constraint can have multiple conditions. | ||
| Conditions are part of constraints and symmetries. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe links to |
||
| """ | ||
|
|
||
| @abstractmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,9 @@ | |
| if TYPE_CHECKING: | ||
| import polars as pl | ||
|
|
||
| from baybe.symmetries.dependency import DependencySymmetry | ||
| from baybe.symmetries.permutation import PermutationSymmetry | ||
|
|
||
|
|
||
| @define | ||
| class DiscreteExcludeConstraint(DiscreteConstraint): | ||
|
|
@@ -195,10 +198,6 @@ class DiscreteDependenciesConstraint(DiscreteConstraint): | |
| a single constraint. | ||
| """ | ||
|
|
||
| # class variables | ||
| eval_during_augmentation: ClassVar[bool] = True | ||
| # See base class | ||
|
|
||
| # object variables | ||
| conditions: list[Condition] = field() | ||
| """The list of individual conditions.""" | ||
|
|
@@ -271,39 +270,56 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: | |
|
|
||
| return inds_bad | ||
|
|
||
| def to_symmetries( | ||
| self, use_data_augmentation: bool = True | ||
| ) -> tuple[DependencySymmetry, ...]: | ||
| """Convert to :class:`~baybe.symmetries.dependency.DependencySymmetry` objects. | ||
|
|
||
| Create one symmetry object per dependency relationship, i.e., per | ||
| (parameter, condition, affected_parameters) triple. | ||
|
|
||
| Args: | ||
| use_data_augmentation: Flag indicating whether the resulting symmetry | ||
| objects should apply data augmentation. ``True`` means that | ||
| measurement augmentation will be performed by replacing inactive | ||
| affected parameter values with all possible values. | ||
|
|
||
| Returns: | ||
| A tuple of dependency symmetries, one for each dependency in the | ||
| constraint. | ||
| """ | ||
| from baybe.symmetries.dependency import DependencySymmetry | ||
|
|
||
| return tuple( | ||
| DependencySymmetry( | ||
| parameter_name=p, | ||
| condition=c, | ||
| affected_parameter_names=aps, | ||
| use_data_augmentation=use_data_augmentation, | ||
| ) | ||
| for p, c, aps in zip( | ||
| self.parameters, self.conditions, self.affected_parameters, strict=True | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| @define | ||
| class DiscretePermutationInvarianceConstraint(DiscreteConstraint): | ||
| """Constraint class for declaring that a set of parameters is permutation invariant. | ||
|
|
||
| More precisely, this means that, ``(val_from_param1, val_from_param2)`` is | ||
| equivalent to ``(val_from_param2, val_from_param1)``. Since it does not make sense | ||
| to have this constraint with duplicated labels, this implementation also internally | ||
| applies the :class:`baybe.constraints.discrete.DiscreteNoLabelDuplicatesConstraint`. | ||
| equivalent to ``(val_from_param2, val_from_param1)``. | ||
|
|
||
| *Note:* This constraint is evaluated during creation. In the future it might also be | ||
| evaluated during modeling to make use of the invariance. | ||
| """ | ||
|
|
||
| # class variables | ||
| eval_during_augmentation: ClassVar[bool] = True | ||
| # See base class | ||
|
|
||
| # object variables | ||
| dependencies: DiscreteDependenciesConstraint | None = field(default=None) | ||
| """Dependencies connected with the invariant parameters.""" | ||
|
|
||
| @override | ||
| def get_invalid(self, data: pd.DataFrame) -> pd.Index: | ||
| # Get indices of entries with duplicate label entries. These will also be | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's been a while .... Can you maybe quickly comment on the reason for the logic change again?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is related to the "diagonal" points you mention in the Changelog, tho I do not understand what "diagonal" really means here |
||
| # dropped by this constraint. | ||
| mask_duplicate_labels = pd.Series(False, index=data.index) | ||
| mask_duplicate_labels[ | ||
| DiscreteNoLabelDuplicatesConstraint(parameters=self.parameters).get_invalid( | ||
| data | ||
| ) | ||
| ] = True | ||
|
|
||
| # Merge a permutation invariant representation of all affected parameters with | ||
| # the other parameters and indicate duplicates. This ensures that variation in | ||
| # other parameters is also accounted for. | ||
|
|
@@ -314,20 +330,14 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: | |
| data[self.parameters].apply(cast(Callable, frozenset), axis=1), | ||
| ], | ||
| axis=1, | ||
| ).loc[ | ||
| ~mask_duplicate_labels # only consider label-duplicate-free part | ||
| ] | ||
| ) | ||
| mask_duplicate_permutations = df_eval.duplicated(keep="first") | ||
|
|
||
| # Indices of entries with label-duplicates | ||
| inds_duplicate_labels = data.index[mask_duplicate_labels] | ||
|
|
||
| # Indices of duplicate permutations in the (already label-duplicate-free) data | ||
| inds_duplicate_permutations = df_eval.index[mask_duplicate_permutations] | ||
| # Indices of duplicate permutations | ||
| inds_invalid = data.index[mask_duplicate_permutations] | ||
|
|
||
| # If there are dependencies connected to the invariant parameters evaluate them | ||
| # here and remove resulting duplicates with a DependenciesConstraint | ||
| inds_invalid = inds_duplicate_labels.union(inds_duplicate_permutations) | ||
| if self.dependencies: | ||
| self.dependencies.permutation_invariant = True | ||
| inds_duplicate_independency_adjusted = self.dependencies.get_invalid( | ||
|
|
@@ -337,6 +347,32 @@ def get_invalid(self, data: pd.DataFrame) -> pd.Index: | |
|
|
||
| return inds_invalid | ||
|
|
||
| def to_symmetry(self, use_data_augmentation: bool = True) -> PermutationSymmetry: | ||
| """Convert to a :class:`~baybe.symmetries.permutation.PermutationSymmetry`. | ||
|
|
||
| The constraint's parameters form the primary permutation group. If | ||
| dependencies are attached, their parameters are added as an additional | ||
| group that is permuted in lockstep. | ||
|
|
||
| Args: | ||
| use_data_augmentation: Flag indicating whether the resulting symmetry | ||
| object should apply data augmentation. ``True`` means that | ||
| measurement augmentation will be performed by generating all | ||
| permutations of parameter values within each group. | ||
|
|
||
| Returns: | ||
| The corresponding permutation symmetry. | ||
| """ | ||
| from baybe.symmetries.permutation import PermutationSymmetry | ||
|
|
||
| groups = [self.parameters] | ||
| if self.dependencies: | ||
| groups.append(list(self.dependencies.parameters)) | ||
| return PermutationSymmetry( | ||
| permutation_groups=groups, | ||
| use_data_augmentation=use_data_augmentation, | ||
| ) | ||
|
|
||
|
|
||
| @define | ||
| class DiscreteCustomConstraint(DiscreteConstraint): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,13 +13,7 @@ | |
| from baybe.parameters.enum import CategoricalEncoding | ||
| from baybe.parameters.validation import validate_unique_values | ||
| from baybe.settings import active_settings | ||
| from baybe.utils.conversion import nonstring_to_tuple | ||
|
|
||
|
|
||
| def _convert_values(value, self, field) -> tuple[str, ...]: | ||
| """Sort and convert values for categorical parameters.""" | ||
| value = nonstring_to_tuple(value, self, field) | ||
| return tuple(sorted(value, key=lambda x: (str(type(x)), x))) | ||
| from baybe.utils.conversion import normalize_convertible2str_sequence | ||
|
|
||
|
|
||
| def _validate_label_min_len(self, attr, value) -> None: | ||
|
|
@@ -38,8 +32,10 @@ class CategoricalParameter(_DiscreteLabelLikeParameter): | |
| # object variables | ||
| _values: tuple[str | bool, ...] = field( | ||
| alias="values", | ||
| converter=Converter(_convert_values, takes_self=True, takes_field=True), # type: ignore | ||
| validator=( | ||
| converter=Converter( # type: ignore[misc,call-overload] # mypy: Converter | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the |
||
| normalize_convertible2str_sequence, takes_self=True, takes_field=True | ||
| ), | ||
| validator=( # type: ignore[arg-type] # mypy: validator tuple | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why suddenly ignores needed (and what is the mypy part about)? Due to newer mypy release? |
||
| validate_unique_values, | ||
| deep_iterable( | ||
| member_validator=(instance_of((str, bool)), _validate_label_min_len), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,9 +13,10 @@ | |
|
|
||
| from baybe.exceptions import NumericalUnderflowError | ||
| from baybe.parameters.base import ContinuousParameter, DiscreteParameter | ||
| from baybe.parameters.validation import validate_is_finite, validate_unique_values | ||
| from baybe.parameters.validation import validate_unique_values | ||
| from baybe.settings import active_settings | ||
| from baybe.utils.interval import InfiniteIntervalError, Interval | ||
| from baybe.utils.validation import validate_is_finite | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: this PR moves some stuff around (like the |
||
|
|
||
|
|
||
| @define(frozen=True, slots=False) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,6 +86,22 @@ def _get_acquisition_function(self, objective: Objective) -> AcquisitionFunction | |
| return qLogNEHVI() if objective.is_multi_output else qLogEI() | ||
| return self.acquisition_function | ||
|
|
||
| def _get_surrogate_for_augmentation(self) -> Surrogate | None: | ||
| """Get the Surrogate instance for augmentation/validation, if available.""" | ||
| from baybe.surrogates.composite import CompositeSurrogate, _ReplicationMapping | ||
|
|
||
| model = self._surrogate_model | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand the downstream consequences of this branching logic 🤔
|
||
| if isinstance(model, Surrogate): | ||
| return model | ||
| if isinstance(model, CompositeSurrogate): | ||
| # All inner surrogates are copies of the same template | ||
| surrogates = model.surrogates | ||
| if isinstance(surrogates, _ReplicationMapping): | ||
| template = surrogates.template | ||
| if isinstance(template, Surrogate): | ||
| return template | ||
| return None | ||
|
AVHopp marked this conversation as resolved.
|
||
|
|
||
| def get_surrogate( | ||
| self, | ||
| searchspace: SearchSpace, | ||
|
|
@@ -114,6 +130,13 @@ def _setup_botorch_acqf( | |
| f"{len(objective.targets)}-target multi-output context." | ||
| ) | ||
|
|
||
| # Perform data augmentation if configured | ||
| surrogate_for_augmentation = self._get_surrogate_for_augmentation() | ||
| if surrogate_for_augmentation is not None: | ||
|
AVHopp marked this conversation as resolved.
|
||
| measurements = surrogate_for_augmentation.augment_measurements( | ||
| measurements, searchspace.parameters | ||
| ) | ||
|
|
||
| surrogate = self.get_surrogate(searchspace, objective, measurements) | ||
| self._botorch_acqf = acqf.to_botorch( | ||
| surrogate, | ||
|
|
@@ -156,6 +179,13 @@ def recommend( | |
|
|
||
| validate_object_names(searchspace.parameters + objective.targets) | ||
|
|
||
| # Validate compatibility of surrogate symmetries with searchspace | ||
| surrogate_for_validation = self._get_surrogate_for_augmentation() | ||
| if surrogate_for_validation is not None: | ||
| for s in surrogate_for_validation.symmetries: | ||
| s.validate_searchspace_context(searchspace) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Important: Validation so far is only part of the recommend call here in the recommenders. Validation has not been included in the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AdrianSosic I see now that the 2nd point could be solved with the In the absence of that its not realy possible to turn it into an upfront validation, so I would probably not change the validation for this moment unless you have a smarter idea
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for being pragmatic and not trying to come up with something potentially convoluted right now. Even if we find a better way for the validation later, including it is just a plain improvement without negative consequences to users, so we can add it later without problems. |
||
|
|
||
| # Experimental input validation | ||
| if (measurements is None) or measurements.empty: | ||
| raise NotImplementedError( | ||
| f"Recommenders of type '{BayesianRecommender.__name__}' do not support " | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add the base class? You could even mention the symmetry framework as a whole first and only then mention the classes, because that the former is new is not clear from the bullet (you cannot distinguish it from the case where we later add a fourth class)