Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skore-hub-project/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ exclude_also = [
src = ["skore_hub_project"]

[tool.ruff.lint]
select = ["E", "F", "UP", "B", "C4", "SIM", "T", "I", "D"]
select = ["E", "F", "UP", "B", "C4", "SIM", "T", "I", "D", "RUF010", "RUF015"]

[tool.ruff.lint.pydocstyle]
convention = "numpy"
Expand Down
2 changes: 1 addition & 1 deletion skore-local-project/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ exclude_also = [
src = ["skore_local_project"]

[tool.ruff.lint]
select = ["E", "F", "UP", "B", "C4", "SIM", "T", "I", "D"]
select = ["E", "F", "UP", "B", "C4", "SIM", "T", "I", "D", "RUF010", "RUF015"]

[tool.ruff.lint.pydocstyle]
convention = "numpy"
Expand Down
2 changes: 1 addition & 1 deletion skore-local-project/src/skore_local_project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def ensure_project_is_not_deleted(method):
def wrapper(self, *args, **kwargs):
if self.name not in self._Project__projects_storage:
raise RuntimeError(
f"Skore could not proceed because {repr(self)} does not exist anymore."
f"Skore could not proceed because {self!r} does not exist anymore."
)

return method(self, *args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions skore/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ select = [
"I",
# pydocstyle
"D",
# Ruff-specific rules
"RUF010", "RUF015",
]

[tool.ruff.lint.pydocstyle]
Expand Down
2 changes: 1 addition & 1 deletion skore/src/skore/_sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max
)
self._cache: dict[tuple[Any, ...], Any] = {}
self._ml_task = list(self.reports_.values())[0]._ml_task # type: ignore
self._ml_task = next(iter(self.reports_.values()))._ml_task # type: ignore

def clear_cache(self) -> None:
"""Clear the cache.
Expand Down
2 changes: 1 addition & 1 deletion skore/src/skore/_sklearn/_plot/data/table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def _plot_distribution_2d(
ellide_string(s) for s in contingency_table.columns
]

if max_value := contingency_table.max(axis=None) < 100_000: # noqa: SIM108
if max_value := contingency_table.max(axis=None) < 100_000:
# avoid scientific notation for small numbers
annotation_format = (
".0f"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_data_source_external(
pd.testing.assert_index_equal(result.columns, expected_columns)

assert len(report._cache) == 1
cached_result = list(report._cache.values())[0]
cached_result = next(iter(report._cache.values()))
pd.testing.assert_index_equal(cached_result.index, expected_index)
pd.testing.assert_index_equal(cached_result.columns, expected_columns)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_binary_classification(

pos_label = 1
n_reports = len(report.reports_)
n_splits = len(list(report.reports_.values())[0].estimator_reports_)
n_splits = len(next(iter(report.reports_.values())).estimator_reports_)

display.plot()
assert isinstance(display.lines_, list)
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_multiclass_classification(

labels = display.precision_recall["label"].cat.categories
n_reports = len(report.reports_)
n_splits = len(list(report.reports_.values())[0].estimator_reports_)
n_splits = len(next(iter(report.reports_.values())).estimator_reports_)

display.plot()
assert isinstance(display.lines_, list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_multiclass_classification(
assert isinstance(display, PrecisionRecallCurveDisplay)
check_display_data(display)

class_labels = list(report.reports_.values())[0].estimator_.classes_
class_labels = next(iter(report.reports_.values())).estimator_.classes_

display.plot()
assert isinstance(display.lines_, list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_binary_classification(

pos_label = 1
n_reports = len(report.reports_)
n_splits = list(report.reports_.values())[0]._splitter.n_splits
n_splits = next(iter(report.reports_.values()))._splitter.n_splits

display.plot()
assert isinstance(display.lines_, list)
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_multiclass_classification(

labels = display.roc_curve["label"].unique()
n_reports = len(report.reports_)
n_splits = list(report.reports_.values())[0]._splitter.n_splits
n_splits = next(iter(report.reports_.values()))._splitter.n_splits

display.plot()
assert isinstance(display.lines_, list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_multiclass_classification(

assert isinstance(display, RocCurveDisplay)
check_display_data(display)
class_labels = list(report.reports_.values())[0].estimator_.classes_
class_labels = next(iter(report.reports_.values())).estimator_.classes_
assert (
list(display.roc_curve["label"].unique())
== list(display.roc_auc["label"].unique())
Expand Down
4 changes: 2 additions & 2 deletions skore/tests/unit/displays/table_report/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ def test_hue_plots_1d(pyplot, estimator_report):

def test_plot_duration_data_1d(pyplot, display):
"""Check the plot output with duration data in 1-d."""
## 1D - timedelta as x
## 1D - timedelta as x
display.plot(x="timedelta_hired")
assert display.ax_.get_xlabel() == "Years"

## 1D - timedelta as y
## 1D - timedelta as y
display.plot(y="timedelta_hired")
assert display.ax_.get_ylabel() == "Years"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_binary_classification(
):
"""Check the metrics work."""
report = comparison_estimator_reports_binary_classification
sub_report = list(report.reports_.values())[0]
sub_report = next(iter(report.reports_.values()))
X_test, y_test = sub_report.X_test, sub_report.y_test

# ensure metric is valid
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_regression(
):
"""Check the metrics work."""
comp = comparison_estimator_reports_regression
sub_report = list(comp.reports_.values())[0]
sub_report = next(iter(comp.reports_.values()))
X_test, y_test = sub_report.X_test, sub_report.y_test

# ensure metric is valid
Expand Down
2 changes: 1 addition & 1 deletion skore/tests/unit/reports/comparison/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_cross_validation_report_cleaned_up(report):
https://github.com/probabl-ai/skore/pull/1512
"""
report.metrics.summarize()
sub_report = list(report.reports_.values())[0]
sub_report = next(iter(report.reports_.values()))

with BytesIO() as stream:
joblib.dump(sub_report, stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_cache_seed_none(regression_data):
pd.testing.assert_frame_equal(importance_first_call, importance_second_call)
# the cache should contain the last result
assert len(report._cache) == 1
key = list(report._cache.keys())[0]
key = next(iter(report._cache.keys()))
pd.testing.assert_frame_equal(report._cache[key], importance_second_call)


Expand All @@ -265,7 +265,7 @@ def test_cache_seed_int(regression_data):
pd.testing.assert_frame_equal(importance_first_call, importance_second_call)
# the cache should contain the last result
assert len(report._cache) == 1
key = list(report._cache.keys())[0]
key = next(iter(report._cache.keys()))
pd.testing.assert_frame_equal(report._cache[key], importance_second_call)


Expand Down