Skip to content

Commit 067b1b2

Browse files
fix: clarify TemplateExpressionSpec reload errors
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent beaa405 commit 067b1b2

2 files changed

Lines changed: 42 additions & 0 deletions

File tree

pysr/sr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
AbstractExpressionSpec,
4040
ExpressionSpec,
4141
ParametricExpressionSpec,
42+
TemplateExpressionSpec,
4243
parametric_expression_deprecation_warning,
4344
)
4445
from .feature_selection import run_feature_selection
@@ -1465,8 +1466,23 @@ def __getstate__(self) -> dict[str, Any]:
14651466
pickled_state["equations_"], pickled_columns
14661467
)
14671468
]
1469+
pickled_state["_loaded_from_pickle_"] = False
14681470
return pickled_state
14691471

1472+
def __setstate__(self, state: dict[str, Any]) -> None:
1473+
self.__dict__.update(state)
1474+
self._loaded_from_pickle_ = True
1475+
1476+
def _raise_for_reloaded_template_expression_spec(self) -> None:
1477+
if getattr(self, "_loaded_from_pickle_", False) and isinstance(
1478+
self.expression_spec, TemplateExpressionSpec
1479+
):
1480+
raise NotImplementedError(
1481+
"Loading models fitted with TemplateExpressionSpec is not yet supported. "
1482+
"Please refit the model in the current Python/Julia session before "
1483+
"calling predict or export methods."
1484+
)
1485+
14701486
def _checkpoint(self):
14711487
"""Save the model's current state to a checkpoint file.
14721488
@@ -2640,6 +2656,7 @@ def predict(
26402656
check_is_fitted(
26412657
self, attributes=["selection_mask_", "feature_names_in_", "nout_"]
26422658
)
2659+
self._raise_for_reloaded_template_expression_spec()
26432660
best_equation = self.get_best(index=index)
26442661

26452662
# When X is an numpy array or a pandas dataframe with a RangeIndex,
@@ -2730,6 +2747,7 @@ def sympy(self, index: int | list[int] | None = None):
27302747
best_equation : str, list[str] of length nout_
27312748
SymPy representation of the best equation.
27322749
"""
2750+
self._raise_for_reloaded_template_expression_spec()
27332751
if not self.expression_spec_.supports_sympy:
27342752
raise ValueError(
27352753
f"`expression_spec={self.expression_spec_}` does not support sympy export."
@@ -2766,6 +2784,7 @@ def latex(
27662784
best_equation : str or list[str] of length nout_
27672785
LaTeX expression of the best equation.
27682786
"""
2787+
self._raise_for_reloaded_template_expression_spec()
27692788
if not self.expression_spec_.supports_latex:
27702789
raise ValueError(
27712790
f"`expression_spec={self.expression_spec_}` does not support latex export."
@@ -2803,6 +2822,7 @@ def jax(self, index=None):
28032822
Dictionary of callable jax function in "callable" key,
28042823
and jax array of parameters as "parameters" key.
28052824
"""
2825+
self._raise_for_reloaded_template_expression_spec()
28062826
if not self.expression_spec_.supports_jax:
28072827
raise ValueError(
28082828
f"`expression_spec={self.expression_spec_}` does not support jax export."
@@ -2839,6 +2859,7 @@ def pytorch(self, index=None):
28392859
best_equation : torch.nn.Module
28402860
PyTorch module representing the expression.
28412861
"""
2862+
self._raise_for_reloaded_template_expression_spec()
28422863
if not self.expression_spec_.supports_torch:
28432864
raise ValueError(
28442865
f"`expression_spec={self.expression_spec_}` does not support torch export."
@@ -2983,6 +3004,7 @@ def latex_table(
29833004
latex_table_str : str
29843005
A string that will render a table in LaTeX of the equations.
29853006
"""
3007+
self._raise_for_reloaded_template_expression_spec()
29863008
if not self.expression_spec_.supports_latex:
29873009
raise ValueError(
29883010
f"`expression_spec={self.expression_spec_}` does not support latex export."

pysr/test/test_main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,26 @@ def test_process_constraints_swaps_multiplication_constraints(self):
22462246

22472247

22482248
class TestTemplateExpressionSpec(unittest.TestCase):
2249+
def test_reloaded_template_expression_spec_raises_targeted_error(self):
2250+
model = PySRRegressor(
2251+
expression_spec=TemplateExpressionSpec(
2252+
combine="f(x)", expressions=["f"], variable_names=["x"]
2253+
)
2254+
)
2255+
model.selection_mask_ = np.array([True])
2256+
model.feature_names_in_ = np.array(["x"])
2257+
model.nout_ = 1
2258+
2259+
loaded = pkl.loads(pkl.dumps(model))
2260+
2261+
message = (
2262+
"Loading models fitted with TemplateExpressionSpec is not yet supported"
2263+
)
2264+
with self.assertRaisesRegex(NotImplementedError, message):
2265+
loaded.predict(np.ones((1, 1)))
2266+
with self.assertRaisesRegex(NotImplementedError, message):
2267+
loaded.sympy()
2268+
22492269
def _check_macro_str(self, spec, expected_str):
22502270
self.assertEqual(
22512271
spec._template_macro_str().strip(), dedent(expected_str).strip()

0 commit comments

Comments
 (0)