-
-
Notifications
You must be signed in to change notification settings - Fork 437
WIP: Add kwargs to InferenceData.to_netcdf() #2410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -462,6 +462,7 @@ def to_netcdf( | |||||||||||||||||||||
engine: str = "h5netcdf", | ||||||||||||||||||||||
base_group: str = "/", | ||||||||||||||||||||||
overwrite_existing: bool = True, | ||||||||||||||||||||||
**kwargs | ||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||
"""Write InferenceData to netcdf4 file. | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -481,7 +482,11 @@ def to_netcdf( | |||||||||||||||||||||
By default, will write to the root of the netCDF file | ||||||||||||||||||||||
overwrite_existing : bool, default True | ||||||||||||||||||||||
Whether to overwrite the existing file or append to it. | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
Other keyword arguments will be passed to `xarray.Dataset.to_netcdf()`. If | ||||||||||||||||||||||
provided these will serve to override dict items that relate to `compress` and | ||||||||||||||||||||||
`engine` parameters described above. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Returns | ||||||||||||||||||||||
------- | ||||||||||||||||||||||
str | ||||||||||||||||||||||
|
@@ -501,6 +506,15 @@ def to_netcdf( | |||||||||||||||||||||
) | ||||||||||||||||||||||
mode = "a" | ||||||||||||||||||||||
|
||||||||||||||||||||||
# add items to kwargs corresponding directly to parameters of this method | ||||||||||||||||||||||
kwargs["engine"] = engine | ||||||||||||||||||||||
|
||||||||||||||||||||||
# get encoding dict that may have been passed in | ||||||||||||||||||||||
try: | ||||||||||||||||||||||
encoding_kw2 = kwargs["encoding"] | ||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||
encoding_kw2 = {} | ||||||||||||||||||||||
Comment on lines
+513
to
+516
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we generally use |
||||||||||||||||||||||
|
||||||||||||||||||||||
if self._groups_all: # check's whether a group is present or not. | ||||||||||||||||||||||
if groups is None: | ||||||||||||||||||||||
groups = self._groups_all | ||||||||||||||||||||||
|
@@ -509,13 +523,30 @@ def to_netcdf( | |||||||||||||||||||||
|
||||||||||||||||||||||
for group in groups: | ||||||||||||||||||||||
data = getattr(self, group) | ||||||||||||||||||||||
kwargs = {"engine": engine} | ||||||||||||||||||||||
|
||||||||||||||||||||||
# define encoding kwargs according to compress | ||||||||||||||||||||||
# but only for compressible dtypes | ||||||||||||||||||||||
if compress: | ||||||||||||||||||||||
kwargs["encoding"] = { | ||||||||||||||||||||||
encoding_kw1 = { | ||||||||||||||||||||||
var_name: {"zlib": True} | ||||||||||||||||||||||
for var_name, values in data.variables.items() | ||||||||||||||||||||||
if _compressible_dtype(values.dtype) | ||||||||||||||||||||||
} | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
encoding_kw1 = {} | ||||||||||||||||||||||
|
||||||||||||||||||||||
# merge the two dicts-of-dicts | ||||||||||||||||||||||
encoding_kw_merged = {} | ||||||||||||||||||||||
for var_name, kw1 in encoding_kw1.items(): | ||||||||||||||||||||||
try: | ||||||||||||||||||||||
kw2 = encoding_kw2[var_name] | ||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||
kw2 = {} | ||||||||||||||||||||||
encoding_kw_merged[var_name] = {**kw1,**kw2} | ||||||||||||||||||||||
Comment on lines
+540
to
+545
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the logic here would not work as expected when there are non compressible types. My line of thought/duck debugging:
Potential proposal:
Suggested change
|
||||||||||||||||||||||
# note: entries passed in via kwargs will overwrite | ||||||||||||||||||||||
# those that may have been created due to other parameters | ||||||||||||||||||||||
kwargs["encoding"] = encoding_kw_merged | ||||||||||||||||||||||
|
||||||||||||||||||||||
data.to_netcdf(filename, mode=mode, group=f"{base_group}/{group}", **kwargs) | ||||||||||||||||||||||
data.close() | ||||||||||||||||||||||
mode = "a" | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1483,6 +1483,15 @@ def test_empty_inference_data_object(self): | |
os.remove(filepath) | ||
assert not os.path.exists(filepath) | ||
|
||
def test_to_netcdf_kwargs(self): | ||
"""Tests to verify that passing kwargs to `InferenceData.to_netcdf()` | ||
works as intended""" | ||
True # TODO | ||
# 1) define an InferenceData object (e.g. from file) | ||
# 2) define different sets of `**kwargs` to pass | ||
# 3) use inference_data.to_netcdf(filepath,**kwargs) | ||
# 4) test these make it through to `data.to_netcdf()` as intended - TODO how? | ||
Comment on lines
+1490
to
+1493
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what you propose is about right. Pseudocode idea: idata = load...
# store with encoding kwargs that mean small but non-neglibible loss of precision
# and as previous test, check requested filename exists
idata_encoded = load...
for group in idata.groups:
# use https://docs.xarray.dev/en/stable/generated/xarray.testing.assert_allclose.html#xarray.testing.assert_allclose
# once as
with pytest.raises(AssertionError):
`assert_allclose(... tol=low/default)
# then again as
assert_allclose(..., tol=high)
# clean up files |
||
|
||
|
||
class TestJSON: | ||
def test_json_converters(self, models): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be the parameter description according to numpydoc, so indented under a
**kwargs
parameter with no type: https://numpydoc.readthedocs.io/en/latest/format.html#parameters (the last paragraph of this section).Also, if you use
:meth:`xarray.Dataset.to_netcdf`
or even`xarray.Dataset.to_netcdf`
(given our sphinx configuration) it will be rendered as a link to the respective docs in the xarray website. You can check the rendered docstring preview from your PR at https://arviz--2410.org.readthedocs.build/en/2410/api/generated/arviz.InferenceData.to_netcdf.html