Skip to content

Commit 1c60984

Browse files
fix: raise clear error on reload of TemplateExpressionSpec models
Closes #846 Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent beaa405 commit 1c60984

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

pysr/sr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
AbstractExpressionSpec,
4040
ExpressionSpec,
4141
ParametricExpressionSpec,
42+
TemplateExpressionSpec,
4243
parametric_expression_deprecation_warning,
4344
)
4445
from .feature_selection import run_feature_selection
@@ -1467,6 +1468,17 @@ def __getstate__(self) -> dict[str, Any]:
14671468
]
14681469
return pickled_state
14691470

1471+
def __setstate__(self, state: dict[str, Any]) -> None:
1472+
# ponytail: raise immediately on reload instead of confusing SymPy error later
1473+
self.__dict__.update(state)
1474+
if "equations_" in state and state["equations_"] is not None and isinstance(
1475+
self.expression_spec, TemplateExpressionSpec
1476+
):
1477+
raise NotImplementedError(
1478+
"Reloading fitted TemplateExpressionSpec models is not yet supported. "
1479+
"Please refit the model in the current session."
1480+
)
1481+
14701482
def _checkpoint(self):
14711483
"""Save the model's current state to a checkpoint file.
14721484

pysr/test/test_main.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,20 @@ def test_process_constraints_swaps_multiplication_constraints(self):
22462246

22472247

22482248
class TestTemplateExpressionSpec(unittest.TestCase):
2249+
def test_reload_raises_clear_error(self):
2250+
# ponytail: one check — reload of fitted template spec raises immediately
2251+
import pickle
2252+
model = PySRRegressor(
2253+
expression_spec=TemplateExpressionSpec(
2254+
combine="f(x)", expressions=["f"], variable_names=["x"]
2255+
)
2256+
)
2257+
model.equations_ = pd.DataFrame({"loss": [0.0]})
2258+
model.feature_names_in_ = np.array(["x"])
2259+
model.nout_ = 1
2260+
with self.assertRaisesRegex(NotImplementedError, "not yet supported"):
2261+
pickle.loads(pickle.dumps(model))
2262+
22492263
def _check_macro_str(self, spec, expected_str):
22502264
self.assertEqual(
22512265
spec._template_macro_str().strip(), dedent(expected_str).strip()

0 commit comments

Comments
 (0)