Skip to content

Commit 73f3c52

Browse files
authored
feat(skore-hub-project): Update how table report is sent to the hub (#1999)
This PR moves `TableReportDisplay._to_json` to `skore-hub-project` where it belongs. Also plotting during serialization is done with an offscreen backend and an extract of the data is serilaized.
1 parent 3cabb19 commit 73f3c52

File tree

5 files changed

+58
-45
lines changed

5 files changed

+58
-45
lines changed

โ€Žskore-hub-project/src/skore_hub_project/media/data.pyโ€Ž

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,36 @@
22

33
from functools import cached_property
44
from inspect import signature
5-
from json import loads
65
from typing import Literal
76

7+
import numpy as np
88
from pydantic import Field, computed_field
99

10+
from skore_hub_project import switch_mpl_backend
1011
from skore_hub_project.protocol import EstimatorReport
1112

1213
from .media import Media, Representation
1314

1415

16+
def _to_native(obj):
17+
"""Walk an object and cast all numpy types to native type.
18+
19+
Useful for json serialization.
20+
"""
21+
if isinstance(obj, dict):
22+
return {k: _to_native(v) for k, v in obj.items()}
23+
elif isinstance(obj, list):
24+
return [_to_native(v) for v in obj]
25+
elif isinstance(obj, tuple):
26+
return tuple(_to_native(v) for v in obj)
27+
elif isinstance(obj, np.ndarray):
28+
return obj.tolist()
29+
elif isinstance(obj, np.generic):
30+
return obj.item()
31+
else:
32+
return obj
33+
34+
1535
class TableReport(Media): # noqa: D101
1636
report: EstimatorReport = Field(repr=False, exclude=True)
1737
key: str = "table_report"
@@ -21,20 +41,27 @@ class TableReport(Media): # noqa: D101
2141
@computed_field # type: ignore[prop-decorator]
2242
@cached_property
2343
def representation(self) -> Representation: # noqa: D102
24-
function = self.report.data.analyze
25-
function_parameters = signature(function).parameters
26-
function_kwargs = {
27-
k: v for k, v in self.attributes.items() if k in function_parameters
28-
}
29-
30-
table_report_display = function(**function_kwargs)
31-
table_report_json_str = table_report_display._to_json()
32-
table_report = loads(table_report_json_str)
33-
34-
return Representation(
35-
media_type="application/vnd.skrub.table-report.v1+json",
36-
value=table_report,
37-
)
44+
with switch_mpl_backend():
45+
function = self.report.data.analyze
46+
function_parameters = signature(function).parameters
47+
function_kwargs = {
48+
k: v for k, v in self.attributes.items() if k in function_parameters
49+
}
50+
51+
table_report_display = function(**function_kwargs)
52+
table_report = table_report_display.summary
53+
54+
table_report["extract"] = (
55+
table_report["dataframe"].head(3).to_dict(orient="split")
56+
)
57+
58+
del table_report["dataframe"]
59+
del table_report["sample_table"]
60+
61+
return Representation(
62+
media_type="application/vnd.skrub.table-report.v1+json",
63+
value=_to_native(table_report),
64+
)
3865

3966

4067
class TableReportTrain(TableReport): # noqa: D101

โ€Žskore-hub-project/tests/unit/media/test_data.pyโ€Ž

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
from pydantic import ValidationError
2-
from pytest import fixture, mark, param, raises
2+
from pytest import mark, param, raises
33
from skore_hub_project.media import TableReportTest, TableReportTrain
44

55

6-
@fixture(autouse=True)
7-
def monkeypatch_to_json(monkeypatch):
8-
monkeypatch.setattr(
9-
"skore._sklearn._plot.TableReportDisplay._to_json", lambda self: "[0,1]"
10-
)
11-
12-
136
@mark.parametrize(
147
"Media,data_source",
158
(
@@ -21,6 +14,8 @@ def test_table_report(binary_classification, Media, data_source):
2114
media = Media(report=binary_classification)
2215
media_dict = media.model_dump()
2316

17+
representation_value = media_dict["representation"]["value"]
18+
media_dict["representation"]["value"] = {}
2419
assert media_dict == {
2520
"key": "table_report",
2621
"verbose_name": "Table report",
@@ -29,9 +24,19 @@ def test_table_report(binary_classification, Media, data_source):
2924
"parameters": {},
3025
"representation": {
3126
"media_type": "application/vnd.skrub.table-report.v1+json",
32-
"value": [0, 1],
27+
"value": {},
3328
},
3429
}
30+
assert set(
31+
[
32+
"n_rows",
33+
"n_columns",
34+
"n_constant_columns",
35+
"extract",
36+
"columns",
37+
"top_associations",
38+
]
39+
).issubset(representation_value.keys())
3540

3641
# wrong type
3742
with raises(ValidationError):

โ€Žskore-hub-project/tests/unit/project/test_project.pyโ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def monkeypatch_permutation(monkeypatch):
7171

7272

7373
@fixture(autouse=True)
74-
def monkeypatch_to_json(monkeypatch):
74+
def monkeypatch_table_report_representation(monkeypatch):
7575
monkeypatch.setattr(
76-
"skore._sklearn._plot.TableReportDisplay._to_json", lambda self: "[0,1]"
76+
"skore_hub_project.media.data.TableReport.representation", lambda self: {}
7777
)
7878

7979

โ€Žskore/src/skore/_sklearn/_plot/data/table_report.pyโ€Ž

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from typing import Any, Literal
32

43
import numpy as np
@@ -11,7 +10,6 @@
1110
from skrub._reporting._html import to_html
1211
from skrub._reporting._summarize import summarize_dataframe
1312
from skrub._reporting._utils import (
14-
JSONEncoder,
1513
duration_to_numeric,
1614
ellide_string,
1715
top_k_value_counts,
@@ -689,12 +687,3 @@ def _html_repr(self) -> str:
689687

690688
def __repr__(self) -> str:
691689
return f"<{self.__class__.__name__}(...)>"
692-
693-
def _to_json(self) -> str:
694-
"""Serialize the data of this report in JSON format.
695-
696-
It is the serialization chosen to be sent to skore-hub.
697-
"""
698-
to_remove = ["dataframe", "sample_table"]
699-
data = {k: v for k, v in self.summary.items() if k not in to_remove}
700-
return json.dumps(data, cls=JSONEncoder)

โ€Žskore/tests/unit/displays/table_report/test_estimator.pyโ€Ž

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import json
2-
31
import numpy as np
42
import pandas as pd
53
import pytest
@@ -380,12 +378,6 @@ def test_corr_plot(pyplot, estimator_report):
380378
assert display.ax_.title.get_text() == "Cramer's V Correlation"
381379

382380

383-
def test_json_dump(display):
384-
"""Check the JSON serialization of the `TableReportDisplay`."""
385-
json_dict = json.loads(display._to_json())
386-
assert isinstance(json_dict, dict)
387-
388-
389381
def test_repr(display):
390382
"""Check the string representation of the `TableReportDisplay`."""
391383
repr = display.__repr__()

0 commit comments

Comments
ย (0)