Skip to content

Commit 3cad00e

Browse files
Provide helper method to initiliaze a MetaLearner based on another MetaLearner (#71)
* Provide helper method to initialize MetaLearner. * Fix logic revolving around pre-fitted models. * Add changelog entry. * Expand on docstring. * Compare attributes. * Update metalearners/metalearner.py Co-authored-by: Francesc Martí Escofet <[email protected]> * Update metalearners/metalearner.py Co-authored-by: Francesc Martí Escofet <[email protected]> --------- Co-authored-by: Francesc Martí Escofet <[email protected]>
1 parent 2f428fb commit 3cad00e

File tree

4 files changed

+82
-2
lines changed

4 files changed

+82
-2
lines changed

CHANGELOG.rst

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Changelog
88
=========
99

10+
0.9.0 (2024-07-xx)
11+
------------------
12+
13+
**New features**
14+
15+
* Added :meth:`metalearners.metalearner.MetaLearner.init_params`.
16+
17+
1018
0.8.0 (2024-07-22)
1119
------------------
1220

metalearners/drlearner.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
from joblib import Parallel, delayed
7-
from typing_extensions import Self
7+
from typing_extensions import Any, Self
88

99
from metalearners._typing import (
1010
Features,
@@ -398,3 +398,7 @@ def _pseudo_outcome(
398398
)
399399

400400
return pseudo_outcome
401+
402+
@property
403+
def init_args(self) -> dict[str, Any]:
404+
return super().init_args | {"adaptive_clipping": self.adaptive_clipping}

metalearners/metalearner.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable, Collection, Sequence
66
from copy import deepcopy
77
from dataclasses import dataclass
8-
from typing import TypedDict
8+
from typing import Any, TypedDict
99

1010
import numpy as np
1111
import pandas as pd
@@ -1123,6 +1123,55 @@ def _default_scoring() -> Scoring:
11231123
return default_scoring
11241124
return dict(default_scoring) | dict(scoring)
11251125

1126+
@property
1127+
def init_args(self) -> dict[str, Any]:
1128+
"""Create initiliazation parameters for a new MetaLearner.
1129+
1130+
Importantly, this does not copy further internal state, such as the weights or
1131+
parameters of trained base models.
1132+
"""
1133+
return {
1134+
"is_classification": self.is_classification,
1135+
"n_variants": self.n_variants,
1136+
"nuisance_model_factory": {
1137+
k: v
1138+
for k, v in self.nuisance_model_factory.items()
1139+
if k != PROPENSITY_MODEL
1140+
if k not in self._prefitted_nuisance_models
1141+
},
1142+
"treatment_model_factory": self.treatment_model_factory,
1143+
"propensity_model_factory": (
1144+
self.nuisance_model_factory.get(PROPENSITY_MODEL)
1145+
if PROPENSITY_MODEL not in self._prefitted_nuisance_models
1146+
else None
1147+
),
1148+
"nuisance_model_params": {
1149+
k: v
1150+
for k, v in self.nuisance_model_params.items()
1151+
if k != PROPENSITY_MODEL
1152+
if k not in self._prefitted_nuisance_models
1153+
},
1154+
"treatment_model_params": self.treatment_model_params,
1155+
"propensity_model_params": (
1156+
self.nuisance_model_params.get(PROPENSITY_MODEL)
1157+
if PROPENSITY_MODEL not in self._prefitted_nuisance_models
1158+
else None
1159+
),
1160+
"fitted_nuisance_models": {
1161+
k: deepcopy(v)
1162+
for k, v in self._nuisance_models.items()
1163+
if k in self._prefitted_nuisance_models and k != PROPENSITY_MODEL
1164+
},
1165+
"fitted_propensity_model": (
1166+
deepcopy(self._nuisance_models.get(PROPENSITY_MODEL))
1167+
if PROPENSITY_MODEL in self._prefitted_nuisance_models
1168+
else None
1169+
),
1170+
"feature_set": self.feature_set,
1171+
"n_folds": self.n_folds,
1172+
"random_state": self.random_state,
1173+
}
1174+
11261175

11271176
class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):
11281177

tests/test_metalearner.py

+19
Original file line numberDiff line numberDiff line change
@@ -1119,3 +1119,22 @@ def test_validate_outcome_different_classes(implementation, use_pandas, rng):
11191119
ValueError, match="have seen different sets of classification outcomes."
11201120
):
11211121
ml.fit(X, y, w)
1122+
1123+
1124+
@pytest.mark.parametrize(
1125+
"implementation",
1126+
[TLearner, SLearner, XLearner, RLearner, DRLearner],
1127+
)
1128+
def test_init_args(implementation):
1129+
ml = implementation(
1130+
True,
1131+
2,
1132+
LogisticRegression,
1133+
LinearRegression,
1134+
LogisticRegression,
1135+
)
1136+
ml2 = implementation(**ml.init_args)
1137+
1138+
assert set(ml.__dict__.keys()) == set(ml2.__dict__.keys())
1139+
for key in ml.__dict__:
1140+
assert ml.__dict__[key] == ml2.__dict__[key]

0 commit comments

Comments
 (0)