Skip to content

plots: support transposed data #8561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions dvc/render/converter/vega.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def _lists(dictionary: Dict):
yield value


def _dicts(dictionary: Dict):
"""Walk through a dictionary and yield all dicts to be able to extract
nested dicts."""
yield dictionary
for _, value in dictionary.items():
if isinstance(value, dict):
yield from _dicts(value)


def _find_first_list(
data: Union[Dict, List], fields: Set, **kwargs
) -> List[Dict]:
Expand All @@ -67,6 +76,35 @@ def _find_first_list(
if not isinstance(data, dict):
return data

# Check whether to transpose the data. Otherwise, we assume that the
# data is stored as list of dicts. Transposed data is when we have a
# dict with fields as keys and the values are lists of numerals.
for dt in _dicts(data):
matched_fields = set(dt.keys())
if not fields <= matched_fields:
continue
# If fields is non-empty, we check only the fields we care for.
# In that case, we check that all fields are there.
if fields:
matched_fields = fields

if (
# Check that all fields are lists.
all(isinstance(dt.get(field), list) for field in matched_fields)
# Check that all list elements are numerals.
and all(
all(isinstance(x, (int, float)) for x in dt.get(field))
for field in matched_fields
)
# Check that all lists for the fields have the same length.
and len({len(dt.get(field)) for field in matched_fields}) == 1
):
# Transpose the data for fields.
return [
{field: dt[field][i] for field in matched_fields}
for i in range(len(dt[first(matched_fields)]))
]

for lst in _lists(data):
if (
all(isinstance(dp, dict) for dp in lst)
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/render/test_vega_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ def test_find_first_list_in_dict():
_find_first_list(dmetric, fields={"foo"})


def test_find_first_list_in_dict_transposed():
m1 = {"accuracy": [1, 3], "loss": [2, 4]}
m2 = {"x": [1, 2]}
dmetric = OrderedDict([("t1", m1), ("t2", m2)])

m1_transposed = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}]
m2_transposed = [{"x": 1}, {"x": 2}]

assert _find_first_list(dmetric, fields=set()) == m1_transposed
assert _find_first_list(dmetric, fields={"x"}) == m2_transposed

with pytest.raises(PlotDataStructureError):
_find_first_list(dmetric, fields={"foo"})


def test_filter_fields():
m = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}]

Expand Down