Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ six==1.17.0
# via python-dateutil
skops==0.13.0
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore==0.10.1
skore==0.10.2
# via skore-hub-project (skore-hub-project/pyproject.toml)
skore-local-project==0.0.3
# via skore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from inspect import signature
from typing import ClassVar, Literal, cast

from pandas import DataFrame
from pydantic import Field, computed_field
from skore import EstimatorReport
from skore import CrossValidationReport, EstimatorReport

from .media import Media, Representation


class FeatureImportance(Media): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
report: EstimatorReport | CrossValidationReport = Field(repr=False, exclude=True)
accessor: ClassVar[str]
category: Literal["feature_importance"] = "feature_importance"

Expand All @@ -32,15 +33,18 @@ def representation(self) -> Representation | None: # noqa: D102
k: v for k, v in self.attributes.items() if k in function_parameters
}

dataframe = function(**function_kwargs)
result = function(**function_kwargs)

if not isinstance(result, DataFrame):
result = result.frame()

serialized = result.fillna("NaN").to_dict(orient="tight")

return Representation(
media_type="application/vnd.dataframe",
value=dataframe.fillna("NaN").to_dict(orient="tight"),
)
return Representation(media_type="application/vnd.dataframe", value=serialized)


class Permutation(FeatureImportance): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
accessor: ClassVar[str] = "feature_importance.permutation"
key: str = "permutation"
verbose_name: str = "Feature importance - Permutation"
Expand All @@ -55,6 +59,7 @@ class PermutationTest(Permutation): # noqa: D101


class MeanDecreaseImpurity(FeatureImportance): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
accessor: ClassVar[str] = "feature_importance.mean_decrease_impurity"
key: str = "mean_decrease_impurity"
verbose_name: str = "Feature importance - Mean Decrease Impurity (MDI)"
Expand Down
24 changes: 17 additions & 7 deletions skore-hub-project/tests/unit/media/test_feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partialmethod
from inspect import signature

from pandas import DataFrame
from pydantic import ValidationError
from pytest import fixture, mark, param, raises
from skore_hub_project.media import (
Expand All @@ -11,10 +12,6 @@
)


def serialize(dataframe):
return dataframe.fillna("NaN").to_dict(orient="tight")


@fixture(autouse=True)
def monkeypatch_permutation(monkeypatch):
import skore
Expand Down Expand Up @@ -63,6 +60,14 @@ def monkeypatch_permutation(monkeypatch):
{"method": "coefficients"},
id="Coefficients",
),
param(
Coefficients,
"cv_regression",
"coefficients",
"Feature importance - Coefficients",
{"method": "coefficients"},
id="Coefficients",
),
),
)
def test_feature_importance(
Expand All @@ -79,8 +84,13 @@ def test_feature_importance(
function = getattr(report.feature_importance, accessor)
function_parameters = signature(function).parameters
function_kwargs = {k: v for k, v in attributes.items() if k in function_parameters}
dataframe = function(**function_kwargs)
dataframe_serialized = serialize(dataframe)

result = function(**function_kwargs)

if not isinstance(result, DataFrame):
result = result.frame()

serialized = result.fillna("NaN").to_dict(orient="tight")

# available accessor
assert Media(report=report).model_dump() == {
Expand All @@ -91,7 +101,7 @@ def test_feature_importance(
"parameters": {},
"representation": {
"media_type": "application/vnd.dataframe",
"value": dataframe_serialized,
"value": serialized,
},
}

Expand Down
Loading