88import shutil
99import tempfile
1010from pathlib import Path
11+ from typing import Any
1112
1213import catboost
1314import pandas as pd
@@ -117,8 +118,11 @@ def load(cls, path):
117118
118119 """
119120 dict_to_load = load_pickle (path )
120- sd = cls ()
121121 if isinstance (dict_to_load , dict ):
122+ df_current = dict_to_load ["df_current" ]
123+ df_baseline = dict_to_load ["df_baseline" ]
124+ sd = cls (df_current , df_baseline )
125+
122126 for attr , val in dict_to_load .items ():
123127 if isinstance (val , io .BytesIO ):
124128 setattr (sd , attr , pickle .load (val .seek (0 )))
@@ -132,15 +136,16 @@ def load(cls, path):
132136 raise ValueError ("pickle file must contain dictionary" )
133137 return sd
134138
139+ # FIXME: we should explicitly declare the type of supported deployed_model and encoding
135140 def __init__ (
136141 self ,
137- df_current = None ,
138- df_baseline = None ,
139- dataset_names = { "df_current" : "Current dataset" , "df_baseline" : "Baseline dataset" } ,
140- deployed_model = None ,
141- encoding = None ,
142- palette_name = "eurybia" ,
143- colors_dict = None ,
142+ df_current : pd . DataFrame ,
143+ df_baseline : pd . DataFrame ,
144+ dataset_names : dict [ str , str ] | None = None ,
145+ deployed_model : Any | None = None ,
146+ encoding : Any = None ,
147+ palette_name : str = "eurybia" ,
148+ colors_dict : dict | None = None ,
144149 ):
145150 """Parameters
146151 ----------
@@ -168,20 +173,23 @@ def __init__(
168173 """
169174 self .df_current = df_current
170175 self .df_baseline = df_baseline
171- self .xpl = None
172- self .df_predict = None
173- self .feature_importance = None
174- self .pb_cols , self .err_mods = None , None
175- self .auc = None
176- self .js_divergence = None
177- self .historical_auc = None
178- self .data_modeldrift = None
179- self .ignore_cols = list ()
180- self .datadrift_stat_test = None
181- if "df_current" not in dataset_names .keys () or "df_baseline" not in dataset_names .keys ():
176+ self .xpl : SmartExplainer | None = None
177+ self .df_predict : pd .DataFrame | None = None
178+ self .feature_importance : pd .DataFrame | None = None
179+ self .pb_cols : dict [str , list [str ]] = dict ()
180+ self .err_mods : dict [str , dict ] = dict ()
181+ self .auc : float | None = None
182+ self .js_divergence : float | None = None
183+ self .historical_auc : pd .DataFrame | None = None
184+ self .data_modeldrift : pd .DataFrame | None = None
185+ self .ignore_cols : list [str ] = list ()
186+ self .datadrift_stat_test : pd .DataFrame | None = None
187+ if dataset_names is None :
188+ dataset_names = {"df_current" : "Current dataset" , "df_baseline" : "Baseline dataset" }
189+ elif "df_current" not in dataset_names .keys () or "df_baseline" not in dataset_names .keys ():
182190 raise ValueError ("dataset_names must be a dictionnary with keys 'df_current' and 'df_baseline'" )
183191 self .dataset_names = pd .DataFrame (dataset_names , index = [0 ])
184- self ._df_concat = None
192+ self ._df_concat : pd . DataFrame | None = None
185193 self ._datadrift_target = "target"
186194 self .plot = SmartPlotter (self )
187195 self .deployed_model = deployed_model
@@ -191,18 +199,18 @@ def __init__(
191199 if colors_dict is not None :
192200 self .colors_dict .update (colors_dict )
193201 self .plot .define_style_attributes (colors_dict = self .colors_dict )
194- self .datadrift_file = None
202+ self .datadrift_file : str | None = None
195203
196204 def compile (
197205 self ,
198- full_validation = False ,
206+ full_validation : bool = False ,
199207 ignore_cols : list [str ] | None = None ,
200- sampling = True ,
201- sample_size = 100000 ,
202- datadrift_file = None ,
203- date_compile_auc = None ,
204- hyperparameter : dict = catboost_hyperparameter_init . copy () ,
205- attr_importance = "feature_importances_" ,
208+ sampling : bool = True ,
209+ sample_size : int = 100000 ,
210+ datadrift_file : str | None = None ,
211+ date_compile_auc : str | None = None ,
212+ hyperparameter : dict | None = None ,
213+ attr_importance : str = "feature_importances_" ,
206214 ):
207215 r"""The compile method is the first step to compute data drift.
208216 It allows to calculate data drift between 2 datasets using a data drift classification model.
@@ -319,7 +327,7 @@ def compile(
319327 self .feature_importance = self ._feature_importance (
320328 deployed_model = self .deployed_model , attr_importance = attr_importance
321329 )
322- self .plot .feature_importance = self .feature_importance
330+ # self.plot.feature_importance = self.feature_importance # FIXME: is this necessary?
323331 self .pb_cols , self .err_mods = pb_cols , err_mods
324332 if self .deployed_model is not None :
325333 self .js_divergence = compute_js_divergence (
@@ -341,7 +349,12 @@ def compile(
341349 self .datadrift_stat_test = self ._compute_datadrift_stat_test ()
342350
343351 def generate_report (
344- self , output_file , project_info_file = None , title_story = "Drift Report" , title_description = "" , working_dir = None
352+ self ,
353+ output_file : str ,
354+ project_info_file : str | None = None ,
355+ title_story : str = "Drift Report" ,
356+ title_description : str = "" ,
357+ working_dir : str | None = None ,
345358 ):
346359 """This method will generate an HTML report containing different information about the project.
347360 It allows the information compiled to be rendered.
@@ -390,7 +403,7 @@ def generate_report(
390403 if rm_working_dir :
391404 shutil .rmtree (working_dir )
392405
393- def _check_dataset (self , ignore_cols : list = list () ):
406+ def _check_dataset (self , ignore_cols : list | None = None ):
394407 """Method to check if datasets are correct before to be analysed and if
395408 it's not, try to modify them and informs the user. In worse case raise
396409 an error.
@@ -403,6 +416,9 @@ def _check_dataset(self, ignore_cols: list = list()):
403416 list of feature to ignore in compute
404417
405418 """
419+ if ignore_cols is None :
420+ ignore_cols = []
421+
406422 if len ([column for column in self .df_current .columns if is_datetime (self .df_current [column ])]) > 0 :
407423 if self .deployed_model is None :
408424 for col in [column for column in self .df_current .columns if is_datetime (self .df_current [column ])]:
@@ -553,7 +569,7 @@ def _predict(self, deployed_model=None, encoding=None):
553569 ]
554570 ).reset_index (drop = True )
555571
556- def _feature_importance (self , deployed_model = None , attr_importance = "feature_importances_" ):
572+ def _feature_importance (self , deployed_model : Any | None = None , attr_importance : str = "feature_importances_" ):
557573 """Create an attributes feature_importance with the computed score on both datasets
558574
559575 Parameters
@@ -582,6 +598,10 @@ def _feature_importance(self, deployed_model=None, attr_importance="feature_impo
582598 """
583599 + str (error )
584600 )
601+
602+ if self .xpl is None :
603+ raise RuntimeError ("SmartExplainer should be set at this point." )
604+
585605 feature_importance_drift = pd .DataFrame (
586606 self .xpl .features_imp [0 ].values , index = self .xpl .features_imp [0 ].index , columns = ["datadrift_classifier" ]
587607 )
@@ -602,7 +622,7 @@ def _feature_importance(self, deployed_model=None, attr_importance="feature_impo
602622 feature_importance ["deployed_model" ] = base_100 (feature_importance ["deployed_model" ])
603623 return feature_importance
604624
605- def _sampling (self , sampling , sample_size , dataset ):
625+ def _sampling (self , sampling : bool , sample_size : int , dataset : pd . DataFrame ):
606626 """Return a sampling from the original dataframe
607627
608628 Parameters
@@ -628,7 +648,8 @@ def _sampling(self, sampling, sample_size, dataset):
628648 else :
629649 return dataset
630650
631- def _histo_datadrift_metric (self , datadrift_file = None , date_compile_auc = None ):
651+ # FIXME: date_compile_auc should be of date format
652+ def _histo_datadrift_metric (self , datadrift_file : str | None = None , date_compile_auc : str | None = None ):
632653 """Method which computes datadrift metrics (AUC, and Jensen Shannon prediction divergence if the deployed_model
633654 is filled in) and append it into a dataframe that will be exported during the generate_report method
634655
@@ -648,6 +669,7 @@ def _histo_datadrift_metric(self, datadrift_file=None, date_compile_auc=None):
648669 logging .basicConfig (
649670 level = logging .INFO , format = "%(asctime)s %(levelname)s %(module)s: %(message)s" , datefmt = "%y/%m/%d %H:%M:%S"
650671 )
672+ # FIXME: is the use of instance attribute instead of the datadirft_file parameter an error?
651673 if self .datadrift_file is None and date_compile_auc is None :
652674 return None
653675 elif self .datadrift_file is None and date_compile_auc is not None :
@@ -698,7 +720,12 @@ def _histo_datadrift_metric(self, datadrift_file=None, date_compile_auc=None):
698720 return df_auc
699721
700722 def add_data_modeldrift (
701- self , dataset , metric = "performance" , reference_columns = [], year_col = "annee" , month_col = "mois"
723+ self ,
724+ dataset : pd .DataFrame ,
725+ metric : str = "performance" ,
726+ reference_columns : list [str ] | None = None ,
727+ year_col : str = "annee" ,
728+ month_col : str = "mois" ,
702729 ):
703730 """When method drift is specified, It will display in the report
704731 the several plots from a dataframe to analyse drift model from the deployed model.
@@ -719,6 +746,8 @@ def add_data_modeldrift(
719746 The column name of the month where the metric has been computed
720747
721748 """
749+ if reference_columns is None :
750+ reference_columns = []
722751 try :
723752 df_modeldrift = dataset .copy ()
724753 df_modeldrift [month_col ] = df_modeldrift [month_col ].apply (lambda row : str (row ).split ("." )[0 ])
@@ -741,7 +770,7 @@ def add_data_modeldrift(
741770 + str (error )
742771 )
743772
744- def _compute_datadrift_stat_test (self , max_size = 50000 , categ_max = 20 ):
773+ def _compute_datadrift_stat_test (self , max_size : int = 50000 , categ_max : int = 20 ):
745774 """Calculates all statistical tests to analyze the drift of each feature
746775
747776 Parameters
@@ -763,6 +792,9 @@ def _compute_datadrift_stat_test(self, max_size=50000, categ_max=20):
763792 current = self .df_current .sample (n = max_size ) if self .df_current .shape [0 ] > max_size else self .df_current
764793 test_results = {}
765794
795+ if self .xpl is None :
796+ raise RuntimeError ("SmartExplainer should be set at this point." )
797+
766798 # compute test for each feature
767799 for features , count in self .xpl .features_desc .items ():
768800 try :
@@ -782,10 +814,9 @@ def _compute_datadrift_stat_test(self, max_size=50000, categ_max=20):
782814
783815 return pd .DataFrame .from_dict (test_results , orient = "index" )
784816
785- def define_style (self , palette_name = None , colors_dict = None ):
817+ def define_style (self , palette_name : str | None = None , colors_dict : dict | None = None ):
786818 """The define_style function is a function that uses a palette or a dict
787- to define the different styles used in the different outputs
788- of eurybia
819+ to define the different styles used in the different outputs of Eurybia
789820
790821 Parameters
791822 ----------
@@ -803,9 +834,13 @@ def define_style(self, palette_name=None, colors_dict=None):
803834 new_colors_dict .update (colors_dict )
804835 self .colors_dict .update (new_colors_dict )
805836 self .plot .define_style_attributes (colors_dict = self .colors_dict )
837+
838+ if self .xpl is None :
839+ raise RuntimeError ("SmartExplainer should be set at this point." )
840+
806841 self .xpl .define_style (colors_dict = self .colors_dict )
807842
808- def save (self , path ):
843+ def save (self , path : str ):
809844 """Save method allows user to save SmartDrift object on disk
810845 using a pickle file.
811846 Save method can be useful: you don't have to recompile to display
0 commit comments