|
5 | 5 | from collections.abc import Callable, Collection, Sequence
|
6 | 6 | from copy import deepcopy
|
7 | 7 | from dataclasses import dataclass
|
8 |
| -from typing import TypedDict |
| 8 | +from typing import Any, TypedDict |
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | import pandas as pd
|
@@ -1123,6 +1123,55 @@ def _default_scoring() -> Scoring:
|
1123 | 1123 | return default_scoring
|
1124 | 1124 | return dict(default_scoring) | dict(scoring)
|
1125 | 1125 |
|
| 1126 | + @property |
| 1127 | + def init_args(self) -> dict[str, Any]: |
| 1128 | + """Create initiliazation parameters for a new MetaLearner. |
| 1129 | +
|
| 1130 | + Importantly, this does not copy further internal state, such as the weights or |
| 1131 | + parameters of trained base models. |
| 1132 | + """ |
| 1133 | + return { |
| 1134 | + "is_classification": self.is_classification, |
| 1135 | + "n_variants": self.n_variants, |
| 1136 | + "nuisance_model_factory": { |
| 1137 | + k: v |
| 1138 | + for k, v in self.nuisance_model_factory.items() |
| 1139 | + if k != PROPENSITY_MODEL |
| 1140 | + if k not in self._prefitted_nuisance_models |
| 1141 | + }, |
| 1142 | + "treatment_model_factory": self.treatment_model_factory, |
| 1143 | + "propensity_model_factory": ( |
| 1144 | + self.nuisance_model_factory.get(PROPENSITY_MODEL) |
| 1145 | + if PROPENSITY_MODEL not in self._prefitted_nuisance_models |
| 1146 | + else None |
| 1147 | + ), |
| 1148 | + "nuisance_model_params": { |
| 1149 | + k: v |
| 1150 | + for k, v in self.nuisance_model_params.items() |
| 1151 | + if k != PROPENSITY_MODEL |
| 1152 | + if k not in self._prefitted_nuisance_models |
| 1153 | + }, |
| 1154 | + "treatment_model_params": self.treatment_model_params, |
| 1155 | + "propensity_model_params": ( |
| 1156 | + self.nuisance_model_params.get(PROPENSITY_MODEL) |
| 1157 | + if PROPENSITY_MODEL not in self._prefitted_nuisance_models |
| 1158 | + else None |
| 1159 | + ), |
| 1160 | + "fitted_nuisance_models": { |
| 1161 | + k: deepcopy(v) |
| 1162 | + for k, v in self._nuisance_models.items() |
| 1163 | + if k in self._prefitted_nuisance_models and k != PROPENSITY_MODEL |
| 1164 | + }, |
| 1165 | + "fitted_propensity_model": ( |
| 1166 | + deepcopy(self._nuisance_models.get(PROPENSITY_MODEL)) |
| 1167 | + if PROPENSITY_MODEL in self._prefitted_nuisance_models |
| 1168 | + else None |
| 1169 | + ), |
| 1170 | + "feature_set": self.feature_set, |
| 1171 | + "n_folds": self.n_folds, |
| 1172 | + "random_state": self.random_state, |
| 1173 | + } |
| 1174 | + |
1126 | 1175 |
|
1127 | 1176 | class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):
|
1128 | 1177 |
|
|
0 commit comments