|
39 | 39 | AbstractExpressionSpec, |
40 | 40 | ExpressionSpec, |
41 | 41 | ParametricExpressionSpec, |
| 42 | + TemplateExpressionSpec, |
42 | 43 | parametric_expression_deprecation_warning, |
43 | 44 | ) |
44 | 45 | from .feature_selection import run_feature_selection |
@@ -1465,8 +1466,23 @@ def __getstate__(self) -> dict[str, Any]: |
1465 | 1466 | pickled_state["equations_"], pickled_columns |
1466 | 1467 | ) |
1467 | 1468 | ] |
| 1469 | + pickled_state["_loaded_from_pickle_"] = False |
1468 | 1470 | return pickled_state |
1469 | 1471 |
|
| 1472 | + def __setstate__(self, state: dict[str, Any]) -> None: |
| 1473 | + self.__dict__.update(state) |
| 1474 | + self._loaded_from_pickle_ = True |
| 1475 | + |
| 1476 | + def _raise_for_reloaded_template_expression_spec(self) -> None: |
| 1477 | + if getattr(self, "_loaded_from_pickle_", False) and isinstance( |
| 1478 | + self.expression_spec, TemplateExpressionSpec |
| 1479 | + ): |
| 1480 | + raise NotImplementedError( |
| 1481 | + "Loading models fitted with TemplateExpressionSpec is not yet supported. " |
| 1482 | + "Please refit the model in the current Python/Julia session before " |
| 1483 | + "calling predict or export methods." |
| 1484 | + ) |
| 1485 | + |
1470 | 1486 | def _checkpoint(self): |
1471 | 1487 | """Save the model's current state to a checkpoint file. |
1472 | 1488 |
|
@@ -2640,6 +2656,7 @@ def predict( |
2640 | 2656 | check_is_fitted( |
2641 | 2657 | self, attributes=["selection_mask_", "feature_names_in_", "nout_"] |
2642 | 2658 | ) |
| 2659 | + self._raise_for_reloaded_template_expression_spec() |
2643 | 2660 | best_equation = self.get_best(index=index) |
2644 | 2661 |
|
2645 | 2662 | # When X is an numpy array or a pandas dataframe with a RangeIndex, |
@@ -2730,6 +2747,7 @@ def sympy(self, index: int | list[int] | None = None): |
2730 | 2747 | best_equation : str, list[str] of length nout_ |
2731 | 2748 | SymPy representation of the best equation. |
2732 | 2749 | """ |
| 2750 | + self._raise_for_reloaded_template_expression_spec() |
2733 | 2751 | if not self.expression_spec_.supports_sympy: |
2734 | 2752 | raise ValueError( |
2735 | 2753 | f"`expression_spec={self.expression_spec_}` does not support sympy export." |
@@ -2766,6 +2784,7 @@ def latex( |
2766 | 2784 | best_equation : str or list[str] of length nout_ |
2767 | 2785 | LaTeX expression of the best equation. |
2768 | 2786 | """ |
| 2787 | + self._raise_for_reloaded_template_expression_spec() |
2769 | 2788 | if not self.expression_spec_.supports_latex: |
2770 | 2789 | raise ValueError( |
2771 | 2790 | f"`expression_spec={self.expression_spec_}` does not support latex export." |
@@ -2803,6 +2822,7 @@ def jax(self, index=None): |
2803 | 2822 | Dictionary of callable jax function in "callable" key, |
2804 | 2823 | and jax array of parameters as "parameters" key. |
2805 | 2824 | """ |
| 2825 | + self._raise_for_reloaded_template_expression_spec() |
2806 | 2826 | if not self.expression_spec_.supports_jax: |
2807 | 2827 | raise ValueError( |
2808 | 2828 | f"`expression_spec={self.expression_spec_}` does not support jax export." |
@@ -2839,6 +2859,7 @@ def pytorch(self, index=None): |
2839 | 2859 | best_equation : torch.nn.Module |
2840 | 2860 | PyTorch module representing the expression. |
2841 | 2861 | """ |
| 2862 | + self._raise_for_reloaded_template_expression_spec() |
2842 | 2863 | if not self.expression_spec_.supports_torch: |
2843 | 2864 | raise ValueError( |
2844 | 2865 | f"`expression_spec={self.expression_spec_}` does not support torch export." |
@@ -2983,6 +3004,7 @@ def latex_table( |
2983 | 3004 | latex_table_str : str |
2984 | 3005 | A string that will render a table in LaTeX of the equations. |
2985 | 3006 | """ |
| 3007 | + self._raise_for_reloaded_template_expression_spec() |
2986 | 3008 | if not self.expression_spec_.supports_latex: |
2987 | 3009 | raise ValueError( |
2988 | 3010 | f"`expression_spec={self.expression_spec_}` does not support latex export." |
|
0 commit comments