Skip to content

Commit 8ea5bad

Browse files
CDAT Migration Phase 1 - ReplaceCDPParameter (#638)
Co-authored-by: chengzhuzhang <[email protected]>
1 parent c168e9b commit 8ea5bad

36 files changed

+629
-178
lines changed

conda-env/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ dependencies:
2929
- scipy
3030
- pytest
3131
- pytest-cov
32-
3332
prefix: /opt/miniconda3/envs/e3sm_diags_ci

e3sm_diags/driver/aerosol_aeronet_driver.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import os
24
from typing import TYPE_CHECKING, Optional
35

@@ -7,21 +9,22 @@
79

810
import e3sm_diags
911
from e3sm_diags.driver import utils
12+
from e3sm_diags.logger import custom_logger
1013
from e3sm_diags.plot.cartopy import aerosol_aeronet_plot
1114

1215
if TYPE_CHECKING:
13-
from e3sm_diags.parameter.core_parameter import CoreParameter
1416
from cdms2.tvariable import TransientVariable
1517

16-
from e3sm_diags.logger import custom_logger
18+
from e3sm_diags.parameter.core_parameter import CoreParameter
19+
1720

1821
logger = custom_logger(__name__)
1922

2023
# This aerosol diagnostics scripts based on AERONET sites data was originally developed by Feng Yan and adapted and integrated in e3sm_diags by Jill Zhang.
2124
# Years include 2006–2015 average climatology for observation according to Feng et al. 2022:doi:10.1002/essoar.10510950.1, and Golaz et al. 2022 E3SMv2 paper.
2225

2326

24-
def run_diag(parameter: "CoreParameter") -> "CoreParameter":
27+
def run_diag(parameter: CoreParameter) -> CoreParameter:
2528
"""Runs the aerosol aeronet diagnostic.
2629
2730
:param parameter: Parameters for the run
@@ -70,7 +73,7 @@ def run_diag(parameter: "CoreParameter") -> "CoreParameter":
7073

7174

7275
def interpolate_model_output_to_obs_sites(
73-
var: Optional["TransientVariable"], var_id: str
76+
var: Optional[TransientVariable], var_id: str
7477
):
7578
"""Interpolate model outputs (on regular lat lon grids) to observational sites
7679

e3sm_diags/driver/aerosol_budget_driver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import csv
24
import os
35
from typing import TYPE_CHECKING # , Optional
@@ -6,7 +8,6 @@
68

79
if TYPE_CHECKING:
810
from e3sm_diags.parameter.core_parameter import CoreParameter
9-
# from cdms2.tvariable import TransientVariable
1011

1112
import cdutil
1213
import numpy
@@ -89,7 +90,7 @@ def generate_metrics_dic(data, aerosol, season):
8990
MISSING_VALUE = 999.999
9091

9192

92-
def run_diag(parameter: "CoreParameter") -> "CoreParameter":
93+
def run_diag(parameter: CoreParameter) -> CoreParameter:
9394
"""Runs the aerosol aeronet diagnostic.
9495
9596
:param parameter: Parameters for the run

e3sm_diags/driver/annual_cycle_zonal_mean_driver.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from __future__ import print_function
1+
from __future__ import annotations
22

33
from typing import TYPE_CHECKING, Any, Dict, List
44

@@ -19,10 +19,12 @@
1919
from cdms2.axis import TransientAxis
2020
from cdms2.tvariable import TransientVariable
2121

22-
from e3sm_diags.parameter.core_parameter import CoreParameter
22+
from e3sm_diags.parameter.annual_cycle_zonal_mean_parameter import (
23+
ACzonalmeanParameter,
24+
)
2325

2426

25-
def run_diag(parameter: "CoreParameter"):
27+
def run_diag(parameter: ACzonalmeanParameter) -> ACzonalmeanParameter:
2628
"""Runs the annual cycle zonal mean diagnostic.
2729
2830
:param parameter: Parameters for the run
@@ -128,7 +130,7 @@ def run_diag(parameter: "CoreParameter"):
128130
return parameter
129131

130132

131-
def _create_annual_cycle(dataset: Dataset, variable: str) -> "TransientVariable":
133+
def _create_annual_cycle(dataset: Dataset, variable: str) -> TransientVariable:
132134
"""Creates the annual climatology cycle for a dataset variable.
133135
134136
:param dataset: Dataset
@@ -144,7 +146,7 @@ def _create_annual_cycle(dataset: Dataset, variable: str) -> "TransientVariable"
144146
for index, month in enumerate(month_list):
145147
var = dataset.get_climo_variable(variable, month)
146148
if month == "01":
147-
var_ann_cycle: "TransientVariable" = MV2.zeros([12] + list(var.shape))
149+
var_ann_cycle: TransientVariable = MV2.zeros([12] + list(var.shape))
148150
var_ann_cycle.id = var.id
149151
var_ann_cycle.long_name = var.long_name
150152
var_ann_cycle.units = var.units

e3sm_diags/driver/area_mean_time_series_driver.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from __future__ import annotations
2+
13
import collections
24
import json
35
import os
6+
from typing import TYPE_CHECKING
47

58
import cdms2
69
import cdutil
@@ -11,6 +14,12 @@
1114
from e3sm_diags.metrics import mean
1215
from e3sm_diags.plot.cartopy import area_mean_time_series_plot
1316

17+
if TYPE_CHECKING:
18+
from e3sm_diags.parameter.area_mean_time_series_parameter import (
19+
AreaMeanTimeSeriesParameter,
20+
)
21+
22+
1423
logger = custom_logger(__name__)
1524

1625
RefsTestMetrics = collections.namedtuple("RefsTestMetrics", ["refs", "test", "metrics"])
@@ -24,7 +33,7 @@ def create_metrics(ref_domain):
2433
return {"mean": mean(ref_domain)}
2534

2635

27-
def run_diag(parameter):
36+
def run_diag(parameter: AreaMeanTimeSeriesParameter) -> AreaMeanTimeSeriesParameter:
2837
variables = parameter.variables
2938
regions = parameter.regions
3039
ref_names = parameter.ref_names

e3sm_diags/driver/arm_diags_driver.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from __future__ import annotations
2+
13
import collections
24
import json
35
import os
6+
from typing import TYPE_CHECKING
47

58
import cdms2
69
import numpy as np
@@ -11,6 +14,9 @@
1114
from e3sm_diags.logger import custom_logger
1215
from e3sm_diags.plot.cartopy import arm_diags_plot
1316

17+
if TYPE_CHECKING:
18+
from e3sm_diags.parameter.arm_diags_parameter import ARMDiagsParameter
19+
1420
logger = custom_logger(__name__)
1521

1622
RefsTestMetrics = collections.namedtuple(
@@ -50,7 +56,7 @@ def create_metrics(test, ref):
5056
}
5157

5258

53-
def run_diag_diurnal_cycle(parameter):
59+
def run_diag_diurnal_cycle(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
5460
variables = parameter.variables
5561
regions = parameter.regions
5662
ref_name = parameter.ref_name
@@ -149,7 +155,7 @@ def run_diag_diurnal_cycle(parameter):
149155
return parameter
150156

151157

152-
def run_diag_diurnal_cycle_zt(parameter):
158+
def run_diag_diurnal_cycle_zt(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
153159
variables = parameter.variables
154160
regions = parameter.regions
155161
ref_name = parameter.ref_name
@@ -259,7 +265,7 @@ def run_diag_diurnal_cycle_zt(parameter):
259265
return parameter
260266

261267

262-
def run_diag_annual_cycle(parameter):
268+
def run_diag_annual_cycle(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
263269
variables = parameter.variables
264270
regions = parameter.regions
265271
ref_name = parameter.ref_name
@@ -360,7 +366,7 @@ def run_diag_annual_cycle(parameter):
360366
return parameter
361367

362368

363-
def run_diag_convection_onset(parameter):
369+
def run_diag_convection_onset(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
364370
regions = parameter.regions
365371
ref_name = parameter.ref_name
366372
ref_path = parameter.reference_data_path
@@ -407,11 +413,11 @@ def run_diag_convection_onset(parameter):
407413
return parameter
408414

409415

410-
def run_diag_pdf_daily(parameter):
416+
def run_diag_pdf_daily(parameter: ARMDiagsParameter):
411417
logger.info("'run_diag_pdf_daily' is not yet implemented.")
412418

413419

414-
def run_diag(parameter):
420+
def run_diag(parameter: ARMDiagsParameter) -> ARMDiagsParameter:
415421

416422
if parameter.diags_set == "annual_cycle":
417423
return run_diag_annual_cycle(parameter)

e3sm_diags/driver/cosp_histogram_driver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from __future__ import print_function
1+
from __future__ import annotations
22

33
import os
4+
from typing import TYPE_CHECKING
45

56
import cdms2
67

@@ -10,6 +11,9 @@
1011
from e3sm_diags.metrics import corr, max_cdms, mean, min_cdms, rmse
1112
from e3sm_diags.plot import plot
1213

14+
if TYPE_CHECKING:
15+
from e3sm_diags.parameter.core_parameter import CoreParameter
16+
1317
logger = custom_logger(__name__)
1418

1519

@@ -40,7 +44,7 @@ def create_metrics(ref, test, ref_regrid, test_regrid, diff):
4044
return metrics_dict
4145

4246

43-
def run_diag(parameter):
47+
def run_diag(parameter: CoreParameter) -> CoreParameter:
4448
variables = parameter.variables
4549
seasons = parameter.seasons
4650
ref_name = getattr(parameter, "ref_name", "")

e3sm_diags/driver/diurnal_cycle_driver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from __future__ import print_function
1+
from __future__ import annotations
22

33
import os
4+
from typing import TYPE_CHECKING
45

56
import cdms2
67

@@ -11,8 +12,11 @@
1112

1213
logger = custom_logger(__name__)
1314

15+
if TYPE_CHECKING:
16+
from e3sm_diags.parameter.diurnal_cycle_parameter import DiurnalCycleParameter
1417

15-
def run_diag(parameter):
18+
19+
def run_diag(parameter: DiurnalCycleParameter) -> DiurnalCycleParameter:
1620
variables = parameter.variables
1721
seasons = parameter.seasons
1822
ref_name = getattr(parameter, "ref_name", "")

e3sm_diags/driver/enso_diags_driver.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from __future__ import print_function
1+
from __future__ import annotations
22

33
import json
44
import math
55
import os
6+
from typing import TYPE_CHECKING
67

78
import cdms2
89
import cdutil
@@ -16,6 +17,10 @@
1617
from e3sm_diags.metrics import corr, max_cdms, mean, min_cdms, rmse, std
1718
from e3sm_diags.plot.cartopy.enso_diags_plot import plot_map, plot_scatter
1819

20+
if TYPE_CHECKING:
21+
from e3sm_diags.parameter.enso_diags_parameter import EnsoDiagsParameter
22+
23+
1924
logger = custom_logger(__name__)
2025

2126

@@ -185,7 +190,7 @@ def create_metrics(ref, test, ref_regrid, test_regrid, diff):
185190
return metrics_dict
186191

187192

188-
def run_diag_map(parameter):
193+
def run_diag_map(parameter: EnsoDiagsParameter) -> EnsoDiagsParameter:
189194
variables = parameter.variables
190195
seasons = parameter.seasons
191196
regions = parameter.regions
@@ -378,7 +383,7 @@ def run_diag_map(parameter):
378383
return parameter
379384

380385

381-
def run_diag_scatter(parameter):
386+
def run_diag_scatter(parameter: EnsoDiagsParameter) -> EnsoDiagsParameter:
382387
variables = parameter.variables
383388
run_type = parameter.run_type
384389
# We will always use the same regions, so we don't do the following:
@@ -433,7 +438,7 @@ def run_diag_scatter(parameter):
433438
return parameter
434439

435440

436-
def run_diag(parameter):
441+
def run_diag(parameter: EnsoDiagsParameter) -> EnsoDiagsParameter:
437442
if parameter.plot_type == "map":
438443
return run_diag_map(parameter)
439444
elif parameter.plot_type == "scatter":

e3sm_diags/driver/lat_lon_driver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from __future__ import print_function
1+
from __future__ import annotations
22

33
import json
44
import os
5+
from typing import TYPE_CHECKING
56

67
import cdms2
78

@@ -13,6 +14,9 @@
1314

1415
logger = custom_logger(__name__)
1516

17+
if TYPE_CHECKING:
18+
from e3sm_diags.parameter.core_parameter import CoreParameter
19+
1620

1721
def create_and_save_data_and_metrics(parameter, mv1_domain, mv2_domain):
1822
if not parameter.model_only:
@@ -107,7 +111,7 @@ def create_metrics(ref, test, ref_regrid, test_regrid, diff):
107111
return metrics_dict
108112

109113

110-
def run_diag(parameter): # noqa: C901
114+
def run_diag(parameter: CoreParameter) -> CoreParameter: # noqa: C901
111115
variables = parameter.variables
112116
seasons = parameter.seasons
113117
ref_name = getattr(parameter, "ref_name", "")

0 commit comments

Comments
 (0)