11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
55
66import numpy as np
77import torch
@@ -24,10 +24,9 @@ def __init__(
2424 seed : int = 1 ,
2525 num_folds : int = 1 ,
2626 num_ensemble : int = 1 ,
27- classifier : str = "mlp" ,
27+ classifier : Union [ str , Type [ BaseEstimator ]] = MLPClassifier ,
2828 z_score : bool = False ,
29- clf_class : Optional [Any ] = None ,
30- clf_kwargs : Optional [Dict [str , Any ]] = None ,
29+ classifier_kwargs : Optional [Dict [str , Any ]] = None ,
3130 num_trials_null : int = 100 ,
3231 permutation : bool = True ,
3332 ) -> None :
@@ -71,10 +70,11 @@ def __init__(
7170 num_ensemble: Number of classifiers for ensembling, defaults to 1.
7271 This is useful to reduce variance coming from the classifier.
7372 z_score: Whether to z-score to normalize the data, defaults to False.
74- classifier: Classification architecture to use,
75- possible values: "random_forest" or "mlp", defaults to "mlp".
76- clf_class: Custom sklearn classifier class, defaults to None.
77- clf_kwargs: Custom kwargs for the sklearn classifier, defaults to None.
73+ classifier: Classification architecture to use, can be one of the following:
74+ - "random_forest" or "mlp", defaults to "mlp" or
75+ - A classifier class (e.g., RandomForestClassifier, MLPClassifier)
76+ classifier_kwargs: Custom kwargs for the sklearn classifier,
77+ defaults to None.
7878 num_trials_null: Number of trials to estimate the null distribution,
7979 defaults to 100.
8080 permutation: Whether to use the permutation method for the null hypothesis,
@@ -111,10 +111,26 @@ def __init__(
111111 self .num_ensemble = num_ensemble
112112
113113 # initialize classifier
114- if "mlp" in classifier .lower ():
115- ndim = thetas .shape [- 1 ]
116- self .clf_class = MLPClassifier
117- if clf_kwargs is None :
114+ if isinstance (classifier , str ):
115+ if classifier .lower () == 'mlp' :
116+ classifier = MLPClassifier
117+ elif classifier .lower () == 'random_forest' :
118+ classifier = RandomForestClassifier
119+ else :
120+ raise ValueError (
121+ f'Invalid classifier: "{ classifier } ".'
122+ 'Expected "mlp", "random_forest", '
123+ 'or a valid scikit-learn classifier class.'
124+ )
125+ assert issubclass (classifier , BaseEstimator ), (
126+ "classifier must be a subclass of sklearn's BaseEstimator"
127+ )
128+ self .clf_class = classifier
129+
130+ self .clf_kwargs = classifier_kwargs
131+ if self .clf_kwargs is None :
132+ if self .clf_class == MLPClassifier :
133+ ndim = thetas .shape [- 1 ]
118134 self .clf_kwargs = {
119135 "activation" : "relu" ,
120136 "hidden_layer_sizes" : (10 * ndim , 10 * ndim ),
@@ -123,19 +139,8 @@ def __init__(
123139 "early_stopping" : True ,
124140 "n_iter_no_change" : 50 ,
125141 }
126- elif "random_forest" in classifier .lower ():
127- self .clf_class = RandomForestClassifier
128- if clf_kwargs is None :
142+ else :
129143 self .clf_kwargs = {}
130- elif "custom" :
131- if clf_class is None or clf_kwargs is None :
132- raise ValueError (
133- "Please provide a valid sklearn classifier class and kwargs."
134- )
135- self .clf_class = clf_class
136- self .clf_kwargs = clf_kwargs
137- else :
138- raise NotImplementedError
139144
140145 # initialize classifiers, will be set after training
141146 self .trained_clfs = None
0 commit comments