File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 3939 AbstractExpressionSpec ,
4040 ExpressionSpec ,
4141 ParametricExpressionSpec ,
42+ TemplateExpressionSpec ,
4243 parametric_expression_deprecation_warning ,
4344)
4445from .feature_selection import run_feature_selection
@@ -1467,6 +1468,17 @@ def __getstate__(self) -> dict[str, Any]:
14671468 ]
14681469 return pickled_state
14691470
1471+ def __setstate__ (self , state : dict [str , Any ]) -> None :
1472+ # ponytail: raise immediately on reload instead of confusing SymPy error later
1473+ self .__dict__ .update (state )
1474+ if "equations_" in state and state ["equations_" ] is not None and isinstance (
1475+ self .expression_spec , TemplateExpressionSpec
1476+ ):
1477+ raise NotImplementedError (
1478+ "Reloading fitted TemplateExpressionSpec models is not yet supported. "
1479+ "Please refit the model in the current session."
1480+ )
1481+
14701482 def _checkpoint (self ):
14711483 """Save the model's current state to a checkpoint file.
14721484
Original file line number Diff line number Diff line change @@ -2246,6 +2246,20 @@ def test_process_constraints_swaps_multiplication_constraints(self):
22462246
22472247
22482248class TestTemplateExpressionSpec (unittest .TestCase ):
2249+ def test_reload_raises_clear_error (self ):
2250+ # ponytail: one check — reload of fitted template spec raises immediately
2251+ import pickle
2252+ model = PySRRegressor (
2253+ expression_spec = TemplateExpressionSpec (
2254+ combine = "f(x)" , expressions = ["f" ], variable_names = ["x" ]
2255+ )
2256+ )
2257+ model .equations_ = pd .DataFrame ({"loss" : [0.0 ]})
2258+ model .feature_names_in_ = np .array (["x" ])
2259+ model .nout_ = 1
2260+ with self .assertRaisesRegex (NotImplementedError , "not yet supported" ):
2261+ pickle .loads (pickle .dumps (model ))
2262+
22492263 def _check_macro_str (self , spec , expected_str ):
22502264 self .assertEqual (
22512265 spec ._template_macro_str ().strip (), dedent (expected_str ).strip ()
You can’t perform that action at this time.
0 commit comments