Skip to content
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `posterior_mean_function` method to `GaussianProcessSurrogate`

## [0.15.0] - 2026-06-11
### Breaking Changes
- `GaussianProcessSurrogate` no longer automatically adds a task kernel in multi-task
Expand Down
5 changes: 5 additions & 0 deletions baybe/surrogates/gaussian_process/components/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
from typing import TYPE_CHECKING, Any

import pandas as pd
Expand Down Expand Up @@ -40,3 +41,7 @@ def __call__(
from gpytorch.means import ConstantMean

return ConstantMean()


# Collect leftover original slotted classes created by attrs
gc.collect()
117 changes: 107 additions & 10 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from attrs.validators import instance_of, is_callable
from typing_extensions import Self, override

from baybe.exceptions import DeprecationError
from baybe.exceptions import DeprecationError, ModelNotTrainedError
from baybe.kernels.base import Kernel
from baybe.objectives.base import Objective
from baybe.parameters.base import Parameter
Expand Down Expand Up @@ -48,11 +48,12 @@
)
from baybe.utils.boolean import strtobool
from baybe.utils.conversion import to_string
from baybe.utils.dataframe import to_tensor

if TYPE_CHECKING:
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.posteriors import Posterior
from gpytorch.kernels import Kernel as GPyTorchKernel
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
Expand Down Expand Up @@ -211,6 +212,31 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
"""The actual model."""

@staticmethod
def _make_input_transform(context: _ModelContext) -> Normalize:
"""Create the input transform for the Gaussian process."""
from botorch.models.transforms.input import Normalize

return Normalize(
len(context.searchspace.comp_rep_columns),
bounds=context.parameter_bounds,
indices=context.numerical_indices,
)

@staticmethod
def _make_outcome_transform(context: _ModelContext) -> Standardize:
"""Create the outcome transform for the Gaussian process."""
from botorch.models.transforms.outcome import Standardize

train_y = to_tensor(
context.objective._pre_transform(context.measurements, allow_extra=True)
)
if train_y.ndim == 1:
train_y = train_y.unsqueeze(-1)
transform = Standardize(m=train_y.shape[-1])
transform(train_y) # fits means/stdvs; GP will re-fit in train mode
return transform

@classmethod
def from_preset(
cls,
Expand Down Expand Up @@ -246,6 +272,82 @@ def from_preset(
gp._custom_kernel = False # preset are first-party features
return gp

def posterior_mean_function(
self,
searchspace: SearchSpace,
objective: Objective,
measurements: pd.DataFrame,
) -> GPyTorchMean:
"""Return a GPyTorch mean module representing the surrogate's posterior mean.

The bound method satisfies
:class:`~baybe.surrogates.gaussian_process.components.mean.MeanFactoryProtocol`
and can be passed directly to a new :class:`GaussianProcessSurrogate`.

Args:
searchspace: The search space of the new GP being fitted.
objective: The objective of the new GP being fitted.
measurements: The training data of the new GP being fitted.
Comment on lines +288 to +290

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by the wording here. I think the "new" relates to how the created GPyTorchMean is intended to be used, right? Because in principle, this function has nothing to do with a "new" GP, so maybe just have something "for the which the mean is defined" or similar?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is however intended to only ever be used for the construction of a new GP, then I think there should be an additional explanatory paragraph in the docstring


Returns:
The posterior mean.

Raises:
ModelNotTrainedError: If the surrogate has not been fitted yet.
"""
from copy import deepcopy

import gpytorch

if self._model is None:
raise ModelNotTrainedError(
f"'{self.__class__.__name__}' must be fitted before its "
f"'{self.posterior_mean_function.__name__}' can be used as a "
f"mean function."
)

context = _ModelContext(searchspace, objective, measurements)

# Undo the new GP's input normalization before querying the prior GP
input_transform = self._make_input_transform(context)
input_transform.eval()

# Match the new GP's outcome standardization
outcome_transform = self._make_outcome_transform(context)
outcome_transform.eval()

class _PosteriorMean(gpytorch.means.Mean):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if relevant, but Claude raised some concerns with this design here that I at least want to comment on and open for discussion:

  1. Whenever there is a call of posterior_mean_function, we create a new object: I think this is not that much of an issue, but could we potentially cache it or is this not worth the additional trouble/code?
  2. It seems like due to the way that the internal class is handled, "input_transform and output_transform are invisible to PyTorch's module system". Consequences for this seem to be that they are not properly registered and are e.g. excluded from .state_dict() or .to(device) calls. Claude thus proposed to register those two functions as proper submodules and adjusting train accordingly. Tbh, I do not think that this is necessary - the only real issue I see with this is point is the one about the device when we move to GPUs, but we can also fix this then.
  3. The usage of gpytorch.settings.fast_pred_var() in forward has no effect since it "only affects variance computation via Lanczos estimates. Since we only access .mean, this context manager does nothing and is misleading." - not sure if true, not sure if relevant not sure if we need to care. I think maybe one test with and without the context manager and seeing whether or not it makes a difference would be sufficient.

"""GPyTorch mean wrapping a trained GP's posterior.

Overrides ``train`` to keep all children in eval mode, preventing optimizers
from corrupting learned transform parameters.
"""

def __init__(self, gp: GPyTorchModel) -> None:
super().__init__()
self.gp = deepcopy(gp)
for param in self.gp.parameters():
param.requires_grad = False
self.gp.eval()
self.gp.likelihood.eval()

@override
def train(self, mode: bool = True) -> _PosteriorMean:
"""Set training mode without propagating to children."""
self.training = mode
return self

@override
def forward(self, x: Tensor) -> Tensor:
"""Compute the mean using the wrapped GP's posterior."""
with gpytorch.settings.fast_pred_var():
x_raw = input_transform.untransform(x)
posterior_mean = self.gp.posterior(x_raw).mean
standardized, _ = outcome_transform(posterior_mean)
return standardized.squeeze(-1)

return _PosteriorMean(self._model)

@override
def to_botorch(self) -> GPyTorchModel:
return self._model
Expand All @@ -271,7 +373,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
import botorch
from botorch.models.transforms import Normalize, Standardize

assert self._searchspace is not None # provided by base class
assert self._objective is not None # provided by base class
Expand All @@ -298,12 +399,8 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:

### Input/output scaling
# NOTE: For GPs, we let BoTorch handle scaling (see [Scaling Workaround] above)
input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=context.numerical_indices,
)
outcome_transform = Standardize(train_y.shape[-1])
input_transform = self._make_input_transform(context)
outcome_transform = self._make_outcome_transform(context)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the change here, does the outcome transform is now coupled to the context and does not use the provided train_y but instead creates its own train_y internally that it uses. This feels a bit weird, and error-prone, so I would advocate for not using those transforms here


### Mean
mean = self.mean_factory(
Expand Down
86 changes: 86 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,89 @@ def test_botorch_preset(multitask: bool, preset: str):
posterior2 = _posterior_stats_botorch(sp, data)

assert_frame_equal(posterior1, posterior2)


def test_get_posterior_mean_correct_under_different_bounds():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we combine both tests and parametrize them?

"""Posterior mean evaluates at correct physical points when bounds differ."""
from baybe.parameters.numerical import NumericalDiscreteParameter

# Train a surrogate on a narrow search space [0, 5]
prior_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])]
prior_ss = SearchSpace.from_product(prior_params)
prior_obj = NumericalTarget(name="y").to_objective()

prior_surrogate = GaussianProcessSurrogate()
prior_meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]})
prior_surrogate.fit(prior_ss, prior_obj, prior_meas)

# Get the surrogate's prediction at x1=2.5
expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

# New GP on a WIDER search space [0, 10], using the get_posterior_mean method
new_params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0, 7.5, 10.0])]
new_ss = SearchSpace.from_product(new_params)

new_surrogate = GaussianProcessSurrogate(
mean_or_factory=prior_surrogate.posterior_mean_function
)
# Train on data that lies exactly on the prior mean to avoid kernel effects
training_points = pd.DataFrame({"x1": [0.0, 10.0]})
with torch.no_grad():
training_targets = prior_surrogate.posterior(training_points).mean
new_meas = pd.DataFrame(
{
"x1": training_points["x1"],
"y": training_targets.numpy().ravel(),
}
)
new_surrogate.fit(new_ss, prior_obj, new_meas)

# Test end-to-end: the posterior should match the prior mean
actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

assert abs(actual_mean - expected_mean) < 1e-4


def test_get_posterior_mean_same_bounds():
"""Posterior mean is correct when both search spaces have the same bounds."""
from baybe.parameters.numerical import NumericalDiscreteParameter

params = [NumericalDiscreteParameter("x1", values=[0.0, 2.5, 5.0])]
ss = SearchSpace.from_product(params)
obj = NumericalTarget(name="y").to_objective()

prior_surrogate = GaussianProcessSurrogate()
meas = pd.DataFrame({"x1": [0.0, 2.5, 5.0], "y": [0.0, 5.0, 10.0]})
prior_surrogate.fit(ss, obj, meas)

expected_mean = prior_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

new_surrogate = GaussianProcessSurrogate(
mean_or_factory=prior_surrogate.posterior_mean_function
)
# Train on data that lies exactly on the prior mean
training_points = pd.DataFrame({"x1": [0.0, 5.0]})
with torch.no_grad():
training_targets = prior_surrogate.posterior(training_points).mean
new_meas = pd.DataFrame(
{
"x1": training_points["x1"],
"y": training_targets.numpy().ravel(),
}
)
new_surrogate.fit(ss, obj, new_meas)

# Test end-to-end: the posterior should match the prior mean
actual_mean = new_surrogate.posterior(pd.DataFrame({"x1": [2.5]})).mean.item()

assert abs(actual_mean - expected_mean) < 1e-4


def test_get_posterior_mean_raises_if_not_fitted():
"""Calling get_posterior_mean raises if the surrogate has not been fitted."""
from baybe.exceptions import ModelNotTrainedError

with pytest.raises(ModelNotTrainedError, match="must be fitted"):
GaussianProcessSurrogate().posterior_mean_function(
searchspace, objective, measurements
)
Loading