Skip to content
Merged
30 changes: 21 additions & 9 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,26 +235,34 @@ def _check_assertions(
)


def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
"""Validate that a Julia `elementwise_loss` is callable.
def _validate_elementwise_loss(
custom_loss, *, has_weights: bool, probe_dtype: Callable[[float], Any] = float
Comment thread
MilesCranmerBot marked this conversation as resolved.
Outdated
) -> None:
"""Check whether a Julia `elementwise_loss` accepts the expected inputs.

We require exactly 2 args unless the user passed `weights=` to fit,
in which case we require 3 args.
The function probes the loss with two or three arguments, depending on
whether weights are present, using the dtype that fitting will use. If the
probe fails, it raises a `ValueError` describing the expected signature.
"""

# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
# Only validate arity when the evaluated object is actually a function.
if not jl_is_function(custom_loss):
return

probe_value = probe_dtype(1.0)
probe_args = (
(probe_value, probe_value, probe_value)
if has_weights
else (probe_value, probe_value)
)
ok = bool(jl.applicable(custom_loss, *probe_args))
if has_weights:
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
if not ok:
raise ValueError(
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
)
else:
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
if not ok:
raise ValueError(
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
Expand Down Expand Up @@ -2109,13 +2117,19 @@ def _run(
if isinstance(complexity_of_variables, list):
complexity_of_variables = jl_array(complexity_of_variables)

np_dtype = self._get_precision_mapped_dtype(np.array(X))

custom_loss = jl.seval(
str(self.elementwise_loss)
if self.elementwise_loss is not None
else "nothing"
)
if self.elementwise_loss is not None:
_validate_elementwise_loss(custom_loss, has_weights=weights is not None)
_validate_elementwise_loss(
custom_loss,
has_weights=weights is not None,
probe_dtype=np_dtype,
)

custom_full_objective = jl.seval(
str(self.loss_function) if self.loss_function is not None else "nothing"
Expand Down Expand Up @@ -2304,8 +2318,6 @@ def _run(
self.julia_options_stream_ = jl_serialize(options)

# Convert data to desired precision
test_X = np.array(X)
np_dtype = self._get_precision_mapped_dtype(test_X)

# This converts the data into a Julia array:
jl_X = jl_array(np.array(X, dtype=np_dtype).T)
Expand Down
28 changes: 28 additions & 0 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,34 @@ def test_elementwise_loss_with_weights_accepts_three_args(self):
weights = np.array([1.0, 1.0])
model.fit(X, y, weights=weights)

def test_elementwise_loss_float32_probe_accepts_strictly_typed_loss(self):
custom_loss = jl.seval(
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
)
_validate_elementwise_loss(
custom_loss,
has_weights=False,
probe_dtype=np.float32,
)

def test_elementwise_loss_float32_fit_accepts_strictly_typed_loss(self):
model = PySRRegressor(
niterations=1,
populations=1,
procs=0,
progress=False,
verbosity=0,
precision=32,
temp_equation_file=True,
binary_operators=["+"],
elementwise_loss=(
"(prediction::Float32, target::Float32) -> (prediction - target)^2"
),
)
X = np.array([[0.0], [1.0]], dtype=np.float32)
y = np.array([0.0, 1.0], dtype=np.float32)
model.fit(X, y)

def test_validation_helpers_skip_nonfunction(self):
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)

Expand Down
Loading