Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions skore-hub-project/src/skore_hub_project/media/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,36 @@

from functools import cached_property
from inspect import signature
from json import loads
from typing import Literal

import numpy as np
from pydantic import Field, computed_field

from skore_hub_project import switch_mpl_backend
from skore_hub_project.protocol import EstimatorReport

from .media import Media, Representation


def _to_native(obj):
"""Walk an object and cast all numpy types to native type.

Useful for json serialization.
"""
if isinstance(obj, dict):
return {k: _to_native(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_to_native(v) for v in obj]
elif isinstance(obj, tuple):
return tuple(_to_native(v) for v in obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.generic):
return obj.item()
else:
return obj


class TableReport(Media): # noqa: D101
report: EstimatorReport = Field(repr=False, exclude=True)
key: str = "table_report"
Expand All @@ -21,20 +41,27 @@ class TableReport(Media): # noqa: D101
@computed_field # type: ignore[prop-decorator]
@cached_property
def representation(self) -> Representation: # noqa: D102
function = self.report.data.analyze
function_parameters = signature(function).parameters
function_kwargs = {
k: v for k, v in self.attributes.items() if k in function_parameters
}

table_report_display = function(**function_kwargs)
table_report_json_str = table_report_display._to_json()
table_report = loads(table_report_json_str)

return Representation(
media_type="application/vnd.skrub.table-report.v1+json",
value=table_report,
)
with switch_mpl_backend():
function = self.report.data.analyze
function_parameters = signature(function).parameters
function_kwargs = {
k: v for k, v in self.attributes.items() if k in function_parameters
}

table_report_display = function(**function_kwargs)
table_report = table_report_display.summary

table_report["extract"] = (
table_report["dataframe"].head(3).to_dict(orient="split")
)

del table_report["dataframe"]
del table_report["sample_table"]

return Representation(
media_type="application/vnd.skrub.table-report.v1+json",
value=_to_native(table_report),
)


class TableReportTrain(TableReport): # noqa: D101
Expand Down
23 changes: 14 additions & 9 deletions skore-hub-project/tests/unit/media/test_data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from pydantic import ValidationError
from pytest import fixture, mark, param, raises
from pytest import mark, param, raises
from skore_hub_project.media import TableReportTest, TableReportTrain


@fixture(autouse=True)
def monkeypatch_to_json(monkeypatch):
monkeypatch.setattr(
"skore._sklearn._plot.TableReportDisplay._to_json", lambda self: "[0,1]"
)


@mark.parametrize(
"Media,data_source",
(
Expand All @@ -21,6 +14,8 @@ def test_table_report(binary_classification, Media, data_source):
media = Media(report=binary_classification)
media_dict = media.model_dump()

representation_value = media_dict["representation"]["value"]
media_dict["representation"]["value"] = {}
assert media_dict == {
"key": "table_report",
"verbose_name": "Table report",
Expand All @@ -29,9 +24,19 @@ def test_table_report(binary_classification, Media, data_source):
"parameters": {},
"representation": {
"media_type": "application/vnd.skrub.table-report.v1+json",
"value": [0, 1],
"value": {},
},
}
assert set(
[
"n_rows",
"n_columns",
"n_constant_columns",
"extract",
"columns",
"top_associations",
]
).issubset(representation_value.keys())

# wrong type
with raises(ValidationError):
Expand Down
4 changes: 2 additions & 2 deletions skore-hub-project/tests/unit/project/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def monkeypatch_permutation(monkeypatch):


@fixture(autouse=True)
def monkeypatch_to_json(monkeypatch):
def monkeypatch_table_report_representation(monkeypatch):
monkeypatch.setattr(
"skore._sklearn._plot.TableReportDisplay._to_json", lambda self: "[0,1]"
"skore_hub_project.media.data.TableReport.representation", lambda self: {}
)


Expand Down
11 changes: 0 additions & 11 deletions skore/src/skore/_sklearn/_plot/data/table_report.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any, Literal

import numpy as np
Expand All @@ -11,7 +10,6 @@
from skrub._reporting._html import to_html
from skrub._reporting._summarize import summarize_dataframe
from skrub._reporting._utils import (
JSONEncoder,
duration_to_numeric,
ellide_string,
top_k_value_counts,
Expand Down Expand Up @@ -689,12 +687,3 @@ def _html_repr(self) -> str:

def __repr__(self) -> str:
return f"<{self.__class__.__name__}(...)>"

def _to_json(self) -> str:
"""Serialize the data of this report in JSON format.

It is the serialization chosen to be sent to skore-hub.
"""
to_remove = ["dataframe", "sample_table"]
data = {k: v for k, v in self.summary.items() if k not in to_remove}
return json.dumps(data, cls=JSONEncoder)
8 changes: 0 additions & 8 deletions skore/tests/unit/displays/table_report/test_estimator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -380,12 +378,6 @@ def test_corr_plot(pyplot, estimator_report):
assert display.ax_.title.get_text() == "Cramer's V Correlation"


def test_json_dump(display):
"""Check the JSON serialization of the `TableReportDisplay`."""
json_dict = json.loads(display._to_json())
assert isinstance(json_dict, dict)


def test_repr(display):
"""Check the string representation of the `TableReportDisplay`."""
repr = display.__repr__()
Expand Down
Loading