diff --git a/skore-hub-project/src/skore_hub_project/media/data.py b/skore-hub-project/src/skore_hub_project/media/data.py index 567629c1ad..c26a0a671b 100644 --- a/skore-hub-project/src/skore_hub_project/media/data.py +++ b/skore-hub-project/src/skore_hub_project/media/data.py @@ -2,16 +2,36 @@ from functools import cached_property from inspect import signature -from json import loads from typing import Literal +import numpy as np from pydantic import Field, computed_field +from skore_hub_project import switch_mpl_backend from skore_hub_project.protocol import EstimatorReport from .media import Media, Representation +def _to_native(obj): + """Walk an object and cast all numpy types to native type. + + Useful for json serialization. + """ + if isinstance(obj, dict): + return {k: _to_native(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_to_native(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(_to_native(v) for v in obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + else: + return obj + + class TableReport(Media): # noqa: D101 report: EstimatorReport = Field(repr=False, exclude=True) key: str = "table_report" @@ -21,20 +41,27 @@ class TableReport(Media): # noqa: D101 @computed_field # type: ignore[prop-decorator] @cached_property def representation(self) -> Representation: # noqa: D102 - function = self.report.data.analyze - function_parameters = signature(function).parameters - function_kwargs = { - k: v for k, v in self.attributes.items() if k in function_parameters - } - - table_report_display = function(**function_kwargs) - table_report_json_str = table_report_display._to_json() - table_report = loads(table_report_json_str) - - return Representation( - media_type="application/vnd.skrub.table-report.v1+json", - value=table_report, - ) + with switch_mpl_backend(): + function = self.report.data.analyze + function_parameters = signature(function).parameters + function_kwargs = { + k: v for k, v in self.attributes.items() if k in function_parameters + } + + table_report_display = function(**function_kwargs) + table_report = table_report_display.summary + + table_report["extract"] = ( + table_report["dataframe"].head(3).to_dict(orient="split") + ) + + del table_report["dataframe"] + del table_report["sample_table"] + + return Representation( + media_type="application/vnd.skrub.table-report.v1+json", + value=_to_native(table_report), + ) class TableReportTrain(TableReport): # noqa: D101 diff --git a/skore-hub-project/tests/unit/media/test_data.py b/skore-hub-project/tests/unit/media/test_data.py index 3441fbe1d8..2e8c809c16 100644 --- a/skore-hub-project/tests/unit/media/test_data.py +++ b/skore-hub-project/tests/unit/media/test_data.py @@ -1,15 +1,8 @@ from pydantic import ValidationError -from pytest import fixture, mark, param, raises +from pytest import mark, param, raises from skore_hub_project.media import TableReportTest, TableReportTrain -@fixture(autouse=True) -def monkeypatch_to_json(monkeypatch): - monkeypatch.setattr( - "skore._sklearn._plot.TableReportDisplay._to_json", lambda self: "[0,1]" - ) - - @mark.parametrize( "Media,data_source", ( @@ -21,6 +14,8 @@ def test_table_report(binary_classification, Media, data_source): media = Media(report=binary_classification) media_dict = media.model_dump() + representation_value = media_dict["representation"]["value"] + media_dict["representation"]["value"] = {} assert media_dict == { "key": "table_report", "verbose_name": "Table report", @@ -29,9 +24,19 @@ def test_table_report(binary_classification, Media, data_source): "parameters": {}, "representation": { "media_type": "application/vnd.skrub.table-report.v1+json", - "value": [0, 1], + "value": {}, }, } + assert set( + [ + "n_rows", + "n_columns", + "n_constant_columns", + "extract", + "columns", + "top_associations", + ] + ).issubset(representation_value.keys()) # wrong type with raises(ValidationError): diff --git a/skore-hub-project/tests/unit/project/test_project.py b/skore-hub-project/tests/unit/project/test_project.py index 4d825d07bb..9c2a435c34 100644 --- a/skore-hub-project/tests/unit/project/test_project.py +++ b/skore-hub-project/tests/unit/project/test_project.py @@ -71,9 +71,9 @@ def monkeypatch_permutation(monkeypatch): @fixture(autouse=True) -def monkeypatch_to_json(monkeypatch): +def monkeypatch_table_report_representation(monkeypatch): monkeypatch.setattr( - "skore._sklearn._plot.TableReportDisplay._to_json", lambda self: "[0,1]" + "skore_hub_project.media.data.TableReport.representation", lambda self: {} ) diff --git a/skore/src/skore/_sklearn/_plot/data/table_report.py b/skore/src/skore/_sklearn/_plot/data/table_report.py index 1dd7d7243b..e062675a8a 100644 --- a/skore/src/skore/_sklearn/_plot/data/table_report.py +++ b/skore/src/skore/_sklearn/_plot/data/table_report.py @@ -1,4 +1,3 @@ -import json from typing import Any, Literal import numpy as np @@ -11,7 +10,6 @@ from skrub._reporting._html import to_html from skrub._reporting._summarize import summarize_dataframe from skrub._reporting._utils import ( - JSONEncoder, duration_to_numeric, ellide_string, top_k_value_counts, @@ -689,12 +687,3 @@ def _html_repr(self) -> str: def __repr__(self) -> str: return f"<{self.__class__.__name__}(...)>" - - def _to_json(self) -> str: - """Serialize the data of this report in JSON format. - - It is the serialization chosen to be sent to skore-hub. - """ - to_remove = ["dataframe", "sample_table"] - data = {k: v for k, v in self.summary.items() if k not in to_remove} - return json.dumps(data, cls=JSONEncoder) diff --git a/skore/tests/unit/displays/table_report/test_estimator.py b/skore/tests/unit/displays/table_report/test_estimator.py index 4c1223e730..442f9b193c 100644 --- a/skore/tests/unit/displays/table_report/test_estimator.py +++ b/skore/tests/unit/displays/table_report/test_estimator.py @@ -1,5 +1,3 @@ -import json - import numpy as np import pandas as pd import pytest @@ -380,12 +378,6 @@ def test_corr_plot(pyplot, estimator_report): assert display.ax_.title.get_text() == "Cramer's V Correlation" -def test_json_dump(display): - """Check the JSON serialization of the `TableReportDisplay`.""" - json_dict = json.loads(display._to_json()) - assert isinstance(json_dict, dict) - - def test_repr(display): """Check the string representation of the `TableReportDisplay`.""" repr = display.__repr__()