@@ -69,7 +69,7 @@ def _validate_framework(self):
6969 except :
7070 message = """Provided model is from unsupported framework.
7171 Lens behavior has not been tested or assured with unsupported modeling frameworks."""
72- global_logger .warning (message , message )
72+ global_logger .warning (message )
7373
7474 def __post_init__ (self ):
7575 """Conditionally updates functionality based on framework"""
@@ -92,9 +92,15 @@ def __post_init__(self):
9292 if self .model_like .layers [- 1 ].output_shape == (None , 1 ):
9393 # Assumes sigmoid -> probabilities need to be rounded
9494 self .__dict__ ["predict" ] = lambda x : pred_func (x ).round ()
95+ # Single-output sigmoid is binary by definition
96+ self .type = "BINARY_CLASSIFICATION"
9597 else :
9698 # Assumes softmax -> probabilities need to be argmaxed
9799 self .__dict__ ["predict" ] = lambda x : np .argmax (pred_func (x ), axis = 1 )
100+ if self .model_like .layers [- 1 ].output_shape [1 ] == 2 :
101+ self .type = "BINARY_CLASSIFICATION"
102+ else :
103+ self .type = "MULTICLASS_CLASSIFICATION"
98104
99105 if self .model_like .layers [- 1 ].output_shape == (None , 2 ):
100106 self .__dict__ ["predict_proba" ] = lambda x : pred_func (x )[:, 1 ]
@@ -117,11 +123,16 @@ def __post_init__(self):
117123
118124 elif self .model_info ["framework" ] == "credoai" :
119125 # Functionality for DummyClassifier
120- self .model_like = getattr (self .model_like , "model_like" , None )
126+ if self .model_like .model_like is not None :
127+ self .model_like = self .model_like .model_like
121128 # If the dummy model has a model_like specified, reassign
122129 # the classifier's model_like attribute to match the dummy's
123130 # so that downstream evaluators (ModelProfiler) can use it
124131
132+ self .type = self .model_like .type
133+ # DummyClassifier model type is set in the constructor based on whether it
134+ # is binary or multiclass
135+
125136 # Predict and Predict_Proba should already be specified
126137
127138
@@ -141,6 +152,13 @@ class DummyClassifier:
141152 model_like : model_like, optional
142153 While predictions are pre-computed, the model object, itself, may be of use for
143154 some evaluations (e.g. ModelProfiler).
155+ binary_clf : bool, optional, default = True
156+ Type of classification model.
157+ Used when wrapping with ClassificationModel.
158+ If binary == True, ClassificationModel.type will be set to `BINARY_CLASSIFICATION',
159+ which enables use of binary metrics.
160+ If binary == False, ClassificationModel.type will be set to 'MULTICLASS_CLASSIFICATION',
161+ and use those metrics.
144162 predict_output : array, optional
145163 Array containing per-sample class labels
146164 Corresponds to sklearn-like `predict` output
@@ -158,6 +176,7 @@ def __init__(
158176 self ,
159177 name : str ,
160178 model_like = None ,
179+ binary_clf = True ,
161180 predict_output = None ,
162181 predict_proba_output = None ,
163182 tags = None ,
@@ -167,6 +186,9 @@ def __init__(
167186 self ._build_functionality ("predict_proba" , predict_proba_output )
168187 self .name = name
169188 self .tags = tags
189+ self .type = (
190+ "BINARY_CLASSIFICATION" if binary_clf else "MULTICLASS_CLASSIFICATION"
191+ )
170192
171193 def _wrap_array (self , array ):
172194 return lambda X = None : array
0 commit comments