Skip to content

Commit 68c5fae

Browse files
authored
Refactor serialization workarounds (#577)
To solve an open serialization issue in #564 , this PR finally cleans up a lot of technical debt that has accumulated in the serialization machinery, which was built heavily on workarounds. ### Summary of the important changes * Most problems could be traced back to the fact that the "subclass serialization machinery" fueled by `unstructure_base` and `get_base_structure_hook` did not reuse hooks that were already set in place for certain classes, which required passing around redundant attribute overrides. --> The new mechanism simply dispatches to the existing hooks of the respective subclass. No more override duplication needed. * Each newly defined abstract type (i.e., where it is clear that only concrete subclass will ever be encountered at runtime) required manual activation of the "subclass serialization machinery", resulting boilerplate code spread across many different modules. --> Since abstractly annotated attributes always require the mechanism to be in place, it is now globally activated for all abstract types using appropriate predicate/factory hooks in the serialization module, removing all boilerplate. ### Problems not addressed * A severe problem of the current mechanism is that additional keys are still silently dropped from configuration strings at deserialization time, which can lead to critical silent bugs. Setting `forbid_extra_fields` would theoretically solve the problem in that additional keys would no longer be supported and would raise an according error when encountered. * However, one consequence of the disallowed additional keys is that one could also no longer provide a `type` field in serialization strings when the concrete type is already fully dictated by the context. For example, the `simplex_parameters` argument of `SubspaceDiscrete.from_simplex` would **no longer accept** type information in the corresponding serialization string, since the parameter type is always `NumericalDiscreteParameter` (as dictated by the type annotation) while the `product_parameters` argument **requires** type information (since the annotation type of the individual entries is simply "Parameter"). Overall, this is just a consistent behavior, but we really want to be less strict in these cases, so we first need to find a reasonable solution to the this problem.
2 parents b956088 + 988f90d commit 68c5fae

60 files changed

Lines changed: 424 additions & 558 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- `farthest_point_sampling` now also supports a collection of integers for
1919
`initialization`, using them for pre-selecting points
2020

21+
### Changed
22+
- `unstructure_base` and `get_base_structure_hook` (de-)serialization utilities
23+
have been replaced with `unstructure_with_type` and `make_base_structure_hook`
24+
- `to_dict` and `to_json` now accept an optional Boolean `add_type` argument
25+
2126
### Fixed
2227
- It is no longer possible to use identical names between parameters and targets
2328
- Random seed context is correctly set within benchmarks
2429
- Measurement input validation now respects typical tolerances associated with floating
2530
point representation inaccuracy
2631
- Exotic serialization issues with constraints and conditions arising from missing
2732
converters for floats
33+
- `MetaRecommender`'s no longer expose their private attributes via the constructor
2834

2935
### Removed
3036
- Telemetry

baybe/acquisition/base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
)
1616
from baybe.objectives.base import Objective
1717
from baybe.searchspace.core import SearchSpace
18-
from baybe.serialization.core import (
19-
converter,
20-
get_base_structure_hook,
21-
unstructure_base,
22-
)
2318
from baybe.serialization.mixin import SerialMixin
2419
from baybe.surrogates.base import SurrogateProtocol
2520
from baybe.utils.basic import classproperty
@@ -181,11 +176,5 @@ def _get_botorch_acqf_class(
181176
)
182177

183178

184-
# Register (un-)structure hooks
185-
converter.register_structure_hook(
186-
AcquisitionFunction, get_base_structure_hook(AcquisitionFunction)
187-
)
188-
converter.register_unstructure_hook(AcquisitionFunction, unstructure_base)
189-
190179
# Collect leftover original slotted classes processed by `attrs.define`
191180
gc.collect()

baybe/campaign.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def _drop_version(dict_: dict) -> dict:
845845
Campaign, converter, _cattrs_include_init_false=True
846846
)
847847
structure_hook = cattrs.gen.make_dict_structure_fn(
848-
Campaign, converter, _cattrs_include_init_false=True, _cattrs_forbid_extra_keys=True
848+
Campaign, converter, _cattrs_include_init_false=True
849849
)
850850
converter.register_unstructure_hook(
851851
Campaign, lambda x: _add_version(unstructure_hook(x))

baybe/constraints/base.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from attrs import define, field
1111
from attrs.validators import ge, instance_of, min_len
1212

13-
from baybe.constraints.deprecation import structure_constraints
13+
from baybe.constraints.deprecation import (
14+
ContinuousLinearEqualityConstraint,
15+
ContinuousLinearInequalityConstraint,
16+
)
1417
from baybe.serialization import (
1518
SerialMixin,
19+
)
20+
from baybe.serialization.core import (
1621
converter,
17-
unstructure_base,
1822
)
1923

2024
if TYPE_CHECKING:
@@ -211,12 +215,26 @@ class ContinuousNonlinearConstraint(ContinuousConstraint, ABC):
211215
"""Abstract base class for continuous nonlinear constraints."""
212216

213217

214-
# Register (un-)structure hooks
215-
converter.register_unstructure_hook(Constraint, unstructure_base)
218+
# >>>>> Deprecation handling
219+
_hook = converter.get_structure_hook(Constraint)
220+
221+
222+
def _deprecate_legacy_classes(dct: dict[str, Any], _) -> Constraint:
223+
"""Enable constraint configs using legacy class names."""
224+
if dct["type"] == "ContinuousLinearEqualityConstraint":
225+
dct.pop("type")
226+
return ContinuousLinearEqualityConstraint(**dct)
227+
elif dct["type"] == "ContinuousLinearInequalityConstraint":
228+
dct.pop("type")
229+
return ContinuousLinearInequalityConstraint(**dct)
230+
return _hook(dct, _)
231+
232+
233+
converter.register_structure_hook_func(
234+
lambda c: c is Constraint, _deprecate_legacy_classes
235+
)
236+
# <<<<< Deprecation handling
216237

217-
# Currently affected by a deprecation
218-
# converter.register_structure_hook(Constraint, get_base_structure_hook(Constraint))
219-
converter.register_structure_hook(Constraint, structure_constraints)
220238

221239
# Collect leftover original slotted classes processed by `attrs.define`
222240
gc.collect()

baybe/constraints/conditions.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from functools import partial
1010
from typing import TYPE_CHECKING, Any
1111

12-
import cattrs
1312
import numpy as np
1413
import pandas as pd
1514
from attrs import define, field
@@ -22,9 +21,6 @@
2221
from baybe.parameters.validation import validate_unique_values
2322
from baybe.serialization import (
2423
SerialMixin,
25-
converter,
26-
get_base_structure_hook,
27-
unstructure_base,
2824
)
2925
from baybe.utils.basic import to_tuple
3026
from baybe.utils.numerical import DTypeFloatNumpy
@@ -230,18 +226,5 @@ def to_polars(self, expr: pl.Expr, /) -> pl.Expr:
230226
return expr.is_in(self.selection)
231227

232228

233-
# Register (un-)structure hooks
234-
_overrides = {
235-
"_selection": cattrs.override(rename="selection"),
236-
}
237-
# FIXME[typing]: https://github.com/python/mypy/issues/4717
238-
converter.register_structure_hook(
239-
Condition,
240-
get_base_structure_hook(Condition, overrides=_overrides), # type: ignore
241-
)
242-
converter.register_unstructure_hook(
243-
Condition, partial(unstructure_base, overrides=_overrides)
244-
)
245-
246229
# Collect leftover original slotted classes processed by `attrs.define`
247230
gc.collect()

baybe/constraints/deprecation.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,7 @@
33
from __future__ import annotations
44

55
import warnings
6-
from collections.abc import Callable
7-
from typing import TYPE_CHECKING, Any
8-
9-
from cattrs.gen import make_dict_structure_fn
10-
11-
from baybe.serialization import converter
12-
from baybe.utils.basic import find_subclass, refers_to
13-
from baybe.utils.boolean import is_abstract
14-
15-
if TYPE_CHECKING:
16-
from baybe.constraints.base import Constraint
6+
from typing import Any
177

188

199
def ContinuousLinearEqualityConstraint(
@@ -62,36 +52,3 @@ def ContinuousLinearInequalityConstraint(
6252
kwargs["rhs"] = rhs
6353

6454
return ContinuousLinearConstraint(**kwargs)
65-
66-
67-
def structure_constraints(val: dict, cls: type) -> Constraint:
68-
"""A structure hook taking care of deprecations.""" # noqa: D401 (imperative mood)
69-
from baybe.constraints.base import Constraint
70-
71-
# If the given class can be instantiated, only ensure there is no conflict with
72-
# a potentially specified type field
73-
if not is_abstract(cls):
74-
if (type_ := val.pop("type", None)) and not refers_to(cls, type_):
75-
raise ValueError(
76-
f"The class '{cls.__name__}' specified for deserialization "
77-
f"does not match with the given type information '{type_}'."
78-
)
79-
concrete_cls = cls
80-
81-
# Otherwise, extract the type information from the given input and find
82-
# the corresponding class in the hierarchy
83-
else:
84-
type_ = val if isinstance(val, str) else val.pop("type")
85-
86-
if type_ == "ContinuousLinearEqualityConstraint":
87-
return ContinuousLinearEqualityConstraint(**val)
88-
elif type_ == "ContinuousLinearInequalityConstraint":
89-
return ContinuousLinearInequalityConstraint(**val)
90-
91-
concrete_cls = find_subclass(Constraint, type_)
92-
93-
# Create the structuring function for the class and call it
94-
fn: Callable = make_dict_structure_fn(
95-
concrete_cls, converter, _cattrs_forbid_extra_keys=True
96-
)
97-
return fn({} if isinstance(val, str) else val, concrete_cls)

baybe/kernels/base.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@
1010

1111
from baybe.exceptions import UnmatchedAttributeError
1212
from baybe.priors.base import Prior
13-
from baybe.serialization.core import (
14-
converter,
15-
get_base_structure_hook,
16-
unstructure_base,
17-
)
1813
from baybe.serialization.mixin import SerialMixin
1914
from baybe.utils.basic import get_baseclasses, match_attributes
2015

@@ -125,9 +120,5 @@ class CompositeKernel(Kernel, ABC):
125120
"""Abstract base class for all composite kernels."""
126121

127122

128-
# Register (un-)structure hooks
129-
converter.register_structure_hook(Kernel, get_base_structure_hook(Kernel))
130-
converter.register_unstructure_hook(Kernel, unstructure_base)
131-
132123
# Collect leftover original slotted classes processed by `attrs.define`
133124
gc.collect()

baybe/objectives/base.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,9 @@
44
from abc import ABC, abstractmethod
55
from typing import ClassVar
66

7-
import cattrs
87
import pandas as pd
98
from attrs import define, field
109

11-
from baybe.serialization.core import (
12-
converter,
13-
get_base_structure_hook,
14-
unstructure_base,
15-
)
1610
from baybe.serialization.mixin import SerialMixin
1711
from baybe.targets.base import Target
1812
from baybe.utils.metadata import Metadata, to_metadata
@@ -82,26 +76,5 @@ def to_objective(x: Target | Objective, /) -> Objective:
8276
return x if isinstance(x, Objective) else x.to_objective()
8377

8478

85-
# Register (un-)structure hooks
86-
converter.register_structure_hook(
87-
Objective,
88-
get_base_structure_hook(
89-
Objective,
90-
overrides=dict(
91-
_target=cattrs.override(rename="target"),
92-
_targets=cattrs.override(rename="targets"),
93-
),
94-
),
95-
)
96-
converter.register_unstructure_hook(
97-
Objective,
98-
lambda x: unstructure_base(
99-
x,
100-
overrides=dict(
101-
_target=cattrs.override(rename="target"),
102-
_targets=cattrs.override(rename="targets"),
103-
),
104-
),
105-
)
10679
# Collect leftover original slotted classes processed by `attrs.define`
10780
gc.collect()

baybe/parameters/base.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
import gc
66
from abc import ABC, abstractmethod
7-
from functools import cached_property, partial
7+
from functools import cached_property
88
from typing import TYPE_CHECKING, Any, ClassVar
99

10-
import cattrs
1110
import pandas as pd
1211
from attrs import define, field
1312
from attrs.converters import optional as optional_c
@@ -17,9 +16,6 @@
1716
from baybe.parameters.enum import ParameterEncoding
1817
from baybe.serialization import (
1918
SerialMixin,
20-
converter,
21-
get_base_structure_hook,
22-
unstructure_base,
2319
)
2420
from baybe.utils.basic import to_tuple
2521
from baybe.utils.metadata import MeasurableMetadata, to_metadata
@@ -255,19 +251,5 @@ def to_subspace(self) -> SubspaceContinuous:
255251
return SubspaceContinuous.from_parameter(self)
256252

257253

258-
# Register (un-)structure hooks
259-
_overrides = {
260-
"_values": cattrs.override(rename="values"),
261-
"_active_values": cattrs.override(rename="active_values"),
262-
}
263-
# FIXME[typing]: https://github.com/python/mypy/issues/4717
264-
converter.register_structure_hook(
265-
Parameter,
266-
get_base_structure_hook(Parameter, overrides=_overrides), # type: ignore
267-
)
268-
converter.register_unstructure_hook(
269-
Parameter, partial(unstructure_base, overrides=_overrides)
270-
)
271-
272254
# Collect leftover original slotted classes processed by `attrs.define`
273255
gc.collect()

baybe/priors/base.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55

66
from attrs import define
77

8-
from baybe.serialization.core import (
9-
converter,
10-
get_base_structure_hook,
11-
unstructure_base,
12-
)
138
from baybe.serialization.mixin import SerialMixin
149
from baybe.utils.basic import match_attributes
1510

@@ -39,10 +34,5 @@ def to_gpytorch(self, *args, **kwargs):
3934
return prior_cls(*args, **kwargs)
4035

4136

42-
# Register (un-)structure hooks
43-
converter.register_structure_hook(Prior, get_base_structure_hook(Prior))
44-
converter.register_unstructure_hook(Prior, unstructure_base)
45-
46-
4737
# Collect leftover original slotted classes processed by `attrs.define`
4838
gc.collect()

0 commit comments

Comments
 (0)