Skip to content

Commit da3b9a3

Browse files
committed
fix pre-commit errors
1 parent 4605fac commit da3b9a3

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

e3sm_diags/driver/lat_lon_native_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def process_variable_derivations(dataset, variable_key, dataset_name=""): # noq
259259
):
260260
data_source = parameter.ref_data_file_path
261261
else:
262-
data_source = ds_ref
262+
data_source = ds_ref # type: ignore
263263

264264
# Load the dataset with uxarray
265265
uxds_ref = ux.open_dataset(grid_file, data_source)
@@ -986,7 +986,7 @@ def _create_metrics_dict(
986986
"npe": getattr(uxds, "npe", ""),
987987
"element_count": len(uxds.face) if hasattr(uxds, "face") else 0,
988988
}
989-
metrics_dict["grid_info"] = grid_info
989+
metrics_dict["grid_info"] = grid_info # type: ignore
990990
except Exception as e:
991991
logger.warning(f"Error adding grid info to metrics: {e}")
992992

@@ -1021,7 +1021,7 @@ def _create_metrics_dict(
10211021
# In the first implementation, we'll use global means for simplicity
10221022
metrics_dict["misc"] = {
10231023
"rmse": abs(
1024-
metrics_dict["test_regrid"]["mean"] - metrics_dict["ref_regrid"]["mean"]
1024+
metrics_dict["test_regrid"]["mean"] - metrics_dict["ref_regrid"]["mean"] # type: ignore
10251025
),
10261026
"corr": 0.0, # Placeholder - proper correlation would require regridding
10271027
}

e3sm_diags/parameter/lat_lon_native_parameter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@ def __init__(self):
1818

1919
# Path to the grid files for the native grids
2020
self.test_grid_file = "" # Grid file for test data
21-
self.ref_grid_file = "" # Grid file for reference data
21+
self.ref_grid_file = "" # Grid file for reference data
22+
23+
# File paths for data files (set dynamically during processing)
24+
self.test_data_file_path = "" # Path to test data file
25+
self.ref_data_file_path = "" # Path to reference data file
2226

2327
# Option for handling periodic elements
2428
# If True, split elements that cross the dateline for better visualization
2529
self.split_periodic_elements = True
2630

2731
# Style options for native grid visualization
2832
self.edge_color = None # Set to a color string to show grid edges
29-
self.edge_width = 0.3 # Width of grid edges when displayed
33+
self.edge_width = 0.3 # Width of grid edges when displayed
3034

3135
# Option to disable the grid antialiasing (may improve performance)
3236
self.antialiased = False

e3sm_diags/plot/lat_lon_native_plot.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def plot( # noqa: C901
9191
region_specs = REGION_SPECS.get(region, None)
9292

9393
# Set map bounds based on region
94-
if region_specs:
95-
lat_bounds = region_specs.get("lat", (-90, 90))
96-
lon_bounds = region_specs.get("lon", (0, 360))
94+
if region_specs is not None:
95+
lat_bounds = region_specs.get("lat", (-90, 90)) # type: ignore
96+
lon_bounds = region_specs.get("lon", (0, 360)) # type: ignore
9797
is_global_domain = lat_bounds == (-90, 90) and lon_bounds == (0, 360)
9898
else:
9999
lat_bounds = (-90, 90)
@@ -138,7 +138,7 @@ def plot( # noqa: C901
138138

139139
# Extract metrics for reference data if available
140140
ref_min = ref_max = ref_mean = diff_min = diff_max = diff_mean = None
141-
if has_reference:
141+
if has_reference and uxds_ref is not None:
142142
if is_global_domain:
143143
ref_min = uxds_ref[var_key].min().item()
144144
ref_max = uxds_ref[var_key].max().item()
@@ -274,10 +274,15 @@ def plot( # noqa: C901
274274
)
275275

276276
# Add RMSE and correlation text
277-
rmse = abs(diff_mean) # Simplified RMSE
277+
rmse = (
278+
abs(diff_mean) if diff_mean is not None else None
279+
) # Simplified RMSE
278280
corr = 0.0 # Placeholder - proper correlation would require aligned data
279281

280-
if rmse < 0.01:
282+
if rmse is None:
283+
logger.error("RMSE calculation failed: diff_mean is None")
284+
return
285+
elif rmse < 0.01:
281286
rmse_fmt = "%.2e"
282287
else:
283288
rmse_fmt = "%.4f"

e3sm_diags/viewer/default_viewer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,11 @@ def create_viewer(root_dir, parameters):
176176
] = os.path.join(
177177
"..", "{}".format(set_name), parameter.case_id, fnm
178178
)
179-
print(os.path.join(
180-
"..", "{}".format(set_name), parameter.case_id, fnm))
179+
print(
180+
os.path.join(
181+
"..", "{}".format(set_name), parameter.case_id, fnm
182+
)
183+
)
181184
ROW_INFO[set_name][parameter.case_id][row_name][season][
182185
"metadata"
183186
] = create_metadata(parameter)

0 commit comments

Comments
 (0)