Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import reduce
from typing import ClassVar, Literal, cast

import orjson

from skore_hub_project.artifact.media.media import Media, Report
from skore_hub_project.protocol import EstimatorReport

Expand All @@ -14,8 +16,6 @@ class FeatureImportance(Media[Report], ABC): # noqa: D101
content_type: Literal["application/vnd.dataframe"] = "application/vnd.dataframe"

def content_to_upload(self) -> bytes | None: # noqa: D102
import orjson

try:
function = cast(
Callable,
Expand All @@ -24,11 +24,7 @@ def content_to_upload(self) -> bytes | None: # noqa: D102
except AttributeError:
return None

result = (
function()
if self.data_source is None
else function(data_source=self.data_source)
)
result = function()

if hasattr(result, "frame"):
result = result.frame()
Expand All @@ -43,6 +39,29 @@ class Permutation(FeatureImportance[EstimatorReport], ABC): # noqa: D101
accessor: ClassVar[str] = "feature_importance.permutation"
name: Literal["permutation"] = "permutation"

def content_to_upload(self) -> bytes | None: # noqa: D102
for key, obj in reversed(self.report._cache.items()):
if len(key) < 7:
continue

if len(key) == 7:
parent_hash, metric, data_source, scoring, *_ = key
else:
parent_hash, metric, data_source, data_source_hash, scoring, *_ = key

if (
parent_hash == self.report._hash
and metric == "permutation_importance"
and data_source == self.data_source
and scoring is None
):
return orjson.dumps(
obj.fillna("NaN").to_dict(orient="tight"),
option=(orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY),
)

return None


class PermutationTrain(Permutation): # noqa: D101
data_source: Literal["train"] = "train"
Expand Down
2 changes: 2 additions & 0 deletions skore-hub-project/src/skore_hub_project/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class EstimatorReport(Protocol):
"""Protocol equivalent to ``skore.EstimatorReport``."""

_hash: int
clear_cache: Any
_cache: Any
metrics: Any
Expand All @@ -25,6 +26,7 @@ class EstimatorReport(Protocol):
class CrossValidationReport(Protocol):
"""Protocol equivalent to ``skore.CrossValidationReport``."""

_hash: int
clear_cache: Any
_cache: Any
metrics: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_feature_importance(
}

# unavailable accessor
report.clear_cache()
monkeypatch.delattr(report.feature_importance.__class__, accessor)
upload_mock.reset_mock()

Expand Down
4 changes: 4 additions & 0 deletions skore-hub-project/tests/unit/report/test_estimator_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def project():

@fixture
def payload(project, binary_classification):
# Force the compute of the permutations
binary_classification.feature_importance.permutation(data_source="train", seed=42)
binary_classification.feature_importance.permutation(data_source="test", seed=42)

return EstimatorReportPayload(
project=project,
report=binary_classification,
Expand Down