Skip to content

Commit 4b3530c

Browse files
convert some columns to categories
1 parent 0cb804c commit 4b3530c

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

skore/src/skore/sklearn/_plot/metrics/roc_curve.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def _plot_single_estimator(
216216
)
217217

218218
else: # multiclass-classification
219-
labels = self.roc_curve["label"].unique()
219+
labels = self.roc_curve["label"].cat.categories
220220
class_colors = sample_mpl_colormap(
221221
colormaps.get_cmap("tab10"), 10 if len(labels) < 10 else len(labels)
222222
)
@@ -309,7 +309,7 @@ def _plot_cross_validated_estimator(
309309
line_kwargs: dict[str, Any] = {}
310310

311311
if self.ml_task == "binary-classification":
312-
for split_idx in self.roc_curve["split_index"].unique():
312+
for split_idx in self.roc_curve["split_index"].cat.categories:
313313
roc_curve = self.roc_curve.query(
314314
f"label == {self.pos_label} & split_index == {split_idx}"
315315
)
@@ -338,7 +338,7 @@ def _plot_cross_validated_estimator(
338338
)
339339
else: # multiclass-classification
340340
info_pos_label = None # irrelevant for multiclass
341-
labels = self.roc_curve["label"].unique()
341+
labels = self.roc_curve["label"].cat.categories
342342
class_colors = sample_mpl_colormap(
343343
colormaps.get_cmap("tab10"), 10 if len(labels) < 10 else len(labels)
344344
)
@@ -347,7 +347,7 @@ def _plot_cross_validated_estimator(
347347
roc_auc = self.roc_auc.query(f"label == {class_label}")["roc_auc"]
348348
roc_curve_kwargs_class = roc_curve_kwargs[class_idx]
349349

350-
for split_idx in self.roc_curve["split_index"].unique():
350+
for split_idx in self.roc_curve["split_index"].cat.categories:
351351
roc_curve_label = self.roc_curve.query(
352352
f"label == {class_label} & split_index == {split_idx}"
353353
)
@@ -461,7 +461,7 @@ def _plot_comparison_estimator(
461461
)
462462
else: # multiclass-classification
463463
info_pos_label = None # irrelevant for multiclass
464-
labels = self.roc_curve["label"].unique()
464+
labels = self.roc_curve["label"].cat.categories
465465
class_colors = sample_mpl_colormap(
466466
colormaps.get_cmap("tab10"), 10 if len(labels) < 10 else len(labels)
467467
)
@@ -555,7 +555,7 @@ def _plot_comparison_cross_validation(
555555
line_kwargs: dict[str, Any] = {}
556556

557557
if self.ml_task == "binary-classification":
558-
labels = self.roc_curve["label"].unique()
558+
labels = self.roc_curve["label"].cat.categories
559559
colors = sample_mpl_colormap(
560560
colormaps.get_cmap("tab10"),
561561
10 if len(estimator_names) < 10 else len(estimator_names),
@@ -575,7 +575,9 @@ def _plot_comparison_cross_validation(
575575
line_kwargs, roc_curve_kwargs[report_idx]
576576
)
577577

578-
for split_index, segment in roc_curve.groupby("split_index"):
578+
for split_index, segment in roc_curve.groupby(
579+
"split_index", observed=True
580+
):
579581
if split_index == 0:
580582
label_kwargs = {
581583
"label": (
@@ -616,7 +618,7 @@ def _plot_comparison_cross_validation(
616618

617619
else: # multiclass-classification
618620
info_pos_label = None # irrelevant for multiclass
619-
labels = self.roc_curve["label"].unique()
621+
labels = self.roc_curve["label"].cat.categories
620622
colors = sample_mpl_colormap(
621623
colormaps.get_cmap("tab10"),
622624
10 if len(estimator_names) < 10 else len(estimator_names),
@@ -635,7 +637,9 @@ def _plot_comparison_cross_validation(
635637
f"label == {label} & estimator_name == '{estimator_name}'"
636638
)["roc_auc"]
637639

638-
for split_index, segment in roc_curve.groupby("split_index"):
640+
for split_index, segment in roc_curve.groupby(
641+
"split_index", observed=True
642+
):
639643
if split_index == 0:
640644
label_kwargs = {
641645
"label": (
@@ -740,7 +744,7 @@ def plot(
740744
self.report_type == "comparison-cross-validation"
741745
and self.ml_task == "multiclass-classification"
742746
):
743-
n_labels = len(self.roc_auc["label"].unique())
747+
n_labels = len(self.roc_auc["label"].cat.categories)
744748
self.figure_, self.ax_ = plt.subplots(ncols=n_labels)
745749
else:
746750
self.figure_, self.ax_ = plt.subplots()
@@ -762,31 +766,37 @@ def plot(
762766

763767
if self.report_type == "estimator":
764768
self.ax_, self.lines_, info_pos_label = self._plot_single_estimator(
765-
estimator_name=estimator_name or self.roc_auc["estimator_name"][0],
769+
estimator_name=(
770+
estimator_name
771+
or self.roc_auc["estimator_name"].cat.categories.item()
772+
),
766773
roc_curve_kwargs=roc_curve_kwargs,
767774
plot_chance_level=plot_chance_level,
768775
chance_level_kwargs=chance_level_kwargs,
769776
)
770777
elif self.report_type == "cross-validation":
771778
self.ax_, self.lines_, info_pos_label = (
772779
self._plot_cross_validated_estimator(
773-
estimator_name=estimator_name or self.roc_auc["estimator_name"][0],
780+
estimator_name=(
781+
estimator_name
782+
or self.roc_auc["estimator_name"].cat.categories.item()
783+
),
774784
roc_curve_kwargs=roc_curve_kwargs,
775785
plot_chance_level=plot_chance_level,
776786
chance_level_kwargs=chance_level_kwargs,
777787
)
778788
)
779789
elif self.report_type == "comparison-estimator":
780790
self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator(
781-
estimator_names=self.roc_auc["estimator_name"].unique(),
791+
estimator_names=self.roc_auc["estimator_name"].cat.categories,
782792
roc_curve_kwargs=roc_curve_kwargs,
783793
plot_chance_level=plot_chance_level,
784794
chance_level_kwargs=chance_level_kwargs,
785795
)
786796
elif self.report_type == "comparison-cross-validation":
787797
self.ax_, self.lines_, info_pos_label = (
788798
self._plot_comparison_cross_validation(
789-
estimator_names=self.roc_auc["estimator_name"].unique(),
799+
estimator_names=self.roc_auc["estimator_name"].cat.categories,
790800
roc_curve_kwargs=roc_curve_kwargs,
791801
plot_chance_level=plot_chance_level,
792802
chance_level_kwargs=chance_level_kwargs,
@@ -943,9 +953,15 @@ def _compute_data_for_display(
943953
}
944954
)
945955

956+
dtypes = {
957+
"estimator_name": "category",
958+
"split_index": "category",
959+
"label": "category",
960+
}
961+
946962
return cls(
947-
roc_curve=DataFrame.from_records(roc_curve_records),
948-
roc_auc=DataFrame.from_records(roc_auc_records),
963+
roc_curve=DataFrame.from_records(roc_curve_records).astype(dtypes),
964+
roc_auc=DataFrame.from_records(roc_auc_records).astype(dtypes),
949965
pos_label=pos_label_validated,
950966
data_source=data_source,
951967
ml_task=ml_task,

0 commit comments

Comments
 (0)