Skip to content

Commit dcd1980

Browse files
committed
Move navy_land.nc to resources dir and move _get_resource_path to utils.py
- Update unit tests for `_get_resource_path()` to cover all cases - Refactor `_get_resource_path` for maintainability and readability -- remove unnecessary args, simplify logic, remove ModuleNotFoundError - Update docstring of `pcmdi_land_sea_mask()` with source to `navy_land.nc` - Update `pcmdi_land_sea_mask()` to not decode time for navy_land.nc to prevent logger warning since it has no time axis - Update `.gitignore` to noy ignore `xcdat/resources` dir - Update `pyproject.toml` `tools.setuptools.package-data`
1 parent a5407e3 commit dcd1980

File tree

7 files changed

+136
-100
lines changed

7 files changed

+136
-100
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ ENV/
113113
# Input data
114114
*.nc
115115

116+
# But keep any netCDF files in the resources directory
117+
!xcdat/resources/*.nc
118+
116119
# Anaconda
117120
conda-build/
118121

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ dev = ["types-python-dateutil", "pre-commit", "ruff", "mypy"]
6161
include = ["xcdat", "xcdat.*"]
6262

6363
[tool.setuptools.package-data]
64-
"xcdat" = ["*.nc"]
64+
"xcdat" = ["resources/navy_land.nc"]
6565

6666
[tool.setuptools.dynamic]
6767
version = { attr = "xcdat.__version__" }

tests/test_mask.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
import sys
32
from unittest import mock
43

@@ -344,48 +343,6 @@ def test_pcmdi_land_sea_mask_multiple_iterations(_improve_mask, ds):
344343
xr.testing.assert_equal(output, expected)
345344

346345

347-
def test_get_resource_path(monkeypatch, tmp_path):
348-
mock_file = tmp_path / "navy_land.nc"
349-
mock_file.touch()
350-
351-
mock_as_file = mock.MagicMock()
352-
mock_as_file.return_value.__enter__.return_value = mock_file
353-
354-
monkeypatch.setattr(mask.resources, "as_file", mock_as_file)
355-
356-
path = mask._get_resource_path("navy_land.nc")
357-
358-
assert path == mock_file
359-
360-
361-
def test_get_resource_path_fallback_from_exception(monkeypatch, tmp_path):
362-
mock_file = tmp_path / "xcdat" / "navy_land.nc"
363-
mock_file.parent.mkdir(parents=True, exist_ok=True)
364-
mock_file.touch()
365-
366-
mock_as_file = mock.MagicMock()
367-
mock_as_file.side_effect = FileNotFoundError("Resource not found")
368-
369-
monkeypatch.setattr(mask.resources, "as_file", mock_as_file)
370-
371-
path = mask._get_resource_path("navy_land.nc", tmp_path)
372-
373-
assert re.match(r".*xcdat/navy_land.nc", str(path))
374-
375-
376-
def test_get_resource_path_fallback_missing(monkeypatch, tmp_path):
377-
mock_as_file = mock.MagicMock()
378-
mock_as_file.side_effect = FileNotFoundError("Resource not found")
379-
380-
monkeypatch.setattr(mask.resources, "as_file", mock_as_file)
381-
382-
with pytest.raises(
383-
RuntimeError,
384-
match=r"Resource file 'navy_land.nc' not found in package or at .*",
385-
):
386-
mask._get_resource_path("navy_land.nc", tmp_path)
387-
388-
389346
def test_is_circular():
390347
# Circular
391348
lon = xr.DataArray(data=np.array([0, 90, 180, 270]), dims=["lon"])

tests/test_utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1+
from importlib import resources
2+
from pathlib import Path
3+
from unittest import mock
4+
15
import pytest
26
import xarray as xr
37

4-
from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool
8+
from xcdat.utils import (
9+
_get_resource_path,
10+
_validate_min_weight,
11+
compare_datasets,
12+
str_to_bool,
13+
)
514

615

716
class TestCompareDatasets:
@@ -123,3 +132,48 @@ def test_returns_valid_min_weight(self):
123132
result = _validate_min_weight(1)
124133

125134
assert result == 1
135+
136+
137+
class TestGetResourcePath:
138+
def test_get_resource_path_local(self):
139+
# Mock the resources.files to simulate the local file structure
140+
with mock.patch("xcdat.utils.resources.files") as mock_files:
141+
mock_files.return_value = mock.Mock()
142+
mock_files.return_value.joinpath.return_value = (
143+
Path.cwd() / "xcdat" / "resources" / "navy_land.nc"
144+
)
145+
expected_path = mock_files.return_value.joinpath(
146+
"resources", "navy_land.nc"
147+
)
148+
149+
path = _get_resource_path("navy_land.nc")
150+
151+
assert path == expected_path
152+
153+
@pytest.mark.skipif(
154+
not pytest.importorskip("xcdat", reason="xcdat is not installed"),
155+
reason="xcdat is not installed",
156+
)
157+
def test_get_resource_path(self):
158+
# Locate the actual resources directory in the xcdat package
159+
resource_files = resources.files("xcdat")
160+
expected_path = resource_files / "resources" / "navy_land.nc"
161+
162+
path = _get_resource_path("navy_land.nc")
163+
164+
assert path == expected_path
165+
166+
def test_raises_runtime_error_if_file_not_found(self):
167+
with mock.patch("xcdat.utils.resources.files") as mock_files:
168+
mock_files.return_value = mock.Mock()
169+
mock_files.return_value.joinpath.return_value = Path(
170+
"/nonexistent/path/navy_land.nc"
171+
)
172+
with mock.patch(
173+
"xcdat.utils.resources.as_file", side_effect=FileNotFoundError
174+
):
175+
with pytest.raises(
176+
RuntimeError,
177+
match="Resource file 'navy_land.nc' not found in package",
178+
):
179+
_get_resource_path("navy_land.nc")

xcdat/mask.py

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from importlib import resources
2-
from pathlib import Path
31
from typing import Any, Callable
42

53
import numpy as np
@@ -11,6 +9,7 @@
119
from xcdat.axis import get_dim_coords
1210
from xcdat.regridder.accessor import obj_to_grid_ds
1311
from xcdat.regridder.grid import create_grid
12+
from xcdat.utils import _get_resource_path
1413

1514
logger = _setup_custom_logger(__name__)
1615

@@ -342,10 +341,12 @@ def pcmdi_land_sea_mask(
342341
source: xr.Dataset | None = None,
343342
source_data_var: str | None = None,
344343
) -> xr.DataArray:
345-
"""Generate a land-sea mask using the PCMDI method.
344+
"""
345+
Generate a land-sea mask using the PCMDI method.
346346
347-
This method uses a high-resolution land-sea mask and regrids it to the
348-
resolution of the input DataArray. It then iteratively improves the mask.
347+
This method uses a high-resolution land-sea mask and regrids it to the resolution
348+
of the input DataArray. It then iteratively improves the mask based on specified
349+
thresholds.
349350
350351
Parameters
351352
----------
@@ -356,18 +357,31 @@ def pcmdi_land_sea_mask(
356357
threshold2 : float, optional
357358
The second threshold for improving the mask, by default 0.3.
358359
source : xr.Dataset | None, optional
359-
The Dataset containing the variable to use as the high resolution source.
360+
A custom Dataset containing the variable to use as the high-resolution source.
361+
If not provided, a default high-resolution land-sea mask is used.
360362
source_data_var : str | None, optional
361-
Name of the variable in `source` to use as the high resolution source.
363+
The name of the variable in `source` to use as the high-resolution source.
364+
If `source` is not provided, this defaults to "sftlf".
362365
363366
Returns
364367
-------
365368
xr.DataArray
366-
The land-sea mask.
369+
The generated land-sea mask.
370+
371+
Raises
372+
------
373+
ValueError
374+
If `source` is provided but `source_data_var` is None.
375+
376+
Notes
377+
-----
378+
By default, the `navy_land.nc` file is used as the high-resolution land-sea mask.
379+
This file is sourced from the PCMDI (Program for Climate Model Diagnosis and
380+
Intercomparison) Metrics Package. It is available at:
381+
https://github.com/PCMDI/pcmdi_metrics/blob/main/share/data/navy_land.nc
367382
368383
Examples
369384
--------
370-
371385
Generate a land-sea mask using the PCMDI method:
372386
373387
>>> import xcdat
@@ -376,12 +390,16 @@ def pcmdi_land_sea_mask(
376390
377391
Generate a land-sea mask using the PCMDI method with custom thresholds:
378392
379-
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"], threshold1=0.3, threshold2=0.4)
393+
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(
394+
... ds["tas"], threshold1=0.3, threshold2=0.4
395+
... )
380396
381-
Generate a land-sea mask using the PCMDI method with a custom high resolution source:
397+
Generate a land-sea mask using the PCMDI method with a custom high-res source:
382398
383-
>>> highres_ds = xcdata.open_dataset("/path/to/file")
384-
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(ds["tas"], source=highres_ds, source_data_var="highres")
399+
>>> highres_ds = xcdat.open_dataset("/path/to/file")
400+
>>> land_sea_mask = xcdat.mask.pcmdi_land_sea_mask(
401+
... ds["tas"], source=highres_ds, source_data_var="highres"
402+
... )
385403
"""
386404
if source is not None and source_data_var is None:
387405
raise ValueError(
@@ -391,9 +409,11 @@ def pcmdi_land_sea_mask(
391409
if source is None:
392410
source_data_var = "sftlf"
393411

394-
resource_path = str(_get_resource_path("navy_land.nc", Path.cwd()))
412+
resource_path = str(_get_resource_path("navy_land.nc"))
395413

396-
source = open_dataset(resource_path)
414+
# Turn off time decoding to prevent logger warning since this dataset
415+
# does not have a time axis.
416+
source = open_dataset(resource_path, decode_times=False)
397417

398418
source_regrid = source.regridder.horizontal(
399419
source_data_var, obj_to_grid_ds(da), tool="regrid2"
@@ -435,44 +455,6 @@ def pcmdi_land_sea_mask(
435455
return mask[source_data_var]
436456

437457

438-
def _get_resource_path(filename: str, default_path: Path | None = None) -> Path:
439-
"""Get the path to a resource file.
440-
441-
Parameters
442-
----------
443-
filename : str
444-
The name of the resource file.
445-
446-
Returns
447-
-------
448-
Path
449-
The path to the resource file.
450-
"""
451-
if default_path is None:
452-
default_path = Path.cwd()
453-
454-
resource_path: Path | None = None
455-
456-
try:
457-
with resources.as_file(resources.files("xcdat").joinpath(filename)) as x:
458-
resource_path = x
459-
except (ModuleNotFoundError, FileNotFoundError) as e:
460-
logger.warning(e)
461-
resource_path = None
462-
463-
if resource_path and resource_path.exists():
464-
return resource_path
465-
466-
resource_path = default_path / "xcdat" / filename
467-
468-
if not resource_path.exists():
469-
raise RuntimeError(
470-
f"Resource file {filename!r} not found in package or at {resource_path!s}."
471-
)
472-
473-
return resource_path
474-
475-
476458
def _is_circular(lon: xr.DataArray, lon_bnds: xr.DataArray) -> bool:
477459
"""Check if a longitude axis is circular.
478460
File renamed without changes.

xcdat/utils.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
import importlib
21
import json
2+
from importlib import import_module, resources
3+
from pathlib import Path
34

45
import xarray as xr
56
from dask.array.core import Array
67

8+
from xcdat._logger import _setup_custom_logger
9+
10+
logger = _setup_custom_logger(__name__)
11+
712

813
def compare_datasets(ds1: xr.Dataset, ds2: xr.Dataset) -> dict[str, list[str]]:
914
"""Compares the keys and values of two datasets.
@@ -101,7 +106,7 @@ def _has_module(modname: str) -> bool: # pragma: no cover
101106
bool
102107
"""
103108
try:
104-
importlib.import_module(modname)
109+
import_module(modname)
105110
has = True
106111
except ImportError:
107112
has = False
@@ -188,3 +193,38 @@ def _validate_min_weight(min_weight: float | None) -> float:
188193
)
189194

190195
return min_weight
196+
197+
198+
def _get_resource_path(filename: str) -> Path:
199+
"""Get the path to a resource file from within the package.
200+
201+
Parameters
202+
----------
203+
filename : str
204+
The name of the resource file.
205+
206+
Returns
207+
-------
208+
Path
209+
The path to the resource file.
210+
211+
Raises
212+
------
213+
RuntimeError
214+
If the resource file is not found in the package.
215+
"""
216+
try:
217+
resource_files = resources.files("xcdat")
218+
resource_path = resource_files.joinpath("resources", filename)
219+
220+
with resources.as_file(resource_path) as resolved_path:
221+
if resolved_path.exists():
222+
return resolved_path
223+
except FileNotFoundError as e:
224+
logger.warning(
225+
f"File not found while locating resource {filename!r}. Error: {e}"
226+
)
227+
228+
raise RuntimeError(
229+
f"Resource file {filename!r} not found in package: {resource_path!s}."
230+
)

0 commit comments

Comments
 (0)