Skip to content

Commit da3c09d

Browse files
committed
add subsample parameter
1 parent ea965e8 commit da3c09d

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

citest/test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class CIMissTest:
3535
Extra keyword arguments forwarded to the imputer.
3636
variance_method : str
3737
``'mi_crossfit'`` (default) or ``'legacy_fold'``.
38+
subsample_cap : int or None
39+
Maximum number of rows to subsample for testing. Set to ``None``
40+
to disable subsampling (default ``2000``).
3841
"""
3942

4043
def __init__(
@@ -49,6 +52,7 @@ def __init__(
4952
random_state: int = 42,
5053
target_level: str = "variable",
5154
variance_method: str = "mi_crossfit",
55+
subsample_cap: int = 2000,
5256
):
5357
self.dataset = dataset
5458
self.imputer = imputer
@@ -61,6 +65,7 @@ def __init__(
6165
self.rng = np.random.default_rng(random_state)
6266
self.target_level = target_level
6367
self.variance_method = variance_method
68+
self.subsample_cap = subsample_cap
6469

6570
def __repr__(self):
6671
return (
@@ -98,9 +103,9 @@ def run(self):
98103

99104
cv = self._get_cv()
100105

101-
if self.dataset.miss_data.shape[0] > 2000:
106+
if self.subsample_cap and self.dataset.miss_data.shape[0] > self.subsample_cap:
102107
sample_idxs = self.rng.choice(
103-
self.dataset.miss_data.shape[0], size=2000, replace=False
108+
self.dataset.miss_data.shape[0], size=self.subsample_cap, replace=False
104109
)
105110
else:
106111
sample_idxs = np.arange(self.dataset.miss_data.shape[0])
@@ -461,9 +466,9 @@ def imputer_r2(self, mask_frac: float = 0.2, m_eval: int = 1) -> dict:
461466
"""
462467
cv = self._get_cv()
463468

464-
if self.dataset.miss_data.shape[0] > 2000:
469+
if self.subsample_cap and self.dataset.miss_data.shape[0] > self.subsample_cap:
465470
sample_idxs = self.rng.choice(
466-
self.dataset.miss_data.shape[0], size=2000, replace=False
471+
self.dataset.miss_data.shape[0], size=self.subsample_cap, replace=False
467472
)
468473
else:
469474
sample_idxs = np.arange(self.dataset.miss_data.shape[0])

0 commit comments

Comments
 (0)