Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setup(self):
z = grid.create_axis("lev", np.linspace(10000, 2000, 2), generate_bounds=False)
self.output_grid = grid.create_grid(z=z)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_infer_target_data(self, grid):
ds = self.ds.copy(True)

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_infer_target_data(self, grid):
):
regridder.vertical("so", ds)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_infer_target_data_missing_formula_terms(self, _):
ds = self.ds.copy(True)

Expand All @@ -92,7 +92,7 @@ def test_infer_target_data_missing_formula_terms(self, _):
):
regridder.vertical("so", ds)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_infer_target_data_invalid_standard_name(self, _):
ds = self.ds.copy(True)

Expand All @@ -110,7 +110,7 @@ def test_infer_target_data_invalid_standard_name(self, _):
):
regridder.vertical("so", ds)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_infer_target_data_missing_required_variable(self, _):
ds = self.ds.copy(True)

Expand All @@ -128,7 +128,7 @@ def test_infer_target_data_missing_required_variable(self, _):
):
regridder.vertical("so", ds)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_infer_target_data_empty_formula_terms(self, _):
ds = self.ds.copy(True)

Expand Down Expand Up @@ -176,7 +176,7 @@ def test_vertical_regrid(self):

assert output_data.so.shape == (15, 2, 4, 4)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_target_data(self, grid):
regridder = xgcm.XGCMRegridder(self.ds, self.output_grid, method="linear")

Expand All @@ -189,7 +189,7 @@ def test_target_data(self, grid):
assert "method" in call_kwargs and call_kwargs["method"] == "linear"
assert "target_data" in call_kwargs and call_kwargs["target_data"] is None

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_target_data_da(self, grid):
target_data = np.random.normal(size=self.ds["so"].shape)

Expand All @@ -212,7 +212,7 @@ def test_target_data_da(self, grid):

xr.testing.assert_allclose(call_kwargs["target_data"], target_da)

@mock.patch("xcdat.regridder.xgcm.Grid")
@mock.patch("xgcm.Grid")
def test_target_data_ds(self, grid):
target_data = np.random.normal(size=self.ds["so"].shape)

Expand Down
8 changes: 7 additions & 1 deletion xcdat/regridder/xgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Literal, get_args

import xarray as xr
from xgcm import Grid

from xcdat._logger import _setup_custom_logger
from xcdat.axis import get_dim_coords
Expand Down Expand Up @@ -151,6 +150,13 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
"""See documentation in :py:func:`xcdat.regridder.xgcm.XGCMRegridder`"""
# NOTE: Importing xgcm triggers Numba’s @guvectorize JIT compilation
# in the xgcm.transform module, which can be time-consuming during
# initial imports. To avoid impacting the import time of xcdat, we
# import xgcm only when this method is called. Subsequent calls to this
# method will use the cached import.
from xgcm import Grid

try:
output_coord_z = get_dim_coords(self._output_grid, "Z")
except KeyError as e:
Expand Down
Loading