Skip to content

Commit 08a340a

Browse files
Defer xgcm import to speed up xcdat startup time by ~3 seconds (#810)
1 parent 7ae6b72 commit 08a340a

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

tests/test_regrid.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def setup(self):
4040
z = grid.create_axis("lev", np.linspace(10000, 2000, 2), generate_bounds=False)
4141
self.output_grid = grid.create_grid(z=z)
4242

43-
@mock.patch("xcdat.regridder.xgcm.Grid")
43+
@mock.patch("xgcm.Grid")
4444
def test_infer_target_data(self, grid):
4545
ds = self.ds.copy(True)
4646

@@ -75,7 +75,7 @@ def test_infer_target_data(self, grid):
7575
):
7676
regridder.vertical("so", ds)
7777

78-
@mock.patch("xcdat.regridder.xgcm.Grid")
78+
@mock.patch("xgcm.Grid")
7979
def test_infer_target_data_missing_formula_terms(self, _):
8080
ds = self.ds.copy(True)
8181

@@ -92,7 +92,7 @@ def test_infer_target_data_missing_formula_terms(self, _):
9292
):
9393
regridder.vertical("so", ds)
9494

95-
@mock.patch("xcdat.regridder.xgcm.Grid")
95+
@mock.patch("xgcm.Grid")
9696
def test_infer_target_data_invalid_standard_name(self, _):
9797
ds = self.ds.copy(True)
9898

@@ -110,7 +110,7 @@ def test_infer_target_data_invalid_standard_name(self, _):
110110
):
111111
regridder.vertical("so", ds)
112112

113-
@mock.patch("xcdat.regridder.xgcm.Grid")
113+
@mock.patch("xgcm.Grid")
114114
def test_infer_target_data_missing_required_variable(self, _):
115115
ds = self.ds.copy(True)
116116

@@ -128,7 +128,7 @@ def test_infer_target_data_missing_required_variable(self, _):
128128
):
129129
regridder.vertical("so", ds)
130130

131-
@mock.patch("xcdat.regridder.xgcm.Grid")
131+
@mock.patch("xgcm.Grid")
132132
def test_infer_target_data_empty_formula_terms(self, _):
133133
ds = self.ds.copy(True)
134134

@@ -176,7 +176,7 @@ def test_vertical_regrid(self):
176176

177177
assert output_data.so.shape == (15, 2, 4, 4)
178178

179-
@mock.patch("xcdat.regridder.xgcm.Grid")
179+
@mock.patch("xgcm.Grid")
180180
def test_target_data(self, grid):
181181
regridder = xgcm.XGCMRegridder(self.ds, self.output_grid, method="linear")
182182

@@ -189,7 +189,7 @@ def test_target_data(self, grid):
189189
assert "method" in call_kwargs and call_kwargs["method"] == "linear"
190190
assert "target_data" in call_kwargs and call_kwargs["target_data"] is None
191191

192-
@mock.patch("xcdat.regridder.xgcm.Grid")
192+
@mock.patch("xgcm.Grid")
193193
def test_target_data_da(self, grid):
194194
target_data = np.random.normal(size=self.ds["so"].shape)
195195

@@ -212,7 +212,7 @@ def test_target_data_da(self, grid):
212212

213213
xr.testing.assert_allclose(call_kwargs["target_data"], target_da)
214214

215-
@mock.patch("xcdat.regridder.xgcm.Grid")
215+
@mock.patch("xgcm.Grid")
216216
def test_target_data_ds(self, grid):
217217
target_data = np.random.normal(size=self.ds["so"].shape)
218218

xcdat/regridder/xgcm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Literal, get_args
33

44
import xarray as xr
5-
from xgcm import Grid
65

76
from xcdat._logger import _setup_custom_logger
87
from xcdat.axis import get_dim_coords
@@ -151,6 +150,13 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
151150

152151
def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
153152
"""See documentation in :py:func:`xcdat.regridder.xgcm.XGCMRegridder`"""
153+
# NOTE: Importing xgcm triggers Numba’s @guvectorize JIT compilation
154+
# in the xgcm.transform module, which can be time-consuming during
155+
# initial imports. To avoid impacting the import time of xcdat, we
156+
# import xgcm only when this method is called. Subsequent calls to this
157+
# method will use the cached import.
158+
from xgcm import Grid
159+
154160
try:
155161
output_coord_z = get_dim_coords(self._output_grid, "Z")
156162
except KeyError as e:

0 commit comments

Comments
 (0)