diff --git a/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py b/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py index 9c094564e..a05706310 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py +++ b/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py @@ -5,7 +5,6 @@ from functools import reduce from typing import ClassVar, Literal, cast -from pandas import DataFrame from pydantic import Field from skore_hub_project.artifact.media.media import Media @@ -33,7 +32,7 @@ def content_to_upload(self) -> bytes | None: # noqa: D102 else function(data_source=self.data_source) ) - if not isinstance(result, DataFrame): + if hasattr(result, "frame"): result = result.frame() return orjson.dumps( diff --git a/skore-hub-project/src/skore_hub_project/metric/timing.py b/skore-hub-project/src/skore_hub_project/metric/timing.py index 462932ac7..921ab7ce9 100644 --- a/skore-hub-project/src/skore_hub_project/metric/timing.py +++ b/skore-hub-project/src/skore_hub_project/metric/timing.py @@ -5,7 +5,6 @@ from functools import cached_property from typing import ClassVar, Literal -from pandas import DataFrame, Series from pydantic import Field, computed_field from skore_hub_project.protocol import CrossValidationReport, EstimatorReport @@ -44,10 +43,10 @@ class FitTimeAggregate(Metric): # noqa: D101 @computed_field # type: ignore[prop-decorator] @cached_property def value(self) -> float | None: # noqa: D102 - timings: DataFrame = self.report.metrics.timings(aggregate=self.aggregate) + timings = self.report.metrics.timings(aggregate=self.aggregate) try: - fit_times: Series = timings.loc["Fit time (s)"] + fit_times = timings.loc["Fit time (s)"] except KeyError: return None @@ -106,10 +105,10 @@ class PredictTimeAggregate(Metric): # noqa: D101 @computed_field # type: ignore[prop-decorator] @cached_property def value(self) -> float | None: # noqa: D102 - timings: DataFrame = self.report.metrics.timings(aggregate=self.aggregate) + timings = self.report.metrics.timings(aggregate=self.aggregate) try: - predict_times: Series = timings.loc[f"Predict time {self.data_source} (s)"] + predict_times = timings.loc[f"Predict time {self.data_source} (s)"] except KeyError: return None diff --git a/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py b/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py index ae7fa4f9f..63a02ad2c 100644 --- a/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py +++ b/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py @@ -1,6 +1,5 @@ from functools import partialmethod -from pandas import DataFrame from pydantic import ValidationError from pytest import fixture, mark, param, raises from skore_hub_project import Project @@ -16,7 +15,7 @@ def serialize(result) -> bytes: import orjson - if not isinstance(result, DataFrame): + if hasattr(result, "frame"): result = result.frame() return orjson.dumps(