Skip to content

WIP: Add kwargs to InferenceData.to_netcdf() #2410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def to_netcdf(
engine: str = "h5netcdf",
base_group: str = "/",
overwrite_existing: bool = True,
**kwargs
) -> str:
"""Write InferenceData to netcdf4 file.

Expand All @@ -481,7 +482,11 @@ def to_netcdf(
By default, will write to the root of the netCDF file
overwrite_existing : bool, default True
Whether to overwrite the existing file or append to it.


Other keyword arguments will be passed to `xarray.Dataset.to_netcdf()`. If
provided these will serve to override dict items that relate to `compress` and
`engine` parameters described above.

Returns
-------
str
Expand All @@ -501,6 +506,15 @@ def to_netcdf(
)
mode = "a"

# add items to kwargs corresponding directly to parameters of this method
kwargs["engine"] = engine

# get encoding dict that may have been passed in
try:
encoding_kw2 = kwargs["encoding"]
except KeyError:
encoding_kw2 = {}

if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._groups_all
Expand All @@ -509,13 +523,30 @@ def to_netcdf(

for group in groups:
data = getattr(self, group)
kwargs = {"engine": engine}

# define encoding kwargs according to compress
# but only for compressible dtypes
if compress:
kwargs["encoding"] = {
encoding_kw1 = {
var_name: {"zlib": True}
for var_name, values in data.variables.items()
if _compressible_dtype(values.dtype)
}
else:
encoding_kw1 = {}

# merge the two dicts-of-dicts
encoding_kw_merged = {}
for var_name, kw1 in encoding_kw1.items():
try:
kw2 = encoding_kw2[var_name]
except KeyError:
kw2 = {}
encoding_kw_merged[var_name] = {**kw1,**kw2}
# note: entries passed in via kwargs will overwrite
# those that may have been created due to other parameters
kwargs["encoding"] = encoding_kw_merged

data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs)
data.close()
mode = "a"
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,15 @@ def test_empty_inference_data_object(self):
os.remove(filepath)
assert not os.path.exists(filepath)

def test_to_netcdf_kwargs(self):
"""Tests to verify that passing kwargs to `InferenceData.to_netcdf()`
works as intended"""
True # TODO
# 1) define an InferenceData object (e.g. from file)
# 2) define different sets of `**kwargs` to pass
# 3) use inference_data.to_netcdf(filepath,**kwargs)
# 4) test these make it through to `data.to_netcdf()` as intended - TODO how?


class TestJSON:
def test_json_converters(self, models):
Expand Down
Loading