Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import reduce
from typing import ClassVar, Literal, cast

from pandas import DataFrame
from pydantic import Field

from skore_hub_project.artifact.media.media import Media
Expand Down Expand Up @@ -33,7 +32,7 @@ def content_to_upload(self) -> bytes | None: # noqa: D102
else function(data_source=self.data_source)
)

if not isinstance(result, DataFrame):
if hasattr(result, "frame"):
result = result.frame()

return orjson.dumps(
Expand Down
9 changes: 4 additions & 5 deletions skore-hub-project/src/skore_hub_project/metric/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import cached_property
from typing import ClassVar, Literal

from pandas import DataFrame, Series
from pydantic import Field, computed_field

from skore_hub_project.protocol import CrossValidationReport, EstimatorReport
Expand Down Expand Up @@ -44,10 +43,10 @@ class FitTimeAggregate(Metric): # noqa: D101
@computed_field # type: ignore[prop-decorator]
@cached_property
def value(self) -> float | None: # noqa: D102
timings: DataFrame = self.report.metrics.timings(aggregate=self.aggregate)
timings = self.report.metrics.timings(aggregate=self.aggregate)

try:
fit_times: Series = timings.loc["Fit time (s)"]
fit_times = timings.loc["Fit time (s)"]
except KeyError:
return None

Expand Down Expand Up @@ -106,10 +105,10 @@ class PredictTimeAggregate(Metric): # noqa: D101
@computed_field # type: ignore[prop-decorator]
@cached_property
def value(self) -> float | None: # noqa: D102
timings: DataFrame = self.report.metrics.timings(aggregate=self.aggregate)
timings = self.report.metrics.timings(aggregate=self.aggregate)

try:
predict_times: Series = timings.loc[f"Predict time {self.data_source} (s)"]
predict_times = timings.loc[f"Predict time {self.data_source} (s)"]
except KeyError:
return None

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from functools import partialmethod

from pandas import DataFrame
from pydantic import ValidationError
from pytest import fixture, mark, param, raises
from skore_hub_project import Project
Expand All @@ -16,7 +15,7 @@
def serialize(result) -> bytes:
import orjson

if not isinstance(result, DataFrame):
if hasattr(result, "frame"):
result = result.frame()

return orjson.dumps(
Expand Down