Skip to content

Commit 1e18b9a

Browse files
MilesCranmerBotpre-commit-ci[bot]MilesCranmer
authored
feat: raise friendly error when loss functions have bad signatures (#1138)
* fix: validate loss_function arity and raise friendly error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: cover loss_function arity validation branches * test: add docstrings for #982 validation tests * refactor: extract loss_function validation helper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: move loss_function validator near class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: place loss_function validator near other checks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: avoid custom closures * fix: improve loss function validation and errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: make loss validators beartype-safe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: DRY loss validators and revert docstring edits Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: remove blanket exception in jl.isnothing helper Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * fix: validate elementwise_loss arity based on weights Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * test: cover loss validation branches Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * fix: apply review suggestions for loss validation Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * fix: remove _jl_is_nothing helper Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * fix: clarify loss validator arity + error text Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com> * test: cover signature error branch in loss validators * test: avoid julia function name collisions in loss validation * refactor: hardcode elementwise_loss suggestion --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent 3bffc62 commit 1e18b9a

2 files changed

Lines changed: 282 additions & 0 deletions

File tree

pysr/sr.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,84 @@ def _check_assertions(
235235
)
236236

237237

238+
def _validate_elementwise_loss(custom_loss, *, has_weights: bool) -> None:
239+
"""Validate that a Julia `elementwise_loss` is callable.
240+
241+
We require exactly 2 args unless the user passed `weights=` to fit,
242+
in which case we require 3 args.
243+
"""
244+
245+
# This can be either a LossFunctions.jl object (e.g. `L2DistLoss()`) or a Julia function.
246+
# Only validate arity when the evaluated object is actually a function.
247+
if not jl_is_function(custom_loss):
248+
return
249+
250+
if has_weights:
251+
ok = bool(jl.applicable(custom_loss, 1.0, 1.0, 1.0))
252+
if not ok:
253+
raise ValueError(
254+
"`elementwise_loss` must accept (prediction, target, weight) when `weights` is passed to `fit`."
255+
)
256+
else:
257+
ok = bool(jl.applicable(custom_loss, 1.0, 1.0))
258+
if not ok:
259+
raise ValueError(
260+
"`elementwise_loss` must accept (prediction, target). If you intended a full objective, use "
261+
"`loss_function` or `loss_function_expression`."
262+
)
263+
264+
265+
def _validate_custom_objective(
266+
custom_objective,
267+
*,
268+
knob,
269+
signature,
270+
other_alternative=None,
271+
) -> None:
272+
if not jl_is_function(custom_objective):
273+
raise ValueError(f"`{knob}` must evaluate to a callable Julia function.")
274+
275+
methods = jl.collect(jl.methods(custom_objective))
276+
277+
def _accepts_npos(m, npos: int) -> bool:
278+
required_npos = int(m.nargs) - 1
279+
if bool(m.isva):
280+
return required_npos <= npos
281+
return required_npos == npos
282+
283+
accepts_three_args = any(_accepts_npos(m, 3) for m in methods)
284+
accepts_two_args = any(_accepts_npos(m, 2) for m in methods)
285+
286+
if not accepts_three_args and accepts_two_args:
287+
msg = (
288+
f"`{knob}` must have signature like {signature}. "
289+
"If you intended an elementwise loss, use `elementwise_loss`."
290+
)
291+
if other_alternative is not None:
292+
msg += f" If you intended the other full-objective mode, use `{other_alternative}`."
293+
raise ValueError(msg)
294+
295+
if not accepts_three_args:
296+
raise ValueError(f"`{knob}` must have signature {signature}.")
297+
298+
299+
def _validate_custom_full_objective(custom_full_objective) -> None:
300+
_validate_custom_objective(
301+
custom_full_objective,
302+
knob="loss_function",
303+
signature="(tree, dataset, options)",
304+
other_alternative="loss_function_expression",
305+
)
306+
307+
308+
def _validate_custom_expression_objective(custom_loss_expression) -> None:
309+
_validate_custom_objective(
310+
custom_loss_expression,
311+
knob="loss_function_expression",
312+
signature="(expression, dataset, options)",
313+
)
314+
315+
238316
def _validate_export_mappings(extra_jax_mappings, extra_torch_mappings):
239317
# It is expected extra_jax/torch_mappings will be updated after fit.
240318
# Thus, validation is performed here instead of in _validate_init_params
@@ -2036,14 +2114,22 @@ def _run(
20362114
if self.elementwise_loss is not None
20372115
else "nothing"
20382116
)
2117+
if self.elementwise_loss is not None:
2118+
_validate_elementwise_loss(custom_loss, has_weights=weights is not None)
2119+
20392120
custom_full_objective = jl.seval(
20402121
str(self.loss_function) if self.loss_function is not None else "nothing"
20412122
)
2123+
if self.loss_function is not None:
2124+
_validate_custom_full_objective(custom_full_objective)
2125+
20422126
custom_loss_expression = jl.seval(
20432127
str(self.loss_function_expression)
20442128
if self.loss_function_expression is not None
20452129
else "nothing"
20462130
)
2131+
if self.loss_function_expression is not None:
2132+
_validate_custom_expression_objective(custom_loss_expression)
20472133

20482134
early_stop_condition = jl.seval(
20492135
str(self.early_stop_condition)

pysr/test/test_main.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
_check_assertions,
4242
_process_constraints,
4343
_suggest_keywords,
44+
_validate_elementwise_loss,
45+
_validate_export_mappings,
4446
idx_model_selection,
4547
)
4648

@@ -183,6 +185,200 @@ def test_high_precision_search_custom_loss(self):
183185
jl.seval("((::Val{x}) where x) -> x")(model.julia_options_.turbo), False
184186
)
185187

188+
def test_loss_function_with_elementwise_signature_errors_early(self):
189+
"""Issue #982: elementwise (prediction, target) loss passed via loss_function errors."""
190+
model = PySRRegressor(
191+
niterations=1,
192+
populations=1,
193+
procs=0,
194+
progress=False,
195+
verbosity=0,
196+
temp_equation_file=True,
197+
binary_operators=["+"],
198+
loss_function="bad_full_objective(prediction, target) = (prediction - target)^2",
199+
)
200+
X = np.array([[0.0], [1.0]])
201+
y = np.array([0.0, 1.0])
202+
with self.assertRaises(ValueError) as cm:
203+
model.fit(X, y)
204+
self.assertIn("elementwise_loss", str(cm.exception))
205+
206+
def test_loss_function_noncallable_errors_early(self):
207+
"""Issue #982: non-callable loss_function (e.g. '1.0') errors early."""
208+
model = PySRRegressor(
209+
niterations=1,
210+
populations=1,
211+
procs=0,
212+
progress=False,
213+
verbosity=0,
214+
temp_equation_file=True,
215+
binary_operators=["+"],
216+
loss_function="1.0",
217+
)
218+
X = np.array([[0.0], [1.0]])
219+
y = np.array([0.0, 1.0])
220+
with self.assertRaises(ValueError) as cm:
221+
model.fit(X, y)
222+
self.assertIn("callable", str(cm.exception))
223+
224+
def test_loss_function_valid_full_objective_runs(self):
225+
"""Issue #982: a valid (tree, dataset, options) objective is accepted."""
226+
model = PySRRegressor(
227+
niterations=1,
228+
populations=1,
229+
procs=0,
230+
progress=False,
231+
verbosity=0,
232+
temp_equation_file=True,
233+
binary_operators=["+"],
234+
loss_function="""
235+
begin
236+
goodloss(tree, dataset, options) = zero(eltype(dataset.y))
237+
goodloss
238+
end
239+
""",
240+
)
241+
X = np.array([[0.0], [1.0]])
242+
y = np.array([0.0, 1.0])
243+
model.fit(X, y)
244+
245+
def test_loss_function_varargs_objective_runs(self):
246+
model = PySRRegressor(
247+
niterations=1,
248+
populations=1,
249+
procs=0,
250+
progress=False,
251+
verbosity=0,
252+
temp_equation_file=True,
253+
binary_operators=["+"],
254+
loss_function="""
255+
begin
256+
goodvarloss(tree, dataset, options...) = zero(eltype(dataset.y))
257+
goodvarloss
258+
end
259+
""",
260+
)
261+
X = np.array([[0.0], [1.0]])
262+
y = np.array([0.0, 1.0])
263+
model.fit(X, y)
264+
265+
def test_elementwise_loss_wrong_signature_errors_early(self):
266+
"""Validate `elementwise_loss` signature (prediction, target[, weights])."""
267+
model = PySRRegressor(
268+
niterations=1,
269+
populations=1,
270+
procs=0,
271+
progress=False,
272+
verbosity=0,
273+
temp_equation_file=True,
274+
binary_operators=["+"],
275+
elementwise_loss="myloss_bad_arity(a) = a",
276+
)
277+
X = np.array([[0.0], [1.0]])
278+
y = np.array([0.0, 1.0])
279+
with self.assertRaises(ValueError) as cm:
280+
model.fit(X, y)
281+
self.assertIn("elementwise_loss", str(cm.exception))
282+
283+
def test_elementwise_loss_with_weights_requires_three_args(self):
284+
model = PySRRegressor(
285+
niterations=1,
286+
populations=1,
287+
procs=0,
288+
progress=False,
289+
verbosity=0,
290+
temp_equation_file=True,
291+
binary_operators=["+"],
292+
elementwise_loss="myloss2(prediction, target) = (prediction - target)^2",
293+
)
294+
X = np.array([[0.0], [1.0]])
295+
y = np.array([0.0, 1.0])
296+
weights = np.array([1.0, 1.0])
297+
with self.assertRaises(ValueError) as cm:
298+
model.fit(X, y, weights=weights)
299+
self.assertIn("elementwise_loss", str(cm.exception))
300+
self.assertIn("weights", str(cm.exception))
301+
302+
def test_elementwise_loss_with_weights_accepts_three_args(self):
303+
model = PySRRegressor(
304+
niterations=1,
305+
populations=1,
306+
procs=0,
307+
progress=False,
308+
verbosity=0,
309+
temp_equation_file=True,
310+
binary_operators=["+"],
311+
elementwise_loss=(
312+
"myloss3(prediction, target, weights) = weights * (prediction - target)^2"
313+
),
314+
)
315+
X = np.array([[0.0], [1.0]])
316+
y = np.array([0.0, 1.0])
317+
weights = np.array([1.0, 1.0])
318+
model.fit(X, y, weights=weights)
319+
320+
def test_validation_helpers_skip_nonfunction(self):
321+
_validate_elementwise_loss(jl.seval("1.0"), has_weights=False)
322+
323+
def test_validate_export_mappings_typechecks(self):
324+
with self.assertRaises(ValueError):
325+
_validate_export_mappings({"a": 1}, None)
326+
327+
def test_loss_function_expression_elementwise_signature_errors_early(self):
328+
"""Validate `loss_function_expression` signature (expression, dataset, options)."""
329+
model = PySRRegressor(
330+
niterations=1,
331+
populations=1,
332+
procs=0,
333+
progress=False,
334+
verbosity=0,
335+
temp_equation_file=True,
336+
binary_operators=["+"],
337+
loss_function_expression="bad_expr_objective(prediction, target) = (prediction - target)^2",
338+
)
339+
X = np.array([[0.0], [1.0]])
340+
y = np.array([0.0, 1.0])
341+
with self.assertRaises(ValueError) as cm:
342+
model.fit(X, y)
343+
self.assertIn("loss_function_expression", str(cm.exception))
344+
self.assertIn("elementwise_loss", str(cm.exception))
345+
346+
def test_loss_function_wrong_signature_errors_early(self):
347+
model = PySRRegressor(
348+
niterations=1,
349+
populations=1,
350+
procs=0,
351+
progress=False,
352+
verbosity=0,
353+
temp_equation_file=True,
354+
binary_operators=["+"],
355+
loss_function="badloss(tree) = 0.0",
356+
)
357+
X = np.array([[0.0], [1.0]])
358+
y = np.array([0.0, 1.0])
359+
with self.assertRaises(ValueError) as cm:
360+
model.fit(X, y)
361+
self.assertIn("loss_function", str(cm.exception))
362+
self.assertIn("(tree, dataset, options)", str(cm.exception))
363+
364+
def test_loss_function_expression_wrong_signature_errors_early(self):
365+
model = PySRRegressor(
366+
niterations=1,
367+
populations=1,
368+
procs=0,
369+
progress=False,
370+
verbosity=0,
371+
temp_equation_file=True,
372+
binary_operators=["+"],
373+
loss_function_expression="badexprloss(expression) = 0.0",
374+
)
375+
X = np.array([[0.0], [1.0]])
376+
y = np.array([0.0, 1.0])
377+
with self.assertRaises(ValueError) as cm:
378+
model.fit(X, y)
379+
self.assertIn("loss_function_expression", str(cm.exception))
380+
self.assertIn("(expression, dataset, options)", str(cm.exception))
381+
186382
def test_operator_conflict_error(self):
187383
regressor = PySRRegressor(
188384
operators={1: ["sin"]},

0 commit comments

Comments
 (0)