Skip to content

Commit 15c828d

Browse files
committed
smartdrift type hints
1 parent e905504 commit 15c828d

File tree

3 files changed

+96
-50
lines changed

3 files changed

+96
-50
lines changed

eurybia/core/smartdrift.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import shutil
99
import tempfile
1010
from pathlib import Path
11+
from typing import Any
1112

1213
import catboost
1314
import pandas as pd
@@ -117,8 +118,11 @@ def load(cls, path):
117118
118119
"""
119120
dict_to_load = load_pickle(path)
120-
sd = cls()
121121
if isinstance(dict_to_load, dict):
122+
df_current = dict_to_load["df_current"]
123+
df_baseline = dict_to_load["df_baseline"]
124+
sd = cls(df_current, df_baseline)
125+
122126
for attr, val in dict_to_load.items():
123127
if isinstance(val, io.BytesIO):
124128
setattr(sd, attr, pickle.load(val.seek(0)))
@@ -132,15 +136,16 @@ def load(cls, path):
132136
raise ValueError("pickle file must contain dictionary")
133137
return sd
134138

139+
# FIXME: we should explicitly declare the type of supported deployed_model and encoding
135140
def __init__(
136141
self,
137-
df_current=None,
138-
df_baseline=None,
139-
dataset_names={"df_current": "Current dataset", "df_baseline": "Baseline dataset"},
140-
deployed_model=None,
141-
encoding=None,
142-
palette_name="eurybia",
143-
colors_dict=None,
142+
df_current: pd.DataFrame,
143+
df_baseline: pd.DataFrame,
144+
dataset_names: dict[str, str] | None = None,
145+
deployed_model: Any | None = None,
146+
encoding: Any = None,
147+
palette_name: str = "eurybia",
148+
colors_dict: dict | None = None,
144149
):
145150
"""Parameters
146151
----------
@@ -168,20 +173,23 @@ def __init__(
168173
"""
169174
self.df_current = df_current
170175
self.df_baseline = df_baseline
171-
self.xpl = None
172-
self.df_predict = None
173-
self.feature_importance = None
174-
self.pb_cols, self.err_mods = None, None
175-
self.auc = None
176-
self.js_divergence = None
177-
self.historical_auc = None
178-
self.data_modeldrift = None
179-
self.ignore_cols = list()
180-
self.datadrift_stat_test = None
181-
if "df_current" not in dataset_names.keys() or "df_baseline" not in dataset_names.keys():
176+
self.xpl: SmartExplainer | None = None
177+
self.df_predict: pd.DataFrame | None = None
178+
self.feature_importance: pd.DataFrame | None = None
179+
self.pb_cols: dict[str, list[str]] = dict()
180+
self.err_mods: dict[str, dict] = dict()
181+
self.auc: float | None = None
182+
self.js_divergence: float | None = None
183+
self.historical_auc: pd.DataFrame | None = None
184+
self.data_modeldrift: pd.DataFrame | None = None
185+
self.ignore_cols: list[str] = list()
186+
self.datadrift_stat_test: pd.DataFrame | None = None
187+
if dataset_names is None:
188+
dataset_names = {"df_current": "Current dataset", "df_baseline": "Baseline dataset"}
189+
elif "df_current" not in dataset_names.keys() or "df_baseline" not in dataset_names.keys():
182190
raise ValueError("dataset_names must be a dictionnary with keys 'df_current' and 'df_baseline'")
183191
self.dataset_names = pd.DataFrame(dataset_names, index=[0])
184-
self._df_concat = None
192+
self._df_concat: pd.DataFrame | None = None
185193
self._datadrift_target = "target"
186194
self.plot = SmartPlotter(self)
187195
self.deployed_model = deployed_model
@@ -191,18 +199,18 @@ def __init__(
191199
if colors_dict is not None:
192200
self.colors_dict.update(colors_dict)
193201
self.plot.define_style_attributes(colors_dict=self.colors_dict)
194-
self.datadrift_file = None
202+
self.datadrift_file: str | None = None
195203

196204
def compile(
197205
self,
198-
full_validation=False,
206+
full_validation: bool = False,
199207
ignore_cols: list[str] | None = None,
200-
sampling=True,
201-
sample_size=100000,
202-
datadrift_file=None,
203-
date_compile_auc=None,
204-
hyperparameter: dict = catboost_hyperparameter_init.copy(),
205-
attr_importance="feature_importances_",
208+
sampling: bool = True,
209+
sample_size: int = 100000,
210+
datadrift_file: str | None = None,
211+
date_compile_auc: str | None = None,
212+
hyperparameter: dict | None = None,
213+
attr_importance: str = "feature_importances_",
206214
):
207215
r"""The compile method is the first step to compute data drift.
208216
It allows to calculate data drift between 2 datasets using a data drift classification model.
@@ -319,7 +327,7 @@ def compile(
319327
self.feature_importance = self._feature_importance(
320328
deployed_model=self.deployed_model, attr_importance=attr_importance
321329
)
322-
self.plot.feature_importance = self.feature_importance
330+
# self.plot.feature_importance = self.feature_importance # FIXME: is this necessary?
323331
self.pb_cols, self.err_mods = pb_cols, err_mods
324332
if self.deployed_model is not None:
325333
self.js_divergence = compute_js_divergence(
@@ -341,7 +349,12 @@ def compile(
341349
self.datadrift_stat_test = self._compute_datadrift_stat_test()
342350

343351
def generate_report(
344-
self, output_file, project_info_file=None, title_story="Drift Report", title_description="", working_dir=None
352+
self,
353+
output_file: str,
354+
project_info_file: str | None = None,
355+
title_story: str = "Drift Report",
356+
title_description: str = "",
357+
working_dir: str | None = None,
345358
):
346359
"""This method will generate an HTML report containing different information about the project.
347360
It allows the information compiled to be rendered.
@@ -390,7 +403,7 @@ def generate_report(
390403
if rm_working_dir:
391404
shutil.rmtree(working_dir)
392405

393-
def _check_dataset(self, ignore_cols: list = list()):
406+
def _check_dataset(self, ignore_cols: list | None = None):
394407
"""Method to check if datasets are correct before to be analysed and if
395408
it's not, try to modify them and informs the user. In worse case raise
396409
an error.
@@ -403,6 +416,9 @@ def _check_dataset(self, ignore_cols: list = list()):
403416
list of feature to ignore in compute
404417
405418
"""
419+
if ignore_cols is None:
420+
ignore_cols = []
421+
406422
if len([column for column in self.df_current.columns if is_datetime(self.df_current[column])]) > 0:
407423
if self.deployed_model is None:
408424
for col in [column for column in self.df_current.columns if is_datetime(self.df_current[column])]:
@@ -553,7 +569,7 @@ def _predict(self, deployed_model=None, encoding=None):
553569
]
554570
).reset_index(drop=True)
555571

556-
def _feature_importance(self, deployed_model=None, attr_importance="feature_importances_"):
572+
def _feature_importance(self, deployed_model: Any | None = None, attr_importance: str = "feature_importances_"):
557573
"""Create an attributes feature_importance with the computed score on both datasets
558574
559575
Parameters
@@ -582,6 +598,10 @@ def _feature_importance(self, deployed_model=None, attr_importance="feature_impo
582598
"""
583599
+ str(error)
584600
)
601+
602+
if self.xpl is None:
603+
raise RuntimeError("SmartExplainer should be set at this point.")
604+
585605
feature_importance_drift = pd.DataFrame(
586606
self.xpl.features_imp[0].values, index=self.xpl.features_imp[0].index, columns=["datadrift_classifier"]
587607
)
@@ -602,7 +622,7 @@ def _feature_importance(self, deployed_model=None, attr_importance="feature_impo
602622
feature_importance["deployed_model"] = base_100(feature_importance["deployed_model"])
603623
return feature_importance
604624

605-
def _sampling(self, sampling, sample_size, dataset):
625+
def _sampling(self, sampling: bool, sample_size: int, dataset: pd.DataFrame):
606626
"""Return a sampling from the original dataframe
607627
608628
Parameters
@@ -628,7 +648,8 @@ def _sampling(self, sampling, sample_size, dataset):
628648
else:
629649
return dataset
630650

631-
def _histo_datadrift_metric(self, datadrift_file=None, date_compile_auc=None):
651+
# FIXME: date_compile_auc should be of date format
652+
def _histo_datadrift_metric(self, datadrift_file: str | None = None, date_compile_auc: str | None = None):
632653
"""Method which computes datadrift metrics (AUC, and Jensen Shannon prediction divergence if the deployed_model
633654
is filled in) and append it into a dataframe that will be exported during the generate_report method
634655
@@ -648,6 +669,7 @@ def _histo_datadrift_metric(self, datadrift_file=None, date_compile_auc=None):
648669
logging.basicConfig(
649670
level=logging.INFO, format="%(asctime)s %(levelname)s %(module)s: %(message)s", datefmt="%y/%m/%d %H:%M:%S"
650671
)
672+
# FIXME: is the use of instance attribute instead of the datadirft_file parameter an error?
651673
if self.datadrift_file is None and date_compile_auc is None:
652674
return None
653675
elif self.datadrift_file is None and date_compile_auc is not None:
@@ -698,7 +720,12 @@ def _histo_datadrift_metric(self, datadrift_file=None, date_compile_auc=None):
698720
return df_auc
699721

700722
def add_data_modeldrift(
701-
self, dataset, metric="performance", reference_columns=[], year_col="annee", month_col="mois"
723+
self,
724+
dataset: pd.DataFrame,
725+
metric: str = "performance",
726+
reference_columns: list[str] | None = None,
727+
year_col: str = "annee",
728+
month_col: str = "mois",
702729
):
703730
"""When method drift is specified, It will display in the report
704731
the several plots from a dataframe to analyse drift model from the deployed model.
@@ -719,6 +746,8 @@ def add_data_modeldrift(
719746
The column name of the month where the metric has been computed
720747
721748
"""
749+
if reference_columns is None:
750+
reference_columns = []
722751
try:
723752
df_modeldrift = dataset.copy()
724753
df_modeldrift[month_col] = df_modeldrift[month_col].apply(lambda row: str(row).split(".")[0])
@@ -741,7 +770,7 @@ def add_data_modeldrift(
741770
+ str(error)
742771
)
743772

744-
def _compute_datadrift_stat_test(self, max_size=50000, categ_max=20):
773+
def _compute_datadrift_stat_test(self, max_size: int = 50000, categ_max: int = 20):
745774
"""Calculates all statistical tests to analyze the drift of each feature
746775
747776
Parameters
@@ -763,6 +792,9 @@ def _compute_datadrift_stat_test(self, max_size=50000, categ_max=20):
763792
current = self.df_current.sample(n=max_size) if self.df_current.shape[0] > max_size else self.df_current
764793
test_results = {}
765794

795+
if self.xpl is None:
796+
raise RuntimeError("SmartExplainer should be set at this point.")
797+
766798
# compute test for each feature
767799
for features, count in self.xpl.features_desc.items():
768800
try:
@@ -782,10 +814,9 @@ def _compute_datadrift_stat_test(self, max_size=50000, categ_max=20):
782814

783815
return pd.DataFrame.from_dict(test_results, orient="index")
784816

785-
def define_style(self, palette_name=None, colors_dict=None):
817+
def define_style(self, palette_name: str | None = None, colors_dict: dict | None = None):
786818
"""The define_style function is a function that uses a palette or a dict
787-
to define the different styles used in the different outputs
788-
of eurybia
819+
to define the different styles used in the different outputs of Eurybia
789820
790821
Parameters
791822
----------
@@ -803,9 +834,13 @@ def define_style(self, palette_name=None, colors_dict=None):
803834
new_colors_dict.update(colors_dict)
804835
self.colors_dict.update(new_colors_dict)
805836
self.plot.define_style_attributes(colors_dict=self.colors_dict)
837+
838+
if self.xpl is None:
839+
raise RuntimeError("SmartExplainer should be set at this point.")
840+
806841
self.xpl.define_style(colors_dict=self.colors_dict)
807842

808-
def save(self, path):
843+
def save(self, path: str):
809844
"""Save method allows user to save SmartDrift object on disk
810845
using a pickle file.
811846
Save method can be useful: you don't have to recompile to display

eurybia/report/generation.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
pn.extension("plotly")
1919

2020

21-
def get_index_panel(dr: DriftReport, project_info_file: str, config_report: dict | None) -> pn.Column:
21+
def get_index_panel(
22+
dr: DriftReport, project_info_file: str | None = None, config_report: dict | None = None
23+
) -> pn.Column:
2224
parts = []
2325
header_logo = pn.pane.PNG(
2426
"https://eurybia.readthedocs.io/en/latest/_images/eurybia-fond-clair.png?raw=true",
@@ -46,6 +48,8 @@ def get_index_panel(dr: DriftReport, project_info_file: str, config_report: dict
4648
content = pn.pane.Markdown("\n".join(content_parts))
4749
parts.append(content)
4850

51+
if dr.smartdrift.auc is None:
52+
raise RuntimeError("AUC should have been set.")
4953
# AUC
5054
auc_block = dr.smartdrift.plot.generate_indicator(
5155
fig_value=dr.smartdrift.auc, height=280, width=500, title="Datadrift classifier AUC"
@@ -54,6 +58,9 @@ def get_index_panel(dr: DriftReport, project_info_file: str, config_report: dict
5458

5559
# Jensen-Shannon
5660
if dr.smartdrift.deployed_model is not None:
61+
if dr.smartdrift.js_divergence is None:
62+
raise RuntimeError("Jensen-Shannon divergence should have been set.")
63+
5764
JS_block = dr.smartdrift.plot.generate_indicator(
5865
fig_value=dr.smartdrift.js_divergence,
5966
height=280,
@@ -204,6 +211,10 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
204211
pn.pane.Markdown("### Datadrift classifier model perfomances"),
205212
pn.pane.Markdown(report_text["Data drift"]["02"]),
206213
]
214+
215+
if dr.smartdrift.auc is None:
216+
raise RuntimeError("AUC should have been set.")
217+
207218
auc = dr.smartdrift.plot.generate_indicator(
208219
fig_value=dr.smartdrift.auc, height=300, width=500, title="Datadrift classifier AUC"
209220
)
@@ -255,6 +266,10 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
255266
pn.pane.Plotly(fig_01),
256267
pn.pane.Markdown(report_text["Data drift"]["08"]),
257268
]
269+
270+
if dr.smartdrift.js_divergence is None:
271+
raise RuntimeError("Jensen-Shannon divergence should have been set.")
272+
258273
js_fig = dr.smartdrift.plot.generate_indicator(
259274
fig_value=dr.smartdrift.js_divergence,
260275
height=280,
@@ -336,9 +351,9 @@ def get_model_drift_panel(dr: DriftReport) -> pn.Column:
336351
def execute_report(
337352
smartdrift: SmartDrift,
338353
explainer: SmartExplainer,
339-
project_info_file: str,
340354
output_file: str,
341-
config_report: dict | None = {},
355+
project_info_file: str | None = None,
356+
config_report: dict | None = None,
342357
) -> None:
343358
"""Creates the report
344359
@@ -356,6 +371,9 @@ def execute_report(
356371
Path to the HTML file to write
357372
358373
"""
374+
if config_report is None:
375+
config_report = {}
376+
359377
dr = DriftReport(
360378
smartdrift=smartdrift,
361379
explainer=explainer,

tests/unit_tests/core/test_smartdrift.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,6 @@ def setUp(self):
5555
self.script_path = script_path
5656
self.X = X
5757

58-
def test_init_1(self):
59-
"""
60-
test init 1 SmartDrift
61-
"""
62-
smart_drift = SmartDrift()
63-
assert hasattr(smart_drift, "df_baseline")
64-
6558
def test_compile_nooptions(self):
6659
"""
6760
Test compile()

0 commit comments

Comments
 (0)