Skip to content

WIP: Add kwargs to InferenceData.to_netcdf() #2410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def to_netcdf(
engine: str = "h5netcdf",
base_group: str = "/",
overwrite_existing: bool = True,
**kwargs
) -> str:
"""Write InferenceData to netcdf4 file.

Expand All @@ -481,7 +482,11 @@ def to_netcdf(
By default, will write to the root of the netCDF file
overwrite_existing : bool, default True
Whether to overwrite the existing file or append to it.


Other keyword arguments will be passed to `xarray.Dataset.to_netcdf()`. If
provided these will serve to override dict items that relate to `compress` and
`engine` parameters described above.
Comment on lines +486 to +488
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be the parameter description according to numpydoc, so indented under a **kwargs parameter with no type: https://numpydoc.readthedocs.io/en/latest/format.html#parameters (the last paragraph of this section).

Also, if you use :meth:`xarray.Dataset.to_netcdf` or even `xarray.Dataset.to_netcdf` (given our sphinx configuration) it will be rendered as a link to the respective docs in the xarray website. You can check the rendered docstring preview from your PR at https://arviz--2410.org.readthedocs.build/en/2410/api/generated/arviz.InferenceData.to_netcdf.html


Returns
-------
str
Expand All @@ -501,6 +506,15 @@ def to_netcdf(
)
mode = "a"

# add items to kwargs corresponding directly to parameters of this method
kwargs["engine"] = engine

# get encoding dict that may have been passed in
try:
encoding_kw2 = kwargs["encoding"]
except KeyError:
encoding_kw2 = {}
Comment on lines +513 to +516
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we generally use .get for these kind of operations: encoding_kw2 = kwargs.get("encoding", {})


if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._groups_all
Expand All @@ -509,13 +523,30 @@ def to_netcdf(

for group in groups:
data = getattr(self, group)
kwargs = {"engine": engine}

# define encoding kwargs according to compress
# but only for compressible dtypes
if compress:
kwargs["encoding"] = {
encoding_kw1 = {
var_name: {"zlib": True}
for var_name, values in data.variables.items()
if _compressible_dtype(values.dtype)
}
else:
encoding_kw1 = {}

# merge the two dicts-of-dicts
encoding_kw_merged = {}
for var_name, kw1 in encoding_kw1.items():
try:
kw2 = encoding_kw2[var_name]
except KeyError:
kw2 = {}
encoding_kw_merged[var_name] = {**kw1,**kw2}
Comment on lines +540 to +545
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic here would not work as expected when there are non compressible types. My line of thought/duck debugging:

  1. encoding_kw2 is full and has elements for all variables
  2. encoding_kw_merged is empty
  3. We loop only over the variable names in encoding_kw1 which will only contain compressible variables. Then for each of these variables only:
    • We merge the respective variable specifics encoding_kw1 and encoding_kw2
  4. encoding_kw_merged has the same keys as encoding_kw1 and the merged dicts as values.
    • If there were no compressible variables, encoding_kw_merged would be empty even with encoding_kw2 being full

Potential proposal:

Suggested change
for var_name, kw1 in encoding_kw1.items():
try:
kw2 = encoding_kw2[var_name]
except KeyError:
kw2 = {}
encoding_kw_merged[var_name] = {**kw1,**kw2}
for var_name in data.data_vars:
kw1 = encoding_kw1.get(var_name, {})
kw2 = encoding_kw2.get(var_name, {})
encoding_kw_merged[var_name] = kw1 | kw2

# note: entries passed in via kwargs will overwrite
# those that may have been created due to other parameters
kwargs["encoding"] = encoding_kw_merged

data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs)
data.close()
mode = "a"
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,15 @@ def test_empty_inference_data_object(self):
os.remove(filepath)
assert not os.path.exists(filepath)

def test_to_netcdf_kwargs(self):
"""Tests to verify that passing kwargs to `InferenceData.to_netcdf()`
works as intended"""
True # TODO
# 1) define an InferenceData object (e.g. from file)
# 2) define different sets of `**kwargs` to pass
# 3) use inference_data.to_netcdf(filepath,**kwargs)
# 4) test these make it through to `data.to_netcdf()` as intended - TODO how?
Comment on lines +1490 to +1493
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you propose is about right. Pseudocode idea:

idata = load...
# store with encoding kwargs that mean small but non-neglibible loss of precision
# and as previous test, check requested filename exists
idata_encoded = load...
for group in idata.groups:
    # use https://docs.xarray.dev/en/stable/generated/xarray.testing.assert_allclose.html#xarray.testing.assert_allclose
    # once as
    with pytest.raises(AssertionError):
        `assert_allclose(... tol=low/default)
    # then again as
    assert_allclose(..., tol=high)
# clean up files



class TestJSON:
def test_json_converters(self, models):
Expand Down
Loading