-
Notifications
You must be signed in to change notification settings - Fork 79
Add GaussianProcessSurrogate.posterior_mean property #823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e65ca98
240d81d
755a80b
cff91db
7098fe3
3f19c10
688bd1a
45b82b6
37bb6a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| from attrs.validators import instance_of, is_callable | ||
| from typing_extensions import Self, override | ||
|
|
||
| from baybe.exceptions import DeprecationError | ||
| from baybe.exceptions import DeprecationError, ModelNotTrainedError | ||
| from baybe.kernels.base import Kernel | ||
| from baybe.objectives.base import Objective | ||
| from baybe.parameters.base import Parameter | ||
|
|
@@ -48,11 +48,12 @@ | |
| ) | ||
| from baybe.utils.boolean import strtobool | ||
| from baybe.utils.conversion import to_string | ||
| from baybe.utils.dataframe import to_tensor | ||
|
|
||
| if TYPE_CHECKING: | ||
| from botorch.models.gpytorch import GPyTorchModel | ||
| from botorch.models.transforms.input import InputTransform | ||
| from botorch.models.transforms.outcome import OutcomeTransform | ||
| from botorch.models.transforms.input import InputTransform, Normalize | ||
| from botorch.models.transforms.outcome import OutcomeTransform, Standardize | ||
| from botorch.posteriors import Posterior | ||
| from gpytorch.kernels import Kernel as GPyTorchKernel | ||
| from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood | ||
|
|
@@ -211,6 +212,31 @@ class GaussianProcessSurrogate(Surrogate): | |
| _model = field(init=False, default=None, eq=False) | ||
| """The actual model.""" | ||
|
|
||
| @staticmethod | ||
| def _make_input_transform(context: _ModelContext) -> Normalize: | ||
| """Create the input transform for the Gaussian process.""" | ||
| from botorch.models.transforms.input import Normalize | ||
|
|
||
| return Normalize( | ||
| len(context.searchspace.comp_rep_columns), | ||
| bounds=context.parameter_bounds, | ||
| indices=context.numerical_indices, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _make_outcome_transform(context: _ModelContext) -> Standardize: | ||
| """Create the outcome transform for the Gaussian process.""" | ||
| from botorch.models.transforms.outcome import Standardize | ||
|
|
||
| train_y = to_tensor( | ||
| context.objective._pre_transform(context.measurements, allow_extra=True) | ||
| ) | ||
| if train_y.ndim == 1: | ||
| train_y = train_y.unsqueeze(-1) | ||
| transform = Standardize(m=train_y.shape[-1]) | ||
| transform(train_y) # fits means/stdvs; GP will re-fit in train mode | ||
| return transform | ||
|
|
||
| @classmethod | ||
| def from_preset( | ||
| cls, | ||
|
|
@@ -246,6 +272,82 @@ def from_preset( | |
| gp._custom_kernel = False # preset are first-party features | ||
| return gp | ||
|
|
||
| def posterior_mean_function( | ||
| self, | ||
| searchspace: SearchSpace, | ||
| objective: Objective, | ||
| measurements: pd.DataFrame, | ||
| ) -> GPyTorchMean: | ||
| """Return a GPyTorch mean module representing the surrogate's posterior mean. | ||
|
|
||
| The bound method satisfies | ||
| :class:`~baybe.surrogates.gaussian_process.components.mean.MeanFactoryProtocol` | ||
| and can be passed directly to a new :class:`GaussianProcessSurrogate`. | ||
|
|
||
| Args: | ||
| searchspace: The search space of the new GP being fitted. | ||
| objective: The objective of the new GP being fitted. | ||
| measurements: The training data of the new GP being fitted. | ||
|
|
||
| Returns: | ||
| The posterior mean. | ||
|
|
||
| Raises: | ||
| ModelNotTrainedError: If the surrogate has not been fitted yet. | ||
| """ | ||
| from copy import deepcopy | ||
|
|
||
| import gpytorch | ||
|
|
||
| if self._model is None: | ||
| raise ModelNotTrainedError( | ||
| f"'{self.__class__.__name__}' must be fitted before its " | ||
| f"'{self.posterior_mean_function.__name__}' can be used as a " | ||
| f"mean function." | ||
| ) | ||
|
|
||
| context = _ModelContext(searchspace, objective, measurements) | ||
|
|
||
| # Undo the new GP's input normalization before querying the prior GP | ||
| input_transform = self._make_input_transform(context) | ||
| input_transform.eval() | ||
|
|
||
| # Match the new GP's outcome standardization | ||
| outcome_transform = self._make_outcome_transform(context) | ||
| outcome_transform.eval() | ||
|
|
||
| class _PosteriorMean(gpytorch.means.Mean): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if relevant, but Claude raised some concerns with this design here that I at least want to comment on and open for discussion:
|
||
| """GPyTorch mean wrapping a trained GP's posterior. | ||
|
|
||
| Overrides ``train`` to keep all children in eval mode, preventing optimizers | ||
| from corrupting learned transform parameters. | ||
| """ | ||
|
|
||
| def __init__(self, gp: GPyTorchModel) -> None: | ||
| super().__init__() | ||
| self.gp = deepcopy(gp) | ||
| for param in self.gp.parameters(): | ||
| param.requires_grad = False | ||
| self.gp.eval() | ||
| self.gp.likelihood.eval() | ||
|
|
||
| @override | ||
| def train(self, mode: bool = True) -> _PosteriorMean: | ||
| """Set training mode without propagating to children.""" | ||
| self.training = mode | ||
| return self | ||
|
|
||
| @override | ||
| def forward(self, x: Tensor) -> Tensor: | ||
| """Compute the mean using the wrapped GP's posterior.""" | ||
| with gpytorch.settings.fast_pred_var(): | ||
| x_raw = input_transform.untransform(x) | ||
| posterior_mean = self.gp.posterior(x_raw).mean | ||
| standardized, _ = outcome_transform(posterior_mean) | ||
| return standardized.squeeze(-1) | ||
|
|
||
| return _PosteriorMean(self._model) | ||
|
|
||
| @override | ||
| def to_botorch(self) -> GPyTorchModel: | ||
| return self._model | ||
|
|
@@ -271,7 +373,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: | |
| @override | ||
| def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | ||
| import botorch | ||
| from botorch.models.transforms import Normalize, Standardize | ||
|
|
||
| assert self._searchspace is not None # provided by base class | ||
| assert self._objective is not None # provided by base class | ||
|
|
@@ -298,12 +399,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | |
|
|
||
| ### Input/output scaling | ||
| # NOTE: For GPs, we let BoTorch handle scaling (see [Scaling Workaround] above) | ||
| input_transform = Normalize( | ||
| train_x.shape[-1], | ||
| bounds=context.parameter_bounds, | ||
| indices=context.numerical_indices, | ||
| ) | ||
| outcome_transform = Standardize(train_y.shape[-1]) | ||
| input_transform = self._make_input_transform(context) | ||
| outcome_transform = self._make_outcome_transform(context) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the change here, does the outcome transform is now coupled to the |
||
|
|
||
| ### Mean | ||
| mean = self.mean_factory( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -218,3 +218,89 @@ def test_botorch_preset(multitask: bool, preset: str): | |
| posterior2 = _posterior_stats_botorch(sp, data) | ||
|
|
||
| assert_frame_equal(posterior1, posterior2) | ||
|
|
||
|
|
||
| def test_get_posterior_mean_correct_under_different_bounds(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't we combine both tests and parametrize them? |
||
| """Posterior mean evaluates at correct physical points when bounds differ.""" | ||
| from baybe.parameters.numerical import NumericalDiscreteParameter | ||
|
|
||
| # Train a surrogate on a narrow search space [0, 5] | ||
| prior_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])] | ||
| prior_ss = SearchSpace.from_product(prior_params) | ||
| prior_obj = NumericalTarget(name="y").to_objective() | ||
|
|
||
| prior_surrogate = GaussianProcessSurrogate() | ||
| prior_meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]}) | ||
| prior_surrogate.fit(prior_ss, prior_obj, prior_meas) | ||
|
|
||
| # Get the surrogate's prediction at x1=2.5 | ||
| expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() | ||
|
|
||
| # New GP on a WIDER search space [0, 10], using the get_posterior_mean method | ||
| new_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0, 7.5, 10.0])] | ||
| new_ss = SearchSpace.from_product(new_params) | ||
|
|
||
| new_surrogate = GaussianProcessSurrogate( | ||
| mean_or_factory=prior_surrogate.posterior_mean_function | ||
| ) | ||
| # Train on data that lies exactly on the prior mean to avoid kernel effects | ||
| training_points = pd.DataFrame({"x1": [0.0, 10.0]}) | ||
| with torch.no_grad(): | ||
| training_targets = prior_surrogate.posterior(training_points).mean | ||
| new_meas = pd.DataFrame( | ||
| { | ||
| "x1": training_points["x1"], | ||
| "y": training_targets.numpy().ravel(), | ||
| } | ||
| ) | ||
| new_surrogate.fit(new_ss, prior_obj, new_meas) | ||
|
|
||
| # Test end-to-end: the posterior should match the prior mean | ||
| actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() | ||
|
|
||
| assert abs(actual_mean - expected_mean) < 1e-4 | ||
|
|
||
|
|
||
| def test_get_posterior_mean_same_bounds(): | ||
| """Posterior mean is correct when both search spaces have the same bounds.""" | ||
| from baybe.parameters.numerical import NumericalDiscreteParameter | ||
|
|
||
| params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])] | ||
| ss = SearchSpace.from_product(params) | ||
| obj = NumericalTarget(name="y").to_objective() | ||
|
|
||
| prior_surrogate = GaussianProcessSurrogate() | ||
| meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]}) | ||
| prior_surrogate.fit(ss, obj, meas) | ||
|
|
||
| expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() | ||
|
|
||
| new_surrogate = GaussianProcessSurrogate( | ||
| mean_or_factory=prior_surrogate.posterior_mean_function | ||
| ) | ||
| # Train on data that lies exactly on the prior mean | ||
| training_points = pd.DataFrame({"x1": [0.0, 5.0]}) | ||
| with torch.no_grad(): | ||
| training_targets = prior_surrogate.posterior(training_points).mean | ||
| new_meas = pd.DataFrame( | ||
| { | ||
| "x1": training_points["x1"], | ||
| "y": training_targets.numpy().ravel(), | ||
| } | ||
| ) | ||
| new_surrogate.fit(ss, obj, new_meas) | ||
|
|
||
| # Test end-to-end: the posterior should match the prior mean | ||
| actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item() | ||
|
|
||
| assert abs(actual_mean - expected_mean) < 1e-4 | ||
|
|
||
|
|
||
| def test_get_posterior_mean_raises_if_not_fitted(): | ||
| """Calling get_posterior_mean raises if the surrogate has not been fitted.""" | ||
| from baybe.exceptions import ModelNotTrainedError | ||
|
|
||
| with pytest.raises(ModelNotTrainedError, match="must be fitted"): | ||
| GaussianProcessSurrogate().posterior_mean_function( | ||
| searchspace, objective, measurements | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by the wording here. I think the "new" relates to how the created
GPyTorchMeanis intended to be used, right? Because in principle, this function has nothing to do with a "new" GP, so maybe just have something "for the which the mean is defined" or similar?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is however intended to only ever be used for the construction of a new GP, then I think there should be an additional explanatory paragraph in the docstring