|
1 | 1 | import glob |
| 2 | +import subprocess |
| 3 | +from datetime import datetime |
| 4 | +from typing import List, TypedDict |
2 | 5 |
|
3 | 6 | import numpy as np |
4 | 7 | import pytest |
|
8 | 11 | from e3sm_diags.logger import custom_logger |
9 | 12 | from tests.complete_run_script import params, runner |
10 | 13 |
|
11 | | - |
12 | 14 | logger = custom_logger(__name__) |
13 | 15 |
|
14 | 16 |
|
15 | | -DEV_DIR = "843-migration-phase3-model-vs-obs" |
16 | | -DEV_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{DEV_DIR}/" |
| 17 | +def _get_git_branch_name() -> str: |
| 18 | + """Get the current git branch name.""" |
| 19 | + try: |
| 20 | + branch_name = ( |
| 21 | + subprocess.check_output( |
| 22 | + ["git", "rev-parse", "--abbrev-ref", "HEAD"], |
| 23 | + stderr=subprocess.DEVNULL, |
| 24 | + ) |
| 25 | + .strip() |
| 26 | + .decode("utf-8") |
| 27 | + ) |
| 28 | + except subprocess.CalledProcessError: |
| 29 | + branch_name = "unknown" |
| 30 | + |
| 31 | + return branch_name |
| 32 | + |
| 33 | + |
| 34 | +BRANCH_NAME = _get_git_branch_name() |
| 35 | +DEV_TIMESTAMP = datetime.now().strftime("%y-%m-%d") |
| 36 | +DEV_DIR = f"{DEV_TIMESTAMP}-{BRANCH_NAME}" |
| 37 | +DEV_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/complete_run/{DEV_DIR}" |
17 | 38 |
|
18 | 39 | DEV_GLOB = sorted(glob.glob(DEV_PATH + "**/**/*.nc")) |
19 | 40 | DEV_NUM_FILES = len(DEV_GLOB) |
20 | 41 |
|
21 | | -MAIN_DIR = "main" |
22 | | -MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/cdat-migration-fy24/{MAIN_DIR}/" |
| 42 | +# TODO: Update `MAIN_DIR` as needed. |
| 43 | +MAIN_DIR = "24-12-09-main" |
| 44 | +MAIN_PATH = f"/global/cfs/cdirs/e3sm/www/e3sm_diags/{MAIN_DIR}/" |
23 | 45 | MAIN_GLOB = sorted(glob.glob(MAIN_PATH + "**/**/*.nc")) |
24 | 46 | MAIN_NUM_FILES = len(MAIN_GLOB) |
25 | 47 |
|
@@ -54,7 +76,6 @@ def run_diags_and_get_results_dir() -> str: |
54 | 76 | class TestRegression: |
55 | 77 | @pytest.fixture(autouse=True) |
56 | 78 | def setup(self, run_diags_and_get_results_dir): |
57 | | - # TODO: We need to store `main` results on a data container |
58 | 79 | self.results_dir = run_diags_and_get_results_dir |
59 | 80 |
|
60 | 81 | def test_check_if_files_found(self): |
@@ -90,8 +111,19 @@ def test_get_relative_diffs(self): |
90 | 111 | assert len(results["key_errors"]) == 0 |
91 | 112 |
|
92 | 113 |
|
93 | | -def _get_relative_diffs(): |
94 | | - results = { |
| 114 | +class DiffResults(TypedDict): |
| 115 | + """Type annotation for the results of the relative differences comparison.""" |
| 116 | + |
| 117 | + missing_files: List[str] |
| 118 | + missing_vars: List[str] |
| 119 | + matching_files: List[str] |
| 120 | + mismatch_errors: List[str] |
| 121 | + not_equal_errors: List[str] |
| 122 | + key_errors: List[str] |
| 123 | + |
| 124 | + |
| 125 | +def _get_relative_diffs() -> DiffResults: |
| 126 | + results: DiffResults = { |
95 | 127 | "missing_files": [], |
96 | 128 | "missing_vars": [], |
97 | 129 | "matching_files": [], |
@@ -192,7 +224,7 @@ def _get_var_data(ds: xr.Dataset, var_key: str) -> np.ndarray | None: |
192 | 224 | except KeyError: |
193 | 225 | var_keys = DERIVED_VARIABLES[var_key.upper()].keys() |
194 | 226 |
|
195 | | - var_keys = [var_key] + list(sum(var_keys, ())) |
| 227 | + var_keys = [var_key] + list(sum(var_keys, ())) # type: ignore |
196 | 228 |
|
197 | 229 | for key in var_keys: |
198 | 230 | if key in ds.data_vars.keys(): |
|
0 commit comments