@@ -35,6 +35,9 @@ class CIMissTest:
3535 Extra keyword arguments forwarded to the imputer.
3636 variance_method : str
3737 ``'mi_crossfit'`` (default) or ``'legacy_fold'``.
38+ subsample_cap : int or None
39+ Maximum number of rows to subsample for testing. Set to ``None``
40+ to disable subsampling (default ``2000``).
3841 """
3942
4043 def __init__ (
@@ -49,6 +52,7 @@ def __init__(
4952 random_state : int = 42 ,
5053 target_level : str = "variable" ,
5154 variance_method : str = "mi_crossfit" ,
55+ subsample_cap : int = 2000 ,
5256 ):
5357 self .dataset = dataset
5458 self .imputer = imputer
@@ -61,6 +65,7 @@ def __init__(
6165 self .rng = np .random .default_rng (random_state )
6266 self .target_level = target_level
6367 self .variance_method = variance_method
68+ self .subsample_cap = subsample_cap
6469
6570 def __repr__ (self ):
6671 return (
@@ -98,9 +103,9 @@ def run(self):
98103
99104 cv = self ._get_cv ()
100105
101- if self .dataset .miss_data .shape [0 ] > 2000 :
106+ if self .subsample_cap and self . dataset .miss_data .shape [0 ] > self . subsample_cap :
102107 sample_idxs = self .rng .choice (
103- self .dataset .miss_data .shape [0 ], size = 2000 , replace = False
108+ self .dataset .miss_data .shape [0 ], size = self . subsample_cap , replace = False
104109 )
105110 else :
106111 sample_idxs = np .arange (self .dataset .miss_data .shape [0 ])
@@ -461,9 +466,9 @@ def imputer_r2(self, mask_frac: float = 0.2, m_eval: int = 1) -> dict:
461466 """
462467 cv = self ._get_cv ()
463468
464- if self .dataset .miss_data .shape [0 ] > 2000 :
469+ if self .subsample_cap and self . dataset .miss_data .shape [0 ] > self . subsample_cap :
465470 sample_idxs = self .rng .choice (
466- self .dataset .miss_data .shape [0 ], size = 2000 , replace = False
471+ self .dataset .miss_data .shape [0 ], size = self . subsample_cap , replace = False
467472 )
468473 else :
469474 sample_idxs = np .arange (self .dataset .miss_data .shape [0 ])
0 commit comments