Skip to content

Commit 883cc6c

Browse files
refactor(progress_decorator): Make sure self_obj is a Report (#1536)
This originates from a bug when implementing comparison of CrossValidationReport in #1512 Because CrossValidationReports hold a `_parent_progress`, when a ComparisonReport creates a progress bar to iterate over CrossValidationReports, the outer progress bar conflicts with the inner progress bars, and rich refuses to proceed. The solution is for the ComparisonReport to explicitly set its inner CrossValidationReports' progress instance, so that in total there is only one progress instance. But before this change, the progress instance was sometimes owned by a `CrossValidationReport.metrics` accessor. This is a problem because accessors are re-instantiated whenever they are accessed, so their state cannot be modified from the parent. The solution this change implements is to remove all `progress`-related attributes from all accessors, and to ensure that the progress instance is only owned by the Report object, not by any of its accessors.
1 parent f3922e7 commit 883cc6c

File tree

3 files changed

+16
-20
lines changed

3 files changed

+16
-20
lines changed

skore/src/skore/sklearn/_comparison/metrics_accessor.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import pandas as pd
66
from numpy.typing import ArrayLike
7-
from rich.progress import Progress
87
from sklearn.metrics import make_scorer
98
from sklearn.metrics._scorer import _BaseScorer as SKLearnScorer
109
from sklearn.utils.metaestimators import available_if
@@ -48,9 +47,6 @@ class _MetricsAccessor(_BaseAccessor, DirNamesMixin):
4847
def __init__(self, parent: ComparisonReport) -> None:
4948
super().__init__(parent)
5049

51-
self._progress_info: Optional[dict[str, Any]] = None
52-
self._parent_progress: Optional[Progress] = None
53-
5450
def report_metrics(
5551
self,
5652
*,
@@ -198,9 +194,9 @@ def _compute_metric_scores(
198194

199195
cache_key = tuple(cache_key_parts)
200196

201-
assert self._progress_info is not None, "Progress info not set"
202-
progress = self._progress_info["current_progress"]
203-
main_task = self._progress_info["current_task"]
197+
assert self._parent._progress_info is not None, "Progress info not set"
198+
progress = self._parent._progress_info["current_progress"]
199+
main_task = self._parent._progress_info["current_task"]
204200

205201
total_estimators = len(self._parent.estimator_reports_)
206202
progress.update(main_task, total=total_estimators)
@@ -1272,9 +1268,9 @@ def _get_display(
12721268
cache_key_parts.append(data_source)
12731269
cache_key = tuple(cache_key_parts)
12741270

1275-
assert self._progress_info is not None, "Progress info not set"
1276-
progress = self._progress_info["current_progress"]
1277-
main_task = self._progress_info["current_task"]
1271+
assert self._parent._progress_info is not None, "Progress info not set"
1272+
progress = self._parent._progress_info["current_progress"]
1273+
main_task = self._parent._progress_info["current_task"]
12781274
total_estimators = len(self._parent.estimator_reports_)
12791275
progress.update(main_task, total=total_estimators)
12801276

skore/src/skore/sklearn/_cross_validation/metrics_accessor.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import joblib
44
import numpy as np
55
import pandas as pd
6-
from rich.progress import Progress
76
from sklearn.metrics import make_scorer
87
from sklearn.metrics._scorer import _BaseScorer as SKLearnScorer
98
from sklearn.utils.metaestimators import available_if
@@ -48,9 +47,6 @@ class _MetricsAccessor(_BaseAccessor["CrossValidationReport"], DirNamesMixin):
4847
def __init__(self, parent: CrossValidationReport) -> None:
4948
super().__init__(parent)
5049

51-
self._progress_info: Optional[dict[str, Any]] = None
52-
self._parent_progress: Optional[Progress] = None
53-
5450
def report_metrics(
5551
self,
5652
*,
@@ -173,9 +169,9 @@ def _compute_metric_scores(
173169
cache_key_parts.append(metric_kwargs[key])
174170
cache_key = tuple(cache_key_parts)
175171

176-
assert self._progress_info is not None, "Progress info not set"
177-
progress = self._progress_info["current_progress"]
178-
main_task = self._progress_info["current_task"]
172+
assert self._parent._progress_info is not None, "Progress info not set"
173+
progress = self._parent._progress_info["current_progress"]
174+
main_task = self._parent._progress_info["current_task"]
179175

180176
total_estimators = len(self._parent.estimator_reports_)
181177
progress.update(main_task, total=total_estimators)
@@ -979,9 +975,9 @@ def _get_display(
979975
cache_key_parts.append(data_source)
980976
cache_key = tuple(cache_key_parts)
981977

982-
assert self._progress_info is not None, "Progress info not set"
983-
progress = self._progress_info["current_progress"]
984-
main_task = self._progress_info["current_task"]
978+
assert self._parent._progress_info is not None, "Progress info not set"
979+
progress = self._parent._progress_info["current_progress"]
980+
main_task = self._parent._progress_info["current_task"]
985981
total_estimators = len(self._parent.estimator_reports_)
986982
progress.update(main_task, total=total_estimators)
987983

skore/src/skore/utils/_progress_bar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def decorator(func: Callable[..., T]) -> Callable[..., T]:
3939
def wrapper(*args: Any, **kwargs: Any) -> T:
4040
self_obj: Any = args[0]
4141

42+
if hasattr(self_obj, "_parent"):
43+
# self_obj is an accessor
44+
self_obj = self_obj._parent
45+
4246
desc = description(self_obj) if callable(description) else description
4347

4448
if getattr(self_obj, "_parent_progress", None) is not None:

0 commit comments

Comments
 (0)