diff --git a/README.md b/README.md index be41c774e..7f76761a6 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,15 @@ The modernization improves performance, usability, and maintainability, paving t for future enhancements to E3SM development. The refactored codebase is now more robust and extensively covered by unit tests, setting a solid foundation for ongoing testing and development. +### New Features in v3.1.0 + +v3.1.0 introduces significant enhancements to support advanced grid analysis and temporal snapshot capabilities: + +| Feature name
(set name) | Brief Introduction | Developers Contributors\* | Released version | +| ------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------- | ---------------- | +| Native grid visualization (lat_lon_native) | Support for plotting data on native grids (e.g., cubed-sphere, unstructured grids) using UXarray, enabling visualization without regridding to preserve native grid features | Jill Zhang, Tom Vo | 3.1.0 | +| Snapshot analysis for core sets | Index-based time selection for snapshot analysis on core diagnostic sets (lat_lon, lat_lon_native, polar, zonal_mean_2d, meridional_mean_2d, zonal_mean_2d_stratosphere), allowing analysis of individual time steps instead of climatological means | Jill Zhang, Tom Vo | 3.1.0 | + ### New Feature added during v2 development | Feature name
(set name) | Brief Introduction | Developers Contributors\* | Released version | diff --git a/auxiliary_tools/debug/1013-snapshot-analysis-core-sets/T.cfg b/auxiliary_tools/debug/1013-snapshot-analysis-core-sets/T.cfg new file mode 100644 index 000000000..874e6dde5 --- /dev/null +++ b/auxiliary_tools/debug/1013-snapshot-analysis-core-sets/T.cfg @@ -0,0 +1,50 @@ +[#] +sets = ["lat_lon"] +case_id = "model_vs_model" +variables = ["T"] +seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +plevs = [850.0] +contour_levels = [240, 245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 295] +diff_levels = [-5, -4, -3, -2, -1, -0.5, -0.25, 0.25, 0.5, 1, 2, 3, 4, 5] +regrid_method = "bilinear" + +[#] +sets = ["polar"] +case_id = "model_vs_model" +variables = ["T"] +seasons = ["DJF", "MAM", "JJA", "SON"] +regions = ["polar_S", "polar_N"] +plevs = [850.0] +contour_levels = [230, 240, 250, 260, 270, 280, 290, 300, 310] +diff_levels = [-15, -10, -7.5, -5, -2.5, -1, 1, 2.5, 5, 7.5, 10, 15] + +[#] +sets = ["zonal_mean_2d"] +case_id = "model_vs_model" +variables = ["T"] +seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +contour_levels = [180,185,190,200,210,220,230,240,250,260,270,280,290,295,300] +diff_levels = [-3.0, -2.5, -2, -1.5, -1, -0.5, -0.25, 0.25, 0.5, 1, 1.5, 2, 2.5, 3.0] + +[#] +sets = ["zonal_mean_2d_stratosphere"] +case_id = "model_vs_model" +variables = ["T"] +seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +contour_levels = [180,185,190,200,210,220,230,240,250,260,270,280,290,295,300] +diff_levels = [-8, -6,-4,-2, -1, -0.5, 0.5, 1, 2, 4,6, 8] + +[#] +sets = ["zonal_mean_xy"] +case_id = "model_vs_model" +variables = ["T"] +seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +plevs = [850.0] + +[#] +sets = ["meridional_mean_2d"] +case_id = "model_vs_model" +variables = ["T"] +seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +contour_levels = [180,185,190,200,210,220,230,240,250,260,270,280,290,295,300] +diff_levels = [-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7] diff --git a/auxiliary_tools/debug/1013-snapshot-analysis-core-sets/run_snapshot_core.py b/auxiliary_tools/debug/1013-snapshot-analysis-core-sets/run_snapshot_core.py new file mode 100644 index 000000000..5b94d19a8 --- /dev/null +++ b/auxiliary_tools/debug/1013-snapshot-analysis-core-sets/run_snapshot_core.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" +This script runs e3sm_diags with the core sets to visualize snap-shot data. +""" + +import os +import sys + +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.run import runner + +# Auto-detect username +username = os.environ.get('USER', 'unknown_user') + +# Create parameter objects for 3 different runs +params = [] + +## (1) First test configuration +#param1 = LatLonNativeParameter() +#param1.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/lat_lon_native_test_1" +#param1.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid" +#param1.test_name = "v3.LR.amip_0101" +#param1.short_test_name = "v3.LR.amip_0101" +#param1.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid" +#param1.ref_name = "v3.HR.test4" +#param1.short_ref_name = "v3.HR.test4" +#param1.seasons = ["DJF"] +#param1.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc" +#param1.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc" +#param1.case_id = "model_vs_model" +#param1.run_type = "model_vs_model" +#params.append(param1) +# +## (2) Second test configuration +#param2 = LatLonNativeParameter() +#param2.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/lat_lon_native_test_2" +#param2.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid" +#param2.test_file = "v3.LR.amip_0101_DJF_climo.nc" +#param2.short_test_name = "v3.LR.amip_0101" +#param2.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid" +#param2.ref_file = "v3.HR.test4_DJF_climo.nc" +#param2.short_ref_name = "v3.HR.test4" +#param2.seasons = ["DJF"] +#param2.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc" +#param2.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne120pg2.nc" +#param2.case_id = "model_vs_model" +#param2.run_type = "model_vs_model" +#params.append(param2) + +# (3) Third test configuration +param3 = CoreParameter() +param3.results_dir = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/tests/1013-snapshot-analysis-core-sets" +param3.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr" +param3.test_file = "T_005101_006012.nc" +param3.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr" +param3.ref_file = "T_005101_006012.nc" +param3.short_test_name = "v2 test" +param3.short_ref_name = "v2 test" +param3.time_slices = ["0","1"] +param3.case_id = "model_vs_model" +param3.run_type = "model_vs_model" +params.append(param3) + +# Run the single diagnostic, comment out for complete diagnostics. +cfg_path = "auxiliary_tools/debug/1013-snapshot-analysis-core-sets/T.cfg" +sys.argv.extend(["--diags", cfg_path]) + +runner.sets_to_run = ["lat_lon", + "zonal_mean_xy", + "zonal_mean_2d", + "zonal_mean_2d_stratosphere", + "polar", + "meridional_mean_2d",] + +# Run each test sequentially +for i, param in enumerate(params, 1): + print(f"\n{'='*60}") + print(f"Running Test {i}: {param.results_dir}") + print(f"{'='*60}") + + # Create results directory + if not os.path.exists(param.results_dir): + os.makedirs(param.results_dir) + + # Run the diagnostic + runner.run_diags([param]) + print(f"Test {i} completed!") + +print(f"\n{'='*60}") +print("All tests completed!") +print(f"{'='*60}") + + diff --git a/docs/source/available-parameters.rst b/docs/source/available-parameters.rst index ce720bbd3..6f8fc9326 100644 --- a/docs/source/available-parameters.rst +++ b/docs/source/available-parameters.rst @@ -70,11 +70,18 @@ functionality of the diagnostics. - **regrid_method**: What regrid method of the regrid tool to use. Possible values are ``'linear'`` or ``'conservative'``. Default is ``'conservative'``. Read the xCDAT documentation on `regridding`_ for more information. -- **regrid_tool**: The regrid tool to use. Default is ``'esmf'``. +- **regrid_tool**: The regrid tool to use. Default is ``'xesmf'``. Read the xCDAT documentation on `regridding`_ for more information. + **Note:** The ``lat_lon_native`` set does not use regridding, so this parameter is ignored for that set. - **seasons**: A list of season to use. Default is annual and all seasons: ``['ANN', 'DJF', 'MAM', 'JJA', 'SON']``. + **Note:** This parameter is mutually exclusive with ``time_slices``. When using ``time_slices``, do not set ``seasons``. +- **time_slices**: *(v3.1.0+)* A list of time indices for snapshot analysis. Examples: ``['0']``, ``['5']``, ``['0', '1', '2']``. + Time slices are zero-based indices into the time dimension of the input dataset. + This enables analysis of individual time steps instead of climatological means. + **Note:** This parameter is mutually exclusive with ``seasons``. When using ``time_slices``, do not set ``seasons``. + Supported sets: ``lat_lon``, ``lat_lon_native``, ``polar``, ``zonal_mean_xy``, ``zonal_mean_2d``, ``meridional_mean_2d``, ``zonal_mean_2d_stratosphere``. - **sets**: A list of the sets to be run. Default is all sets: - ``['zonal_mean_xy', 'zonal_mean_2d', 'meridional_mean_2d', 'lat_lon', 'polar', 'area_mean_time_series', 'cosp_histogram', 'enso_diags', 'qbo', 'streamflow','diurnal_cycle']``. + ``['zonal_mean_xy', 'zonal_mean_2d', 'zonal_mean_2d_stratosphere', 'meridional_mean_2d', 'lat_lon', 'lat_lon_native', 'polar', 'area_mean_time_series', 'cosp_histogram', 'enso_diags', 'qbo', 'streamflow', 'diurnal_cycle', 'arm_diags', 'tc_analysis', 'annual_cycle_zonal_mean', 'lat_lon_land', 'lat_lon_river', 'aerosol_aeronet', 'aerosol_budget']``. - **variables**: What variable(s) to use for this run. Ex: ``variables=["T", "PRECT"]``. .. _regridding: https://xcdat.readthedocs.io/en/latest/getting-started-guide/faqs.html#regridding @@ -262,12 +269,30 @@ You can specify both ``test_end_yr`` and ``ref_end_yr`` or just ``end_yr``. - **plot_log_plevs**: Log-scale the y-axis. Default ``False``. - **plot_plevs**: Plot the pressure levels. Default ``False``. +``'lat_lon_native'`` *(v3.1.0+)*: + +- **test_grid_file**: *(Required)* Path to the grid file for test data in UGRID format (e.g., ``'/path/to/ne30pg2.nc'``). + The grid file defines the native grid structure used by UXarray for visualization. +- **ref_grid_file**: Path to the grid file for reference data in UGRID format (e.g., ``'/path/to/ne30pg2.nc'``). + Required for model vs model comparisons. Can be omitted for model-only runs. +- **antialiased**: Apply antialiasing to the plot. Default ``False``. Setting to ``True`` may improve visual quality but can impact performance. +- **time_slices**: Time indices for snapshot analysis (same as in core parameters). See ``time_slices`` description above. + +**Notes for lat_lon_native:** + +- Native grid visualization requires UXarray (included in E3SM Unified environment) +- Grid files must be in UGRID format +- Regridding parameters (``regrid_tool``, ``regrid_method``) are ignored for this set +- Can use either ``seasons`` for climatology or ``time_slices`` for snapshot analysis (mutually exclusive) Other parameters ~~~~~~~~~~~~~~~~ - **dataset**: Default is ``''``. -- **granulate**: Default is ``['variables', 'seasons', 'plevs', 'regions']``. +- **granulate**: Default is ``['variables', 'seasons', 'plevs', 'regions', 'time_slices']``. + This parameter controls how diagnostics are split into separate runs. - **selectors**: Default is ``['sets', 'seasons']``. See :ref:`Using the selectors parameter `. - **viewer_descr**: Used to specify values in the viewer. Default ``{}``. - **fail_on_incomplete**: Exit status will reflect failure if any parameter fails to complete. Default is ``False`` (e.g., a failing parameter will not create a failing exit code). +- **test_file**: *(v3.1.0+)* Specify the exact file name for test data. Useful for snapshot analysis with ``time_slices`` or when using specific data files. +- **ref_file**: *(v3.1.0+)* Specify the exact file name for reference data. Useful for snapshot analysis with ``time_slices`` or when using specific data files. diff --git a/docs/source/examples.rst b/docs/source/examples.rst index dbe0d87e2..323b7b773 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -93,6 +93,38 @@ on two different sets: ``zonal_mean_2d`` an ``lat_lon``. so you can compare different version of the data, or the same variable from different datasets. We are comparing CERES EBAF TOA version 2.8 and 4.0. +8. Native Grid Visualization (v3.1.0) +-------------------------------------- +`This example `__ demonstrates how to visualize model data on its native grid +(e.g., cubed-sphere, unstructured grids) without regridding to a regular lat-lon grid. +This feature uses UXarray to preserve native grid features and is particularly useful for high-resolution models with complex grid structures. +The example shows model vs model comparison using snapshot analysis on native grids. + +**Key features:** + +- Visualize data on native grids without regridding +- Preserve native grid features and characteristics +- Support for cubed-sphere and unstructured grids +- Uses UXarray for grid-aware operations + +9. Snapshot Analysis for Core Sets (v3.1.0) +-------------------------------------------- +`This example `__ demonstrates time slice analysis on core diagnostic sets. +Instead of computing climatological seasonal means, this analyzes individual time steps from model output using index-based time selection. + +This is useful for analyzing specific events, comparing model states at particular time points, +or understanding temporal evolution without time averaging. The example shows how to use the ``time_slices`` parameter +on multiple diagnostic sets (lat_lon, zonal_mean_2d, polar, meridional_mean_2d, zonal_mean_2d_stratosphere). + +**Key features:** + +- Index-based time selection (e.g., time_slices = ["0", "1", "2"]) +- Analyze individual time steps without temporal averaging +- Event-based or process-oriented diagnostics +- Works across multiple core diagnostic sets + +**Note:** ``time_slices`` and ``seasons`` parameters are mutually exclusive. + Running the Examples ==================== @@ -117,7 +149,7 @@ The parameters file contains information related to the location of the data, what years to run the diagnostics on, what plots to create, and more. The configuration file provides information about the diagnostics you are running. -This is used in Ex.4,5,7. +This is used in Ex.4, 5, 7, 8, 9. Parameters for each example can be found in `this directory `__. @@ -134,7 +166,7 @@ Use the code below to run the diagnostics. salloc --nodes 1 --qos interactive --time 01:00:00 --constraint cpu --account=e3sm # Enter the E3SM Unified environment. For Perlmutter CPU, the command to do this is: source /global/common/software/e3sm/anaconda_envs/load_latest_e3sm_unified_pm-cpu.sh - # Running Ex.1. For examples 4,5,7 append ``-d diags.cfg``. + # Running Ex.1. For examples 4, 5, 7, 8, 9 append ``-d diags.cfg``. python ex1.py --multiprocessing --num_workers=32 # You may need to change permissions on your web directory to see the example output. chmod -R 755 @@ -174,6 +206,8 @@ These were generated with the following script: # emacs ex5-model-vs-obs/ex5.py # emacs ex6-model-vs-obs-custom/ex6.py # emacs ex7-obs-vs-obs/ex7.py + # emacs ex8-native-grid-visualization/ex8.py + # emacs ex9-snapshot-analysis/ex9.py source /global/common/software/e3sm/anaconda_envs/load_latest_e3sm_unified_pm-cpu.sh cd ex1-model_ts-vs-model_ts @@ -187,9 +221,13 @@ These were generated with the following script: cd ../ex5-model-vs-obs python ex5.py --multiprocessing --num_workers=32 -d diags.cfg cd ../ex6-model-vs-obs-custom - python ex6.py --multiprocessing --num_workers=32 + python ex6.py cd ../ex7-obs-vs-obs - python ex7.py --multiprocessing --num_workers=32 -d diags.cfg + python ex7.py -d diags.cfg + cd ../ex8-native-grid-visualization + python ex8.py -d diags.cfg + cd ../ex9-snapshot-analysis + python ex9.py -d diags.cfg cd ../ chmod -R 755 /global/cfs/cdirs/e3sm/www/forsyth/examples diff --git a/e3sm_diags/driver/default_diags/lat_lon_native_model_vs_model.cfg b/e3sm_diags/driver/default_diags/lat_lon_native_model_vs_model.cfg index ab33c3607..37f73e874 100644 --- a/e3sm_diags/driver/default_diags/lat_lon_native_model_vs_model.cfg +++ b/e3sm_diags/driver/default_diags/lat_lon_native_model_vs_model.cfg @@ -3,7 +3,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["PRECT"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] -regions = ["global", "60S60N", "30S30N-150E90W"] +regions = ["global"] test_colormap = "WhiteBlueGreenYellowRed.rgb" reference_colormap = "WhiteBlueGreenYellowRed.rgb" diff_colormap = "BrBG" @@ -15,7 +15,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["PRECC"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] -regions = ["global", "60S60N", "30S30N-150E90W"] +regions = ["global"] test_colormap = "WhiteBlueGreenYellowRed.rgb" reference_colormap = "WhiteBlueGreenYellowRed.rgb" diff_colormap = "BrBG" @@ -51,6 +51,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["SWCFSRF"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [-170, -150, -135, -120, -105, -90, -75, -60, -45, -30, -15, 0, 15, 30, 45] diff_levels = [-30, -25, -20, -15, -10, -5, -2, 2, 5, 10, 15, 20, 25, 30] @@ -60,6 +61,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["LWCFSRF"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [0, 10, 20, 30, 40, 50, 60, 70, 80] diff_levels = [-30, -25, -20, -15, -10, -5, -2, 2, 5, 10, 15, 20, 25, 30] @@ -69,6 +71,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["LHFLX"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [0,5, 15, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300] diff_levels = [-75, -50, -25, -10, -5, -2, 2, 5, 10, 25, 50, 75] @@ -78,6 +81,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["SHFLX"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [-100, -75, -50, -25, -10, 0, 10, 25, 50, 75, 100, 125, 150] diff_levels = [-75, -50, -25, -10, -5, -2, 2, 5, 10, 25, 50, 75] @@ -87,6 +91,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["NET_FLUX_SRF"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [-200, -160, -120, -80, -40, 0, 40, 80, 120, 160, 200] diff_levels = [-75, -50, -25, -10, -5, -2, 2, 5, 10, 25, 50, 75] @@ -96,6 +101,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["TMQ"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60] diff_levels = [-12, -9, -6, -4, -3, -2, -1, 1, 2, 3, 4, 6, 9, 12] @@ -105,6 +111,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["QREFHT"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] contour_levels = [0.2, 0.5, 1, 2.5, 5, 7.5, 10, 12.5, 15, 17.5] diff_levels = [-5, -4, -3, -2, -1, -0.25, 0.25, 1, 2, 3, 4, 5] @@ -113,6 +120,7 @@ sets = ["lat_lon_native"] case_id = "model_vs_model" variables = ["U10"] seasons = ["ANN", "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", "DJF", "MAM", "JJA", "SON"] +regions = ["global"] test_colormap = "PiYG_r" reference_colormap = "PiYG_r" contour_levels = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] diff --git a/e3sm_diags/driver/lat_lon_driver.py b/e3sm_diags/driver/lat_lon_driver.py index 7c45f16df..5baa64027 100755 --- a/e3sm_diags/driver/lat_lon_driver.py +++ b/e3sm_diags/driver/lat_lon_driver.py @@ -5,7 +5,6 @@ import xarray as xr from e3sm_diags.driver import METRICS_DEFAULT_VALUE -from e3sm_diags.driver.utils.climo_xr import ClimoFreq from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots from e3sm_diags.driver.utils.regrid import ( @@ -16,7 +15,6 @@ regrid_z_axis_to_plevs, subset_and_align_datasets, ) -from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import _setup_child_logger from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg, std from e3sm_diags.plot.lat_lon_plot import plot as plot_func @@ -24,6 +22,7 @@ logger = _setup_child_logger(__name__) if TYPE_CHECKING: + from e3sm_diags.driver.utils.type_annotations import MetricsDict, TimeSelection from e3sm_diags.parameter.core_parameter import CoreParameter @@ -50,7 +49,6 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: (e.g., one is 2-D and the other is 3-D). """ variables = parameter.variables - seasons = parameter.seasons ref_name = getattr(parameter, "ref_name", "") regions = parameter.regions @@ -60,16 +58,35 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: test_ds = Dataset(parameter, data_type="test") ref_ds = Dataset(parameter, data_type="ref") + time_selection_type, time_selections = parameter._get_time_selection_to_use() + for var_key in variables: logger.info("Variable: {}".format(var_key)) parameter.var_id = var_key - for season in seasons: - parameter._set_name_yrs_attrs(test_ds, ref_ds, season) + for time_selection in time_selections: + is_time_slice = time_selection_type == "time_slices" + + # Get test and reference datasets. + # NOTE: lat_lon diagnostics get reference datasets differently than + # other sets using its own helper function `_get_ref_dataset`. + if is_time_slice: + ds_test = test_ds.get_time_sliced_dataset(var_key, time_selection) + + # For time slices, always use the annual land-sea mask. + ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask("ANN") + else: + # time_selection will be ClimoFreq, so ignore type checking here. + ds_test = test_ds.get_climo_dataset(var_key, time_selection) # type: ignore[arg-type] + ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask( # type: ignore[no-redef] + time_selection # type: ignore[arg-type] + ) + + ds_ref = _get_ref_dataset(ref_ds, var_key, time_selection, is_time_slice) - ds_test = test_ds.get_climo_dataset(var_key, season) - ds_ref = _get_ref_climo_dataset(ref_ds, var_key, season) - ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask(season) + # Set name_yrs after loading data because time sliced datasets + # have the required attributes only after loading the data. + parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection) if ds_ref is None: is_vars_3d = has_z_axis(ds_test[var_key]) @@ -79,7 +96,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter, ds_test, ds_land_sea_mask, - season, + time_selection, regions, var_key, ref_name, @@ -89,7 +106,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter, ds_test, ds_land_sea_mask, - season, + time_selection, regions, var_key, ref_name, @@ -107,7 +124,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: ds_test, ds_ref, ds_land_sea_mask, - season, + time_selection, regions, var_key, ref_name, @@ -118,7 +135,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: ds_test, ds_ref, ds_land_sea_mask, - season, + time_selection, regions, var_key, ref_name, @@ -285,7 +302,7 @@ def _run_diags_2d( ds_test: xr.Dataset, ds_ref: xr.Dataset, ds_land_sea_mask: xr.Dataset, - season: str, + time_selection: TimeSelection, regions: list[str], var_key: str, ref_name: str, @@ -306,8 +323,8 @@ def _run_diags_2d( ds_land_sea_mask : xr.Dataset The land sea mask dataset, which is only used for masking if the region is "land" or "ocean". - season : str - The season. + time_selection : TimeSelection + The time slice or season. regions : list[str] The list of regions. var_key : str @@ -340,7 +357,9 @@ def _run_diags_2d( ds_diff_region, ) - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) + parameter._set_param_output_attrs( + var_key, time_selection, region, ref_name, ilev=None + ) _save_data_metrics_and_plots( parameter, plot_func, @@ -359,7 +378,7 @@ def _run_diags_3d( ds_test: xr.Dataset, ds_ref: xr.Dataset, ds_land_sea_mask: xr.Dataset, - season: str, + time_selection: TimeSelection, regions: list[str], var_key: str, ref_name: str, @@ -380,8 +399,8 @@ def _run_diags_3d( ds_land_sea_mask : xr.Dataset The land sea mask dataset, which is only used for masking if the region is "land" or "ocean". - season : str - The season. + time_selection : TimeSelection + The time slice or season. regions : list[str] The list of regions. var_key : str @@ -425,7 +444,9 @@ def _run_diags_3d( ds_diff_region, ) - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev) + parameter._set_param_output_attrs( + var_key, time_selection, region, ref_name, ilev + ) _save_data_metrics_and_plots( parameter, plot_func, @@ -439,12 +460,15 @@ def _run_diags_3d( ) -def _get_ref_climo_dataset( - dataset: Dataset, var_key: str, season: ClimoFreq +def _get_ref_dataset( + dataset: Dataset, + var_key: str, + time_selection: TimeSelection, + is_time_slice: bool = False, ) -> xr.Dataset | None: - """Get the reference climatology dataset for the variable and season. + """Get the reference dataset for the variable and time selection. - If the reference climatatology does not exist or could not be found, it + If the reference data does not exist or could not be found, it will be considered a model-only run and return `None`. Parameters @@ -453,13 +477,15 @@ def _get_ref_climo_dataset( The dataset object. var_key : str The key of the variable. - season : CLIMO_FREQ - The climatology frequency. + time_selection : ClimoFreq | str + The climatology frequency or time slice string. + is_time_slice : bool, optional + If True, treat time_selection as a time slice string. Default is False. Returns ------- xr.Dataset | None - The reference climatology if it exists or None if it does not. + The reference dataset if it exists or None if it does not. None indicates a model-only run. Raises @@ -469,7 +495,11 @@ def _get_ref_climo_dataset( """ if dataset.data_type == "ref": try: - ds_ref = dataset.get_climo_dataset(var_key, season) + if is_time_slice: + ds_ref = dataset.get_time_sliced_dataset(var_key, time_selection) + else: + # time_selection will be ClimoFreq, so ignore type checking here. + ds_ref = dataset.get_climo_dataset(var_key, time_selection) # type: ignore[arg-type] except (RuntimeError, IOError): ds_ref = None diff --git a/e3sm_diags/driver/lat_lon_native_driver.py b/e3sm_diags/driver/lat_lon_native_driver.py index d2029532b..8912b1b8a 100644 --- a/e3sm_diags/driver/lat_lon_native_driver.py +++ b/e3sm_diags/driver/lat_lon_native_driver.py @@ -66,7 +66,7 @@ def run_diag(parameter: LatLonNativeParameter) -> LatLonNativeParameter: # noqa for time_period in time_periods: if use_time_slices: logger.info(f"Processing time slice: {time_period}") - parameter._set_time_slice_attrs( + parameter._set_time_slice_name_yrs_attrs( test_ds.dataset, ref_ds.dataset, time_period ) else: diff --git a/e3sm_diags/driver/meridional_mean_2d_driver.py b/e3sm_diags/driver/meridional_mean_2d_driver.py index b075aefcc..d59f05746 100644 --- a/e3sm_diags/driver/meridional_mean_2d_driver.py +++ b/e3sm_diags/driver/meridional_mean_2d_driver.py @@ -7,13 +7,15 @@ import xcdat as xc from e3sm_diags.driver.utils.dataset_xr import Dataset -from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots +from e3sm_diags.driver.utils.io import ( + _get_xarray_datasets, + _save_data_metrics_and_plots, +) from e3sm_diags.driver.utils.regrid import ( align_grids_to_lower_res, has_z_axis, regrid_z_axis_to_plevs, ) -from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import _setup_child_logger from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg from e3sm_diags.parameter.zonal_mean_2d_parameter import DEFAULT_PLEVS @@ -22,6 +24,7 @@ logger = _setup_child_logger(__name__) if TYPE_CHECKING: + from e3sm_diags.driver.utils.type_annotations import MetricsDict, TimeSelection from e3sm_diags.parameter.meridional_mean_2d_parameter import ( MeridionalMean2dParameter, ) @@ -50,21 +53,25 @@ def run_diag(parameter: MeridionalMean2dParameter) -> MeridionalMean2dParameter: If the test or ref variables do are not 3-D (no Z-axis). """ variables = parameter.variables - seasons = parameter.seasons ref_name = getattr(parameter, "ref_name", "") test_ds = Dataset(parameter, data_type="test") ref_ds = Dataset(parameter, data_type="ref") + time_selection_type, time_selections = parameter._get_time_selection_to_use() + for var_key in variables: logger.info("Variable: {}".format(var_key)) parameter.var_id = var_key - for season in seasons: - parameter._set_name_yrs_attrs(test_ds, ref_ds, season) + for time_selection in time_selections: + ds_test, ds_ref, _ = _get_xarray_datasets( + test_ds, ref_ds, var_key, time_selection_type, time_selection + ) - ds_test = test_ds.get_climo_dataset(var_key, season) - ds_ref = ref_ds.get_climo_dataset(var_key, season) + # Set name_yrs after loading data because time sliced datasets + # have the required attributes only after loading the data. + parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection) dv_test = ds_test[var_key] dv_ref = ds_ref[var_key] @@ -84,7 +91,9 @@ def run_diag(parameter: MeridionalMean2dParameter) -> MeridionalMean2dParameter: if not parameter._is_plevs_set(): parameter.plevs = DEFAULT_PLEVS - _run_diags_3d(parameter, ds_test, ds_ref, season, var_key, ref_name) + _run_diags_3d( + parameter, ds_test, ds_ref, time_selection, var_key, ref_name + ) return parameter @@ -93,7 +102,7 @@ def _run_diags_3d( parameter: MeridionalMean2dParameter, ds_test: xr.Dataset, ds_ref: xr.Dataset, - season: str, + time_selection: TimeSelection, var_key: str, ref_name: str, ): @@ -142,7 +151,7 @@ def _run_diags_3d( ) parameter._set_param_output_attrs( - var_key, season, parameter.regions[0], ref_name, ilev=None + var_key, time_selection, parameter.regions[0], ref_name, ilev=None ) _save_data_metrics_and_plots( parameter, diff --git a/e3sm_diags/driver/polar_driver.py b/e3sm_diags/driver/polar_driver.py index dbb79b446..931105b46 100755 --- a/e3sm_diags/driver/polar_driver.py +++ b/e3sm_diags/driver/polar_driver.py @@ -5,14 +5,16 @@ import xarray as xr from e3sm_diags.driver.utils.dataset_xr import Dataset -from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots +from e3sm_diags.driver.utils.io import ( + _get_xarray_datasets, + _save_data_metrics_and_plots, +) from e3sm_diags.driver.utils.regrid import ( get_z_axis, has_z_axis, regrid_z_axis_to_plevs, subset_and_align_datasets, ) -from e3sm_diags.driver.utils.type_annotations import MetricsDict from e3sm_diags.logger import _setup_child_logger from e3sm_diags.metrics.metrics import correlation, rmse, spatial_avg from e3sm_diags.plot.polar_plot import plot as plot_func @@ -20,15 +22,18 @@ logger = _setup_child_logger(__name__) if TYPE_CHECKING: + from e3sm_diags.driver.utils.type_annotations import MetricsDict, TimeSelection from e3sm_diags.parameter.core_parameter import CoreParameter def run_diag(parameter: CoreParameter) -> CoreParameter: variables = parameter.variables - seasons = parameter.seasons ref_name = getattr(parameter, "ref_name", "") regions = parameter.regions + # Check that either seasons or time_slices is specified, but not both + time_selection_type, time_selections = parameter._get_time_selection_to_use() + test_ds = Dataset(parameter, data_type="test") ref_ds = Dataset(parameter, data_type="ref") @@ -36,14 +41,19 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: logger.info("Variable: {}".format(var_key)) parameter.var_id = var_key - for season in seasons: - parameter._set_name_yrs_attrs(test_ds, ref_ds, season) - - # Get land/ocean fraction for masking. - ds_land_sea_mask: xr.Dataset = test_ds._get_land_sea_mask(season) + for time_selection in time_selections: + ds_test, ds_ref, ds_land_sea_mask = _get_xarray_datasets( + test_ds, + ref_ds, + var_key, + time_selection_type, + time_selection, + get_land_sea_mask=True, + ) - ds_test = test_ds.get_climo_dataset(var_key, season) - ds_ref = ref_ds.get_climo_dataset(var_key, season) + # Set name_yrs after loading data because time sliced datasets + # have the required attributes only after loading the data. + parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection) # Store the variable's DataArray objects for reuse. dv_test = ds_test[var_key] @@ -61,8 +71,8 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter, ds_test, ds_ref, - ds_land_sea_mask, - season, + ds_land_sea_mask, # type: ignore[arg-type] + time_selection, regions, var_key, ref_name, @@ -72,8 +82,8 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter, ds_test, ds_ref, - ds_land_sea_mask, - season, + ds_land_sea_mask, # type: ignore[arg-type] + time_selection, regions, var_key, ref_name, @@ -87,7 +97,7 @@ def _run_diags_2d( ds_test: xr.Dataset, ds_ref: xr.Dataset, ds_land_sea_mask: xr.Dataset, - season: str, + time_selection: TimeSelection, regions: list[str], var_key: str, ref_name: str, @@ -109,8 +119,8 @@ def _run_diags_2d( ds_land_sea_mask : xr.Dataset The land sea mask dataset, which is only used for masking if the region is "land" or "ocean". - season : str - The season. + time_selection : TimeSelection + The time slice or season. regions : list[str] The list of regions. var_key : str @@ -143,7 +153,9 @@ def _run_diags_2d( ds_diff_region, ) - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) + parameter._set_param_output_attrs( + var_key, time_selection, region, ref_name, ilev=None + ) _save_data_metrics_and_plots( parameter, plot_func, @@ -162,7 +174,7 @@ def _run_diags_3d( ds_test: xr.Dataset, ds_ref: xr.Dataset, ds_land_sea_mask: xr.Dataset, - season: str, + time_selection: TimeSelection, regions: list[str], var_key: str, ref_name: str, @@ -184,8 +196,8 @@ def _run_diags_3d( ds_land_sea_mask : xr.Dataset The land sea mask dataset, which is only used for masking if the region is "land" or "ocean". - season : str - The season. + time_selection : TimeSelection + The time slice or season. regions : list[str] The list of regions. var_key : str @@ -229,7 +241,9 @@ def _run_diags_3d( ds_diff_region, ) - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev) + parameter._set_param_output_attrs( + var_key, time_selection, region, ref_name, ilev + ) _save_data_metrics_and_plots( parameter, plot_func, diff --git a/e3sm_diags/driver/utils/dataset_native.py b/e3sm_diags/driver/utils/dataset_native.py index 8778220f6..6beba0771 100644 --- a/e3sm_diags/driver/utils/dataset_native.py +++ b/e3sm_diags/driver/utils/dataset_native.py @@ -54,7 +54,7 @@ def dataset_name(self) -> str: def get_native_dataset( self, var_key: str, - season: TimeSelection, + time_selection: TimeSelection, is_time_slice: bool = False, allow_missing: bool = False, ) -> xr.Dataset | None: @@ -96,19 +96,22 @@ def get_native_dataset( try: if is_time_slice: ds = self._get_full_native_dataset() - ds = self._apply_time_slice(ds, season) + ds = self._apply_time_slice(ds, time_selection) else: - if season in get_args(ClimoFreq): - ds = self.dataset.get_climo_dataset(var_key, season) # type: ignore + if time_selection in get_args(ClimoFreq): + # time_selection is a valid ClimoFreq here, so ignore type checking. + ds = self.dataset.get_climo_dataset(var_key, time_selection) # type: ignore[arg-type] else: - raise ValueError(f"Invalid season for climatology: {season}") + raise ValueError( + f"Invalid season for climatology: {time_selection}" + ) # Store file path in parameter for native grid processing. # Note: For climatology case, get_climo_dataset() already handles file # path storage. if is_time_slice: # For time slices, we know the exact file path we used. - filepath = self.dataset._get_climo_filepath_with_params() + filepath = self.dataset._get_filepath_with_params() if filepath: if self.dataset.data_type == "test": @@ -149,7 +152,7 @@ def _get_full_native_dataset(self) -> xr.Dataset: RuntimeError If unable to get the full dataset. """ - filepath = self.dataset._get_climo_filepath_with_params() + filepath = self.dataset._get_filepath_with_params() if filepath is None: raise RuntimeError( @@ -182,7 +185,7 @@ def _apply_time_slice(self, ds: xr.Dataset, time_slice: str) -> xr.Dataset: ds : xr.Dataset The input dataset with time dimension. time_slice : str - The time slice specification (e.g., "0:10:2", "5:15", "7"). + The time slice specification as a single index (e.g., "0", "5", "42"). Returns ------- @@ -203,26 +206,14 @@ def _apply_time_slice(self, ds: xr.Dataset, time_slice: str) -> xr.Dataset: ) return ds - # Parse slice notation - if ":" in time_slice: - # Handle slice notation like "0:10:2" or "5:15" or ":10" or "5:" or "::2" - parts = time_slice.split(":") - - start = int(parts[0]) if parts[0] else None - end = int(parts[1]) if len(parts) > 1 and parts[1] else None - step = int(parts[2]) if len(parts) > 2 and parts[2] else None - - # Apply the slice - ds_sliced = ds.isel({time_dim: slice(start, end, step)}) - else: - # Single index - index = int(time_slice) - ds_sliced = ds.isel({time_dim: index}) + # Single index selection + index = int(time_slice) + ds_sliced = ds.isel({time_dim: index}) logger.info( f"Applied time slice '{time_slice}' to dataset. " f"Original time length: {ds.sizes[time_dim]}, " - f"Sliced time length: {ds_sliced.sizes.get(time_dim, 1)}" + f"Selected index: {index}" ) return ds_sliced diff --git a/e3sm_diags/driver/utils/dataset_xr.py b/e3sm_diags/driver/utils/dataset_xr.py index 38f24d9fa..ca1fc5121 100644 --- a/e3sm_diags/driver/utils/dataset_xr.py +++ b/e3sm_diags/driver/utils/dataset_xr.py @@ -37,7 +37,7 @@ from e3sm_diags.logger import _setup_child_logger if TYPE_CHECKING: - from e3sm_diags.driver.utils.type_annotations import TimeSelection + from e3sm_diags.driver.utils.type_annotations import TimeSelection, TimeSlice from e3sm_diags.parameter.core_parameter import CoreParameter logger = _setup_child_logger(__name__) @@ -352,6 +352,151 @@ def _get_global_attr_from_climo_dataset( return attr_val + # -------------------------------------------------------------------------- + # Time-slice related methods + # -------------------------------------------------------------------------- + def get_time_sliced_dataset(self, var: str, time_slice: TimeSlice) -> xr.Dataset: + """Get the dataset containing the time slice. + + These variables can either be from the test data or reference data. + + Parameters + ---------- + var : str + The key of the climatology or time series variable to get the + dataset for. + time_slice : TimeSlice + The time slice string, expressed as a single index (e.g., "0", "5", + "42"). + + Returns + ------- + xr.Dataset + The dataset containing the time-sliced data. + """ + self.var = var + + if not isinstance(self.var, str) or self.var == "": + raise ValueError("The `var` argument is not a valid string.") + + filepath = self._get_filepath_with_params() + + if filepath is None: + raise RuntimeError( + f"Unable to get file path for {self.data_type} dataset. " + f"For time slicing, please ensure that " + f"{'ref_file' if self.data_type == 'ref' else 'test_file'} parameter is set." + ) + + if not os.path.exists(filepath): + raise RuntimeError(f"File not found: {filepath}") + + self.parameter._add_filepath_attr(self.data_type, filepath) + + ds = self._get_full_dataset() + ds = self._apply_time_slice_to_dataset(ds, time_slice) + + return ds + + def _get_full_dataset(self) -> xr.Dataset: + """Get the full dataset without any time averaging for time slicing. + + This function uses the dataset's file path parameters to directly open + the raw data file for time slicing operations. + + Returns + ------- + xr.Dataset + The full dataset with all time steps. + + Raises + ------ + RuntimeError + If unable to get the full dataset or file not found. + """ + filepath = getattr(self.parameter, f"{self.data_type}_data_file_path") + + logger.info(f"Opening full dataset from: {filepath}") + + try: + ds = xc.open_dataset(filepath, add_bounds=["X", "Y", "T"]) + + logger.info( + f"Successfully opened dataset with time dimension size: {ds.sizes.get('time', 'N/A')}" + ) + except (FileNotFoundError, OSError, ValueError) as e: + raise RuntimeError(f"Failed to open dataset {filepath}: {e}") from e + else: + return ds + + def _get_filepath_with_params(self) -> str | None: + """Get the filepath using parameters. + + Returns + ------- + str | None + The filepath using the `ref_file` or `test_file` parameter if they + are set. + """ + filepath = None + + if self.data_type == "ref": + if self.parameter.ref_file != "": + filepath = os.path.join(self.root_path, self.parameter.ref_file) + + elif self.data_type == "test": + if hasattr(self.parameter, "test_file"): + filepath = os.path.join(self.root_path, self.parameter.test_file) + + return filepath + + def _apply_time_slice_to_dataset( + self, ds: xr.Dataset, time_slice: TimeSlice + ) -> xr.Dataset: + """Apply time slice selection to a dataset. + + Parameters + ---------- + ds : xr.Dataset + The input dataset with time dimension. + time_slice : TimeSlice + The time slice specification as a single index (e.g., "0", "5", "42"). + + Returns + ------- + xr.Dataset + The dataset with time slice applied. + """ + try: + time_dim = xc.get_dim_keys(ds, axis="T") + except (ValueError, KeyError): + time_dim = None + + if time_dim is None: + logger.warning( + "No time dimension found in dataset. Returning original dataset." + ) + + return ds + + index = int(time_slice) + + try: + ds_sliced = ds.isel({time_dim: index}) + except IndexError as e: + raise IndexError( + f"Time slice index {index} is out of bounds for time dimension " + f"of size {ds.sizes[time_dim]}." + ) from e + + logger.info( + f"Applied time slice '{time_slice}' to dataset. " + f"Original time length: {ds.sizes[time_dim]}, " + f"Selected index: {index}" + ) + + return ds_sliced + # -------------------------------------------------------------------------- # Climatology related methods # -------------------------------------------------------------------------- @@ -369,7 +514,7 @@ def get_climo_dataset(self, var: str, season: ClimoFreq) -> xr.Dataset: var : str The key of the climatology or time series variable to get the dataset for. - season : CLIMO_FREQ, optional + season : ClimoFreq The season for the climatology. Returns @@ -388,19 +533,22 @@ def get_climo_dataset(self, var: str, season: ClimoFreq) -> xr.Dataset: if not isinstance(self.var, str) or self.var == "": raise ValueError("The `var` argument is not a valid string.") + if not isinstance(season, str) or season not in CLIMO_FREQS: raise ValueError( - "The `season` argument is not a valid string. Options include: " + f"The `season` argument, {season}, is not a valid string. Options include: " f"{CLIMO_FREQS}" ) if self.is_time_series: ds = self.get_time_series_dataset(var) + # At this point, season is validated to be in CLIMO_FREQS, so it's a + # ClimoFreq ds_climo = climo(ds, self.var, season).to_dataset() ds_climo = ds_climo.bounds.add_missing_bounds(axes=["X", "Y"]) - self.parameter._add_time_series_file_path_attr(self.data_type, ds) + self.parameter._add_time_series_filepath_attr(self.data_type, ds) return ds_climo @@ -410,7 +558,7 @@ def get_climo_dataset(self, var: str, season: ClimoFreq) -> xr.Dataset: try: filepath = self._get_climo_filepath(season) - self.parameter._add_climatology_file_path_attr(self.data_type, filepath) + self.parameter._add_filepath_attr(self.data_type, filepath) except Exception as e: logger.warning(f"Failed to store absolute file path: {e}") @@ -557,7 +705,7 @@ def _get_climo_filepath(self, season: str) -> str: The path to the climatology file. """ # First pattern attempt. - filepath = self._get_climo_filepath_with_params() + filepath = self._get_filepath_with_params() # Second and third pattern attempts. if filepath is None: @@ -589,27 +737,6 @@ def _get_climo_filepath(self, season: str) -> str: return filepath - def _get_climo_filepath_with_params(self) -> str | None: - """Get the climatology filepath using parameters. - - Returns - ------- - str | None - The filepath using the `ref_file` or `test_file` parameter if they - are set. - """ - filepath = None - - if self.data_type == "ref": - if self.parameter.ref_file != "": - filepath = os.path.join(self.root_path, self.parameter.ref_file) - - elif self.data_type == "test": - if hasattr(self.parameter, "test_file"): - filepath = os.path.join(self.root_path, self.parameter.test_file) - - return filepath - def _find_climo_filepath(self, filename: str, season: str) -> str | None: """Find the climatology filepath for the variable. @@ -1457,7 +1584,7 @@ def _center_time_for_non_submonthly_data(self, ds: xr.Dataset) -> xr.Dataset: return ds - def _get_land_sea_mask(self, season: str) -> xr.Dataset: + def _get_land_sea_mask(self, season: ClimoFreq) -> xr.Dataset: """Get the land sea mask from the dataset or use the default file. Land sea mask variables are time invariant which means the time @@ -1466,7 +1593,7 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset: Parameters ---------- - season : str + season : ClimoFreq The season to subset on. Returns @@ -1483,12 +1610,12 @@ def _get_land_sea_mask(self, season: str) -> xr.Dataset: return ds_mask - def _get_land_sea_mask_dataset(self, season: str) -> xr.Dataset | None: + def _get_land_sea_mask_dataset(self, season: ClimoFreq) -> xr.Dataset | None: """Get the land sea mask dataset for the given season. Parameters ---------- - season : str + season : ClimoFreq The season to subset on. Returns @@ -1504,8 +1631,8 @@ def _get_land_sea_mask_dataset(self, season: str) -> xr.Dataset | None: # FIXME: B905: zip() without an explicit strict= parameter for land_key, ocn_key in zip(land_keys, ocn_keys, strict=False): try: - ds_land = self.get_climo_dataset(land_key, season) # type: ignore - ds_ocn = self.get_climo_dataset(ocn_key, season) # type: ignore + ds_land = self.get_climo_dataset(land_key, season) + ds_ocn = self.get_climo_dataset(ocn_key, season) except IOError: pass else: diff --git a/e3sm_diags/driver/utils/io.py b/e3sm_diags/driver/utils/io.py index 8b6293465..e67875851 100644 --- a/e3sm_diags/driver/utils/io.py +++ b/e3sm_diags/driver/utils/io.py @@ -1,17 +1,100 @@ +from __future__ import annotations + import errno import json import os from collections.abc import Callable -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, NamedTuple import xarray as xr -from e3sm_diags.driver.utils.type_annotations import MetricsDict +from e3sm_diags.driver.utils.dataset_xr import Dataset from e3sm_diags.logger import _setup_child_logger from e3sm_diags.parameter.core_parameter import CoreParameter logger = _setup_child_logger(__name__) +if TYPE_CHECKING: + from e3sm_diags.driver.utils.type_annotations import MetricsDict, TimeSelection + + +class DatasetResult(NamedTuple): + ds_test: xr.Dataset + ds_ref: xr.Dataset + ds_land_sea_mask: xr.Dataset | None + + +def _get_xarray_datasets( + test_ds: Dataset, + ref_ds: Dataset, + var_key: str, + time_selection_type: Literal["time_slices", "seasons"], + time_selection: TimeSelection, + get_land_sea_mask: bool = False, +) -> DatasetResult: + """Utility function to fetch datasets based on time selection type. + + Parameters + ---------- + test_ds : Dataset + The test dataset object. + ref_ds : Dataset + The reference dataset object. + var_key : str + The key of the variable to fetch. + time_selection_type : Literal["time_slices", "seasons"] + The type of time selection, e.g., "time_slices" or "seasons". + time_selection : TimeSelection + The time slice or season. + get_land_sea_mask : bool, optional + Whether to fetch the land-sea mask, by default False. + + Returns + ------- + DatasetResult + A named tuple containing (ds_test, ds_ref, ds_land_sea_mask). + """ + fetch_ds_test = _select_dataset_fetch_method(test_ds, time_selection_type) + fetch_ds_ref = _select_dataset_fetch_method(ref_ds, time_selection_type) + + ds_test = fetch_ds_test(var_key, time_selection) + ds_ref = fetch_ds_ref(var_key, time_selection) + + ds_land_sea_mask = None + + if get_land_sea_mask: + # For time slices, always use the annual land-sea mask. + if time_selection_type == "time_slices": + ds_land_sea_mask = test_ds._get_land_sea_mask("ANN") + else: + # time_selection will be ClimoFreq, so ignore type checking here. + ds_land_sea_mask = test_ds._get_land_sea_mask(time_selection) # type: ignore[arg-type] + + return DatasetResult(ds_test, ds_ref, ds_land_sea_mask) + + +def _select_dataset_fetch_method( + dataset: Dataset, time_selection_type: Literal["time_slices", "seasons"] +) -> Callable: + """Select the appropriate dataset fetching method based on time selection type. + + Parameters + ---------- + dataset : Dataset + The dataset object. + time_selection_type : Literal["time_slices", "seasons"] + The type of time selection, e.g., "time_slices" or "seasons. + + Returns + ------- + Callable + The dataset fetching method. + """ + if time_selection_type == "time_slices": + return dataset.get_time_sliced_dataset + + return dataset.get_climo_dataset + def _save_data_metrics_and_plots( parameter: CoreParameter, diff --git a/e3sm_diags/driver/utils/type_annotations.py b/e3sm_diags/driver/utils/type_annotations.py index b0c6bc300..29cde2c72 100644 --- a/e3sm_diags/driver/utils/type_annotations.py +++ b/e3sm_diags/driver/utils/type_annotations.py @@ -8,8 +8,8 @@ MetricsSubDict = dict[str, float | None | list[float]] MetricsDict = dict[str, UnitAttr | MetricsSubDict] -# Type for time slice specification: index-based with optional stride -# Examples: "0:10:2" (start:end:stride), "5:15" (start:end), "7" (single index) +# Type for time slice specification: individual time index for snapshot analysis +# Examples: "0", "5", "42" TimeSlice = str # Union type for time selection - can be either climatology season or time slice diff --git a/e3sm_diags/driver/zonal_mean_2d_driver.py b/e3sm_diags/driver/zonal_mean_2d_driver.py index fcdd4a131..6f2f2918b 100755 --- a/e3sm_diags/driver/zonal_mean_2d_driver.py +++ b/e3sm_diags/driver/zonal_mean_2d_driver.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import copy +from typing import TYPE_CHECKING import xarray as xr import xcdat as xc # noqa: F401 from e3sm_diags.driver.utils.dataset_xr import Dataset -from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots +from e3sm_diags.driver.utils.io import ( + _get_xarray_datasets, + _save_data_metrics_and_plots, +) from e3sm_diags.driver.utils.regrid import ( align_grids_to_lower_res, has_z_axis, @@ -23,17 +29,21 @@ DEFAULT_PLEVS = copy.deepcopy(DEFAULT_PLEVS) +if TYPE_CHECKING: + from e3sm_diags.driver.utils.type_annotations import TimeSelection + def run_diag( parameter: ZonalMean2dParameter, default_plevs=DEFAULT_PLEVS ) -> ZonalMean2dParameter: variables = parameter.variables - seasons = parameter.seasons ref_name = getattr(parameter, "ref_name", "") if not parameter._is_plevs_set(): parameter.plevs = default_plevs + time_selection_type, time_selections = parameter._get_time_selection_to_use() + test_ds = Dataset(parameter, data_type="test") ref_ds = Dataset(parameter, data_type="ref") @@ -41,11 +51,14 @@ def run_diag( logger.info("Variable: {}".format(var_key)) parameter.var_id = var_key - for season in seasons: - parameter._set_name_yrs_attrs(test_ds, ref_ds, season) + for time_selection in time_selections: + ds_test, ds_ref, _ = _get_xarray_datasets( + test_ds, ref_ds, var_key, time_selection_type, time_selection + ) - ds_test = test_ds.get_climo_dataset(var_key, season) - ds_ref = ref_ds.get_climo_dataset(var_key, season) + # Set name_yrs after loading data because time sliced datasets + # have the required attributes only after loading the data. + parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection) # Store the variable's DataArray objects for reuse. dv_test = ds_test[var_key] @@ -64,7 +77,7 @@ def run_diag( parameter, ds_test, ds_ref, - season, + time_selection, var_key, ref_name, ) @@ -78,7 +91,7 @@ def _run_diags_3d( parameter: ZonalMean2dParameter, ds_test: xr.Dataset, ds_ref: xr.Dataset, - season: str, + time_selection: TimeSelection, var_key: str, ref_name: str, ): @@ -117,7 +130,7 @@ def _run_diags_3d( # Set parameter attributes for output files. parameter._set_param_output_attrs( - var_key, season, parameter.regions[0], ref_name, ilev=None + var_key, time_selection, parameter.regions[0], ref_name, ilev=None ) _save_data_metrics_and_plots( parameter, diff --git a/e3sm_diags/driver/zonal_mean_xy_driver.py b/e3sm_diags/driver/zonal_mean_xy_driver.py index f47a82e6d..4ed20e481 100755 --- a/e3sm_diags/driver/zonal_mean_xy_driver.py +++ b/e3sm_diags/driver/zonal_mean_xy_driver.py @@ -7,7 +7,10 @@ from scipy import interpolate from e3sm_diags.driver.utils.dataset_xr import Dataset -from e3sm_diags.driver.utils.io import _save_data_metrics_and_plots +from e3sm_diags.driver.utils.io import ( + _get_xarray_datasets, + _save_data_metrics_and_plots, +) from e3sm_diags.driver.utils.regrid import ( get_z_axis, has_z_axis, @@ -20,6 +23,7 @@ logger = _setup_child_logger(__name__) if TYPE_CHECKING: + from e3sm_diags.driver.utils.type_annotations import TimeSelection from e3sm_diags.parameter.core_parameter import CoreParameter @@ -47,7 +51,6 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: (e.g., one is 2-D and the other is 3-D). """ variables = parameter.variables - seasons = parameter.seasons ref_name = getattr(parameter, "ref_name", "") regions = parameter.regions @@ -58,6 +61,8 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: "supported for the zonal_mean_xy set." ) + time_selection_type, time_selections = parameter._get_time_selection_to_use() + # Variables storing xarray `Dataset` objects start with `ds_` and # variables storing e3sm_diags `Dataset` objects end with `_ds`. This # is to help distinguish both objects from each other. @@ -68,11 +73,14 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: logger.info("Variable: {}".format(var_key)) parameter.var_id = var_key - for season in seasons: - parameter._set_name_yrs_attrs(test_ds, ref_ds, season) + for time_selection in time_selections: + ds_test, ds_ref, _ = _get_xarray_datasets( + test_ds, ref_ds, var_key, time_selection_type, time_selection + ) - ds_test = test_ds.get_climo_dataset(var_key, season) - ds_ref = ref_ds.get_climo_dataset(var_key, season) + # Set name_yrs after loading data because time sliced datasets + # have the required attributes only after loading the data. + parameter._set_name_yrs_attrs(test_ds, ref_ds, time_selection) # Store the variable's DataArray objects for reuse. dv_test = ds_test[var_key] @@ -86,7 +94,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter, ds_test, ds_ref, - season, + time_selection, regions, var_key, ref_name, @@ -96,7 +104,7 @@ def run_diag(parameter: CoreParameter) -> CoreParameter: parameter, ds_test, ds_ref, - season, + time_selection, regions, var_key, ref_name, @@ -114,7 +122,7 @@ def _run_diags_2d( parameter: CoreParameter, ds_test: xr.Dataset, ds_ref: xr.Dataset, - season: str, + time_selection: TimeSelection, regions: list[str], var_key: str, ref_name: str, @@ -133,8 +141,8 @@ def _run_diags_2d( ds_ref : xr.Dataset The dataset containing the ref variable. If this is a model-only run then it will be the same dataset as ``ds_test``. - season : str - The season. + time_selection : TimeSelection + The time slice or season. regions : list[str] The list of regions. var_key : str @@ -148,7 +156,9 @@ def _run_diags_2d( da_test_1d, da_ref_1d = _calc_zonal_mean(ds_test, ds_ref, var_key) da_diff_1d = _get_diff_of_zonal_means(da_test_1d, da_ref_1d) - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev=None) + parameter._set_param_output_attrs( + var_key, time_selection, region, ref_name, ilev=None + ) _save_data_metrics_and_plots( parameter, plot_func, @@ -164,7 +174,7 @@ def _run_diags_3d( parameter: CoreParameter, ds_test: xr.Dataset, ds_ref: xr.Dataset, - season: str, + time_selection: TimeSelection, regions: list[str], var_key: str, ref_name: str, @@ -183,8 +193,8 @@ def _run_diags_3d( ds_ref : xr.Dataset The dataset containing the ref variable. If this is a model-only run then it will be the same dataset as ``ds_test``. - season : str - The season. + time_selection : TimeSelection + The time slice or season. regions : list[str] The list of regions. var_key : str @@ -209,7 +219,9 @@ def _run_diags_3d( da_test_1d, da_ref_1d = _calc_zonal_mean(ds_test_ilev, ds_ref_ilev, var_key) da_diff_1d = _get_diff_of_zonal_means(da_test_1d, da_ref_1d) - parameter._set_param_output_attrs(var_key, season, region, ref_name, ilev) + parameter._set_param_output_attrs( + var_key, time_selection, region, ref_name, ilev + ) _save_data_metrics_and_plots( parameter, plot_func, diff --git a/e3sm_diags/parameter/core_parameter.py b/e3sm_diags/parameter/core_parameter.py index ce4139393..5ae0ee3f1 100644 --- a/e3sm_diags/parameter/core_parameter.py +++ b/e3sm_diags/parameter/core_parameter.py @@ -3,6 +3,7 @@ import copy import importlib import os +import re import sys from typing import TYPE_CHECKING, Any, Literal @@ -17,7 +18,7 @@ if TYPE_CHECKING: import xarray as xr - from e3sm_diags.driver.utils.type_annotations import TimeSelection + from e3sm_diags.driver.utils.type_annotations import TimeSelection, TimeSlice logger = _setup_child_logger(__name__) @@ -127,6 +128,12 @@ def __init__(self): self.variables: list[str] = [] self.seasons: list[ClimoFreq] = ["ANN", "DJF", "MAM", "JJA", "SON"] + + # Time slice parameters (mutually exclusive with seasons) + # Index-based time selection for snapshot analysis using individual time indices + # Examples: ["0"], ["5"], ["0", "1", "2"] + self.time_slices: list[TimeSlice] = [] + self.regions: list[str] = ["global"] self.regrid_tool: REGRID_TOOLS = "xesmf" @@ -201,7 +208,13 @@ def __init__(self): # TODO: Need documentation on these attributes here and # here: https://e3sm-project.github.io/e3sm_diags/_build/html/main/available-parameters.html self.dataset: str = "" - self.granulate: list[str] = ["variables", "seasons", "plevs", "regions"] + self.granulate: list[str] = [ + "variables", + "seasons", + "plevs", + "regions", + "time_slices", + ] self.selectors: list[str] = ["sets", "seasons"] self.viewer_descr: dict[str, str] = {} self.fail_on_incomplete: bool = False @@ -264,10 +277,40 @@ def check_values(self): msg = "You need to define both the 'test_start_yr' and 'test_end_yr' parameter." raise RuntimeError(msg) + if self.time_slices: + self._validate_time_slice_format() + + def _validate_time_slice_format(self) -> None: + """Validate that time_slice follows the expected format. + + Time slices must be non-negative integer indices representing + individual time steps in the dataset. + + Parameters + ---------- + time_slice : str + The time slice string to validate. Must be a non-negative integer. + + Raises + ------ + ValueError + If the time slice format is invalid (not a non-negative integer). + """ + # Define the regex pattern for a non-negative integer, including no + # leading zeros except for zero itself. + pattern = r"^(0|[1-9]\d*)$" + + for time_slice in self.time_slices: + if not re.match(pattern, time_slice.strip()): + raise ValueError( + f"Invalid time_slice format: '{time_slice}'. " + f"Expected a non-negative integer index. Examples: '0', '5', '42'" + ) + def _set_param_output_attrs( self, var_key: str, - season: str, + time_selection: TimeSelection, region: str, ref_name: str, ilev: float | None, @@ -278,8 +321,8 @@ def _set_param_output_attrs( ---------- var_key : str The variable key. - season : str - The season. + time_selection : TimeSelection + The time slice or season. region : str The region. ref_name : str @@ -289,18 +332,18 @@ def _set_param_output_attrs( variable is 3D. """ if ilev is None: - output_file = f"{ref_name}-{var_key}-{season}-{region}" - main_title = f"{var_key} {season} {region}" + output_file = f"{ref_name}-{var_key}-{time_selection}-{region}" + main_title = f"{var_key} {time_selection} {region}" else: ilev_str = str(int(ilev)) - output_file = f"{ref_name}-{var_key}-{ilev_str}-{season}-{region}" - main_title = f"{var_key} {ilev_str} mb {season} {region}" + output_file = f"{ref_name}-{var_key}-{ilev_str}-{time_selection}-{region}" + main_title = f"{var_key} {ilev_str} mb {time_selection} {region}" self.output_file = output_file self.main_title = main_title def _set_name_yrs_attrs( - self, ds_test: Dataset, ds_ref: Dataset, season: TimeSelection | None + self, ds_test: Dataset, ds_ref: Dataset, time_selection: TimeSelection | None ): """Set the test_name_yrs and ref_name_yrs attributes. @@ -310,11 +353,11 @@ def _set_name_yrs_attrs( The test dataset object used for setting ``self.test_name_yrs``. ds_ref : Dataset The ref dataset object used for setting ``self.ref_name_yrs``. - season : TimeSelection | None - The optional frequency for climatology or time slice. + time_selection : TimeSelection | None + The optional time slice or season. """ - self.test_name_yrs = ds_test.get_name_yrs_attr(season) - self.ref_name_yrs = ds_ref.get_name_yrs_attr(season) + self.test_name_yrs = ds_test.get_name_yrs_attr(time_selection) + self.ref_name_yrs = ds_ref.get_name_yrs_attr(time_selection) def _is_plevs_set(self): if (isinstance(self.plevs, np.ndarray) and not self.plevs.all()) or ( @@ -369,7 +412,7 @@ def _run_diag(self) -> list[Any]: return results - def _add_time_series_file_path_attr( + def _add_time_series_filepath_attr( self, data_type: Literal["test", "ref"], ds: xr.Dataset, @@ -391,11 +434,11 @@ def _add_time_series_file_path_attr( if data_type not in {"test", "ref"}: raise ValueError("data_type must be either 'test' or 'ref'.") - file_path_attr = f"{data_type}_data_file_path" + filepath_attr = f"{data_type}_data_file_path" - setattr(self, file_path_attr, getattr(ds, "file_path", "Unknown")) + setattr(self, filepath_attr, getattr(ds, "file_path", "Unknown")) - def _add_climatology_file_path_attr( + def _add_filepath_attr( self, data_type: Literal["test", "ref"], filepath: str | None = None, @@ -407,24 +450,65 @@ def _add_climatology_file_path_attr( data_type : Literal["test", "ref"] The type of data, either "test" or "ref". filepath : str | None, optional - The file path for climatology data. + The file path for climatology or time-slice data. Raises ------ ValueError If `data_type` is not "test" or "ref". ValueError - If `filepath` is not provided for climatology data. + If `filepath` is not provided for climatology or time-slice data. """ if data_type not in {"test", "ref"}: raise ValueError("data_type must be either 'test' or 'ref'.") - file_path_attr = f"{data_type}_data_file_path" + filepath_attr = f"{data_type}_data_file_path" if not filepath: - raise ValueError("Filepath must be provided for climatology data.") + raise ValueError( + "Filepath must be provided for the climatology or time-slice data." + ) + + setattr(self, filepath_attr, os.path.abspath(filepath)) + + def _get_time_selection_to_use( + self, require_one: bool = True + ) -> tuple[Literal["time_slices", "seasons"], list[TimeSlice] | list[ClimoFreq]]: + """ + Determine the time selection type and corresponding values. + + If ``time_slices`` are specified, they take precedence over ``seasons``. + + Parameters + ---------- + require_one : bool, optional + If True, ensures that at least one of `seasons` or `time_slices` is + specified. Default is True. + + Returns + ------- + tuple[Literal["time_slices", "seasons"], list[TimeSlice] | list[ClimoFreq]] + A tuple containing the time selection type ("time_slices" or "seasons") + and the corresponding list of values. + + Raises + ------ + RuntimeError + If neither `seasons` nor `time_slices` are specified when `require_one` + is True. + """ + if require_one and not (self.seasons or self.time_slices): + raise RuntimeError( + "Must specify either 'seasons' or 'time_slices'.\n" + "- Use 'seasons' for climatological analysis (e.g., ['ANN', 'DJF']).\n" + "- Use 'time_slices' for snapshot-based selection (e.g., ['0'], ['5'])." + ) + + # Time slices take precedence over seasons if both are specified. + if self.time_slices: + return "time_slices", self.time_slices - setattr(self, file_path_attr, os.path.abspath(filepath)) + return "seasons", self.seasons def __setattr__(self, name: str, value: Any) -> None: """Override setattr to ensure year attributes are padded when set.""" diff --git a/e3sm_diags/parameter/lat_lon_native_parameter.py b/e3sm_diags/parameter/lat_lon_native_parameter.py index 78a7da395..b4e6de0b4 100644 --- a/e3sm_diags/parameter/lat_lon_native_parameter.py +++ b/e3sm_diags/parameter/lat_lon_native_parameter.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING from e3sm_diags.parameter.core_parameter import CoreParameter @@ -33,83 +32,12 @@ def __init__(self): # Option to disable the grid antialiasing (may improve performance) self.antialiased = False - # Time selection parameters (mutually exclusive with seasons) - # Either use seasons (inherited from CoreParameter) OR time_slices - # Index-based time selection with stride support - # Examples: ["0:10:2", "5:15", "7"] for start:end:stride, start:end, or single index - self.time_slices: list[TimeSlice] = [] - def check_values(self): - """Verifies that required values are properly set. - - Raises - ------ - RuntimeError - If no grid files are provided or set. - RuntimeError - If neither seasons nor time_slices are specified. - """ - has_seasons = len(self.seasons) > 0 - has_time_slices = len(self.time_slices) > 0 - - if not has_seasons and not has_time_slices: - raise RuntimeError( - "Must specify either 'seasons' or 'time_slices'. " - "Use 'seasons' for climatological analysis (e.g., ['ANN', 'DJF']) " - "or 'time_slices' for index-based selection (e.g., ['0:10:2', '5:15'])." - ) - - # Validate time_slice format if provided - if has_time_slices: - for time_slice in self.time_slices: - self._validate_time_slice_format(time_slice) - + """Verifies that required values are properly set.""" # TODO: For now, we'll make grid file check a soft check. In the future, # we may want to require at least test_grid_file pass - def _validate_time_slice_format(self, time_slice: str): - r"""Validate that time_slice follows the expected format. - - This regex pattern for slice notation is designed to match a - latitude/longitude-like format with optional degrees, minutes, and - seconds. - - ^: Matches the start of the string. - - (-?\d+|): Matches an optional integer (can be negative) for degrees. - - (?::(-?\d+|): Matches an optional colon followed by an optional - integer (can be negative) for minutes. - - (?::(-?\d+|)): Matches an optional colon followed by an optional - integer (can be negative) for seconds. - - )?: Makes the minutes and seconds groups optional. - - $: Matches the end of the string. - - Valid formats: - - "index" (single index): "5" - - "start:end" (range): "0:10" - - "start:end:stride" (range with stride): "0:10:2" - - ":end" (from beginning): ":10" - - "start:" (to end): "5:" - - "::stride" (full range with stride): "::2" - - Parameters - ---------- - time_slice : str - The time slice string to validate - - Raises - ------ - ValueError - If the time slice format is invalid - """ - pattern = r"^(-?\d+|)(?::(-?\d+|)(?::(-?\d+|))?)?$" - - if not re.match(pattern, time_slice.strip()): - raise ValueError( - f"Invalid time_slice format: '{time_slice}'. " - f"Expected formats: 'index', 'start:end', 'start:end:stride', " - f"':end', 'start:', or '::stride'. Examples: '5', '0:10', '0:10:2'" - ) - def _set_name_yrs_attrs( self, test_ds: Dataset, ref_ds: Dataset, season: TimeSelection | None ): @@ -127,21 +55,22 @@ def _set_name_yrs_attrs( from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQS if season is None or season in CLIMO_FREQS: - # Standard climatology season, use parent implementation. super()._set_name_yrs_attrs(test_ds, ref_ds, season) else: - # This is a time slice string, handle it specially. - self._set_time_slice_attrs(test_ds, ref_ds, season) + self._set_time_slice_name_yrs_attrs(test_ds, ref_ds, season) - def _set_time_slice_attrs(self, test_ds: Dataset, ref_ds: Dataset, time_slice: str): - """Set attributes for time slice-based processing. + def _set_time_slice_name_yrs_attrs( + self, test_ds: Dataset, ref_ds: Dataset, time_slice: TimeSlice + ) -> None: + """Set name_yrs attributes for time slice-based processing. - This method sets up the necessary attributes for file naming and + This function sets up the necessary attributes for file naming and processing when using time_slices instead of seasons. - Store the time slice info but keep current_set as the diagnostic set name - current_set should remain as "lat_lon_native" for proper directory structure - The time slice will be used in filename generation via other attributes + NOTE: This method is only used by lat_lon_native. Other diagnostic sets + call self._set_name_yrs_attrs() from the parent class directly, even + with time slices. We may want to refactor this in the future for + consistency. Parameters ---------- @@ -149,16 +78,19 @@ def _set_time_slice_attrs(self, test_ds: Dataset, ref_ds: Dataset, time_slice: s The test dataset object. ref_ds : Dataset The reference dataset object. - time_slice : str + time_slice : TimeSlice The time slice specification. + + Notes + ----- + This function modifies the parameter object in-place by setting: + - parameter.current_time_slice + - parameter.test_name_yrs + - parameter.ref_name_yrs """ # Set the time slice info for potential use in plotting/output self.current_time_slice = time_slice - # For time slices, we manually set the name_yrs attributes instead of - # calling parent method to avoid issues with the dataset's get_name_yrs_attr - # expecting a valid season - # Set test_name_yrs - use test dataset years if available, otherwise use # time slice info try: diff --git a/e3sm_diags/parser/core_parser.py b/e3sm_diags/parser/core_parser.py index 07a2fe6ce..3666bade5 100644 --- a/e3sm_diags/parser/core_parser.py +++ b/e3sm_diags/parser/core_parser.py @@ -321,6 +321,16 @@ def add_arguments(self): required=False, ) + self.parser.add_argument( + "--time_slices", + nargs="+", + dest="time_slices", + help="Time slices to use (mutually exclusive with seasons). " + + "Individual time indices for snapshot-based analysis. " + + "Examples: '0' (single index), '5' (single index), or multiple like '0' '1' '2'.", + required=False, + ) + self.parser.add_argument( "-r", "--regions", diff --git a/e3sm_diags/parser/lat_lon_native_parser.py b/e3sm_diags/parser/lat_lon_native_parser.py index be5ec8bbc..1dba400fa 100644 --- a/e3sm_diags/parser/lat_lon_native_parser.py +++ b/e3sm_diags/parser/lat_lon_native_parser.py @@ -53,6 +53,7 @@ def add_arguments(self): "--time_slices", dest="time_slices", nargs="+", - help="Time slices for snapshot-based analysis (e.g., '0', '0:10:2', '5:15'). Mutually exclusive with seasons.", + help="Individual time indices for snapshot-based analysis (e.g., '0', '5', or multiple like '0' '1' '2'). " + + "Mutually exclusive with seasons.", required=False, ) diff --git a/e3sm_diags/viewer/default_viewer.py b/e3sm_diags/viewer/default_viewer.py index 7b03b90bb..343a417e7 100644 --- a/e3sm_diags/viewer/default_viewer.py +++ b/e3sm_diags/viewer/default_viewer.py @@ -279,8 +279,12 @@ def seasons_used(parameters: list[CoreParameter]) -> list[str]: } # Return sorted time slices if they are used + # Sort numerically by converting to int if time_slices_used: - return sorted(time_slices_used) + logger.info(f"Time slices found: {time_slices_used}") + sorted_slices = sorted(time_slices_used, key=lambda x: int(x)) + logger.info(f"Time slices sorted: {sorted_slices}") + return sorted_slices # Otherwise, collect and return seasons used, ordered by SEASONS seasons_used = {season for p in parameters for season in p.seasons} diff --git a/e3sm_diags/viewer/mean_2d_viewer.py b/e3sm_diags/viewer/mean_2d_viewer.py index e5e717834..de15c3f1c 100644 --- a/e3sm_diags/viewer/mean_2d_viewer.py +++ b/e3sm_diags/viewer/mean_2d_viewer.py @@ -30,14 +30,32 @@ def create_viewer(root_dir, parameters): # Sort the parameters so that the viewer is created in the correct order. # Using SEASONS.index(), we make sure we get the parameters in # ['ANN', 'DJF', ..., 'SON'] order instead of alphabetical. - parameters.sort( - key=lambda x: (x.case_id, x.variables[0], SEASONS.index(x.seasons[0])) - ) + # For time_slices, sort numerically by the starting index. + def _get_sort_key(x): + # Get the first time period (either time_slice or season) + if hasattr(x, "time_slices") and len(x.time_slices) > 0: + time_slice = x.time_slices[0] + # Time slices are individual indices (e.g., "0", "5", "42") + # Sort numerically by the index value + return (x.case_id, x.variables[0], int(time_slice)) + else: + season = x.seasons[0] + # For seasons, use SEASONS order + return (x.case_id, x.variables[0], SEASONS.index(season)) + + parameters.sort(key=_get_sort_key) for param in parameters: ref_name = getattr(param, "ref_name", "") for var in param.variables: - for season in param.seasons: + # Handle either seasons or time_slices + time_periods = ( + param.time_slices + if (hasattr(param, "time_slices") and len(param.time_slices) > 0) + else param.seasons + ) + + for season in time_periods: for region in param.regions: try: viewer.set_group(param.case_id) diff --git a/examples/ex8-native-grid-visualization/README.md b/examples/ex8-native-grid-visualization/README.md new file mode 100644 index 000000000..25ba38d26 --- /dev/null +++ b/examples/ex8-native-grid-visualization/README.md @@ -0,0 +1,121 @@ +# Example 8: Native Grid Visualization + +This example demonstrates the **native grid visualization** feature introduced in **E3SM Diags v3.1.0**. + +## What This Example Does + +- Visualizes model data on native grids (e.g., cubed-sphere, unstructured grids) +- Uses UXarray for grid-aware operations +- Compares two models without regridding to a regular lat-lon grid +- Preserves native grid features and characteristics + +## Key Features + +The native grid visualization capability: + +- Supports various native grid formats (cubed-sphere, unstructured, etc.) +- Eliminates artifacts introduced by regridding +- Enables comparison of models with different native grids +- Particularly useful for high-resolution models + +## Key Parameters + +- `LatLonNativeParameter` - Required parameter class for native grid visualization +- `test_grid_file` - Path to test model's grid file (UGRID format) +- `ref_grid_file` - Path to reference model's grid file (optional for model-only runs) +- `time_slices` - Use snapshot analysis instead of climatology (e.g., ["0"]) +- `antialiased` - Whether to apply antialiasing to the plot + +## Data Requirements + +This example uses test data located at LCRC: + +- Data path: `/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid` +- Grid files: `/lcrc/group/e3sm/diagnostics/grids/` + +For your own data, ensure you have: + +1. Model output files on native grid +2. Corresponding grid files in UGRID format + +## Running This Example + +### Using the Python Script + +```bash +# Run with default settings (automatically uses your username for output directory) +python ex8.py + +# Run with custom configuration file +python ex8.py -d diags.cfg +``` + +### Using Command-Line Interface + +```bash +e3sm_diags lat_lon_native \ + --no_viewer \ + --reference_data_path '/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid' \ + --test_data_path '/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid' \ + --results_dir '/lcrc/group/e3sm/public_html/diagnostic_output/$USER/e3sm_diags_examples/ex8_native_grid' \ + --case_id 'model_vs_model' \ + --run_type 'model_vs_model' \ + --sets 'lat_lon_native' \ + --variables 'TGCLDLWP' \ + --time_slices 0 \ + --main_title 'TGCLDLWP 0 global' \ + --contour_levels '10' '25' '50' '75' '100' '125' '150' '175' '200' '225' '250' \ + --short_test_name 'v3.LR.amip_0101' \ + --ref_file 'v3.LR.amip_0101.eam.h0.1989-12.nc' \ + --diff_colormap 'RdBu' \ + --diff_levels '-35' '-30' '-25' '-20' '-15' '-10' '-5' '5' '10' '15' '20' '25' '30' '35' \ + --test_grid_file '/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc' \ + --ref_grid_file '/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc' \ + --test_file 'v3.LR.amip_0101.eam.h0.1989-12.nc' +``` + +**Note:** Use `--no_viewer` for command-line usage to avoid directory creation issues. For HTML viewer output, use the Python script approach instead. + +## Configuration File + +The `diags.cfg` file allows you to customize: + +- Variables to plot (e.g., TGCLDLWP) +- Regions of interest +- Colormap settings +- Contour levels for test, reference, and difference plots + +## Expected Output + +The diagnostic will generate: + +- Native grid visualizations for specified variables +- Test model plot +- Reference model plot +- Difference plot (Test - Reference) +- HTML viewer for browsing results + +Results will be saved in: `/lcrc/group/e3sm/public_html/diagnostic_output/$USER/e3sm_diags_examples/ex8_native_grid/viewer/` + +## Notes + +- Native grid visualization requires the UXarray library, which is included as a dependency of E3SM diagnostics and the E3SM Unified environment. +- Grid files must be in UGRID format +- This example uses `time_slices` for snapshot analysis; you can also use `seasons` for climatology +- For model-only runs (no reference data), set `model_only = True` and omit ref_grid_file + +## Differences from Regular lat_lon Set + +Unlike the standard `lat_lon` set which regrids data to a regular lat-lon grid: + +- `lat_lon_native` preserves the original grid structure +- No interpolation artifacts +- Better representation of native grid features +- Requires grid files in UGRID format + +## More Information + +For more details, see: + +- [E3SM Diags Documentation](https://e3sm-project.github.io/e3sm_diags) +- [UXarray Documentation](https://uxarray.readthedocs.io/) diff --git a/examples/ex8-native-grid-visualization/diags.cfg b/examples/ex8-native-grid-visualization/diags.cfg new file mode 100644 index 000000000..28593574a --- /dev/null +++ b/examples/ex8-native-grid-visualization/diags.cfg @@ -0,0 +1,8 @@ +[#] +sets = ["lat_lon_native"] +case_id = "model_vs_model" +variables = ["TGCLDLWP"] +regions = ["global"] +diff_colormap = "RdBu" +contour_levels = [10, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250] +diff_levels = [-35, -30, -25, -20, -15, -10, -5, 5, 10, 15, 20, 25, 30, 35] diff --git a/examples/ex8-native-grid-visualization/ex8.py b/examples/ex8-native-grid-visualization/ex8.py new file mode 100644 index 000000000..ea9153fa0 --- /dev/null +++ b/examples/ex8-native-grid-visualization/ex8.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Example 8: Native Grid Visualization + +This example demonstrates how to visualize model data on its native grid +(e.g., cubed-sphere, unstructured grids) using UXarray, without regridding +to a regular lat-lon grid. + +This preserves native grid features and is particularly useful for: +- High-resolution models with complex grid structures +- Comparing models with different native grids +- Analyzing grid-dependent features + +This feature was introduced in E3SM Diags v3.1.0. +""" + +import os + +from e3sm_diags.parameter.lat_lon_native_parameter import LatLonNativeParameter +from e3sm_diags.run import runner + +# Auto-detect username +username = os.environ.get('USER', 'unknown_user') + +# Create parameter object +param = LatLonNativeParameter() + +# Location of the data +param.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid" +param.test_file = "v3.LR.amip_0101.eam.h0.1989-12.nc" + +param.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/native_grid" +param.ref_file = "v3.LR.amip_0101.eam.h0.1989-12.nc" + +# Short names for display +param.short_test_name = "v3.LR.amip_0101" +param.short_ref_name = "v3.HR.test4" + +# Native grid files +# These specify the grid structure for native visualization +param.test_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc" +param.ref_grid_file = "/lcrc/group/e3sm/diagnostics/grids/ne30pg2.nc" + +# Time selection: use snapshot instead of climatology +# Use time_slices to analyze a specific time index +param.time_slices = ["0"] # First time step + +# Comparison settings +param.case_id = "model_vs_model" +param.run_type = "model_vs_model" + +# Antialiasing setting +param.antialiased = False + +# Name of the folder where the results are stored. +prefix = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/e3sm_diags_examples" +param.results_dir = os.path.join(prefix, "ex8_native_grid") + +# Below are more optional arguments. + +# For running with multiprocessing. +# param.multiprocessing = True +# param.num_workers = 32 + +# Run the diagnostic +runner.sets_to_run = ["lat_lon_native"] +runner.run_diags([param]) diff --git a/examples/ex9-snapshot-analysis/README.md b/examples/ex9-snapshot-analysis/README.md new file mode 100644 index 000000000..54045e701 --- /dev/null +++ b/examples/ex9-snapshot-analysis/README.md @@ -0,0 +1,164 @@ +# Example 9: Snapshot Analysis for Core Sets + +This example demonstrates the **snapshot analysis** feature introduced in **E3SM Diags v3.1.0**. + +## What This Example Does + +- Analyzes individual time steps instead of seasonal climatological means +- Uses index-based time selection to examine specific time points +- Demonstrates time_slices parameter on multiple core diagnostic sets +- Compares model states at specific indices without temporal averaging + +## Key Features + +The snapshot analysis capability: +- Enables event-based or process-oriented diagnostics +- Analyzes specific time points without climatological averaging +- Supports multiple time indices analyzed separately +- Works across multiple core diagnostic sets + +## Key Parameters + +- `time_slices` - List of time indices to analyze (e.g., ["0", "1", "2"]) + - Time slices are zero-based indices into the time dimension + - ["0"] = first time step + - ["5"] = 6th time step + - ["0", "1", "2"] = first 3 time steps (each analyzed separately) +- **IMPORTANT**: `time_slices` and `seasons` are mutually exclusive + - When using `time_slices`, do NOT set `seasons` + +## Supported Diagnostic Sets + +The following core diagnostic sets support snapshot analysis: +- `lat_lon` - Latitude-Longitude contour maps +- `lat_lon_native` - Native grid visualization +- `polar` - Polar contour maps +- `zonal_mean_2d` - Pressure-Latitude zonal mean contour plots +- `meridional_mean_2d` - Pressure-Longitude meridional mean contour plots +- `zonal_mean_2d_stratosphere` - Stratospheric zonal mean plots + +## Data Requirements + +This example uses test data located at LCRC: +- Data path: `/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/` + +For your own data, ensure: +1. Model output files contain time dimension +2. Files have sufficient time steps for requested indices +3. Data files are accessible from the diagnostic runs + +## Running This Example + +### Using the Python Script + +```bash +# Edit ex9.py to set your output directory +# Update the `prefix` variable to point to your web directory + +# Run with default settings +python ex9.py + +# Run with custom configuration file +python ex9.py -d diags.cfg +``` + +### Using Command-Line Interface (example) + +```bash +e3sm_diags zonal_mean_2d \ + --no_viewer \ + --reference_data_path '/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr' \ + --test_data_path '/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr' \ + --results_dir '/lcrc/group/e3sm/public_html/diagnostic_output/$USER/e3sm_diags_examples/ex9_snapshot_analysis' \ + --case_id 'model_vs_model' \ + --run_type 'model_vs_model' \ + --sets 'zonal_mean_2d' \ + --variables 'T' \ + --time_slices 0 \ + --multiprocessing \ + --main_title 'T 0 global' \ + --contour_levels '180' '185' '190' '200' '210' '220' '230' '240' '250' '260' '270' '280' '290' '295' '300' \ + --short_test_name 'v2 test' \ + --ref_file 'T_005101_006012.nc' \ + --diff_levels '-3.0' '-2.5' '-2' '-1.5' '-1' '-0.5' '-0.25' '0.25' '0.5' '1' '1.5' '2' '2.5' '3.0' \ + --test_file 'T_005101_006012.nc' +``` + +**Note:** Use `--no_viewer` for command-line usage to avoid directory creation issues. For HTML viewer output, use the Python script approach instead. + +**Important**: Do not use both `--time_slices` and `--seasons` in the same command! + +## Configuration File + +The `diags.cfg` file allows you to customize settings for each diagnostic set: +- Variables to plot (e.g., T for temperature) +- Pressure levels for 3D variables (e.g., 850.0 mb) +- Regions of interest (e.g., polar_S, polar_N) +- Colormap settings +- Contour levels for test, reference, and difference plots +- Regridding method (for lat_lon set) + +## Expected Output + +The diagnostic will generate: +- Plots for each time slice specified +- Test model plot at each time index +- Reference model plot at each time index +- Difference plots (Test - Reference) at each time index +- HTML viewer with columns for each time slice + +Results will be saved in: `/ex9_snapshot_analysis/viewer/` + +In the viewer: +- Rows represent different variables/regions/levels +- Columns represent different time slices (0, 1, 2, etc.) +- Click on any cell to see detailed plots + +## Notes + +- Time slices are **zero-based indices** (0 = first time step, 1 = second, etc.) +- Each time slice is analyzed **separately** (not averaged together) +- The viewer displays time slices as column headers instead of seasons +- Time slices are sorted **numerically** (0, 1, 2, ..., not alphabetically) +- Make sure your data files have enough time steps for the requested indices + +## Differences from Seasonal Climatology + +Unlike seasonal climatology analysis which uses `seasons = ["ANN", "DJF", "JJA", "SON"]`: +- Snapshot analysis uses `time_slices = ["0", "1", "2", ...]` +- No temporal averaging - analyzes exact time points +- Useful for event-based studies and temporal evolution +- Can analyze any arbitrary time index in your dataset + +## Use Cases + +Snapshot analysis is particularly useful for: +1. **Event Studies**: Analyzing specific weather events or phenomena +2. **Model Spin-up**: Examining early time steps in model initialization +3. **Temporal Evolution**: Tracking how fields change over successive time steps +4. **Intercomparison**: Comparing models at synchronized time points +5. **Debugging**: Investigating specific time steps with unusual behavior + +## Combining with Native Grid Visualization + +You can combine snapshot analysis with native grid visualization: + +```python +from e3sm_diags.parameter.lat_lon_native_parameter import LatLonNativeParameter + +param = LatLonNativeParameter() +param.time_slices = ["0", "5", "10"] # Snapshot analysis +param.test_grid_file = "/path/to/grid.nc" # Native grid +# ... other parameters ... + +runner.sets_to_run = ["lat_lon_native"] +runner.run_diags([param]) +``` + +This combines both v3.1.0 features for snapshot analysis on native grids! + +## More Information + +For more details, see: +- [E3SM Diags Documentation](https://e3sm-project.github.io/e3sm_diags) +- [E3SM Diags README - v3.1.0 Features](https://github.com/E3SM-Project/e3sm_diags#new-features-in-v310) diff --git a/examples/ex9-snapshot-analysis/diags.cfg b/examples/ex9-snapshot-analysis/diags.cfg new file mode 100644 index 000000000..d899960e4 --- /dev/null +++ b/examples/ex9-snapshot-analysis/diags.cfg @@ -0,0 +1,44 @@ +[#] +sets = ["lat_lon"] +case_id = "model_vs_model" +variables = ["T"] +plevs = [850.0] +contour_levels = [240, 245, 250, 255, 260, 265, 270, 275, 280, 285, 290, 295] +diff_levels = [-5, -4, -3, -2, -1, -0.5, -0.25, 0.25, 0.5, 1, 2, 3, 4, 5] +regrid_method = "bilinear" + +[#] +sets = ["polar"] +case_id = "model_vs_model" +variables = ["T"] +regions = ["polar_S", "polar_N"] +plevs = [850.0] +contour_levels = [230, 240, 250, 260, 270, 280, 290, 300, 310] +diff_levels = [-15, -10, -7.5, -5, -2.5, -1, 1, 2.5, 5, 7.5, 10, 15] + +[#] +sets = ["zonal_mean_2d"] +case_id = "model_vs_model" +variables = ["T"] +contour_levels = [180, 185, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 295, 300] +diff_levels = [-3.0, -2.5, -2, -1.5, -1, -0.5, -0.25, 0.25, 0.5, 1, 1.5, 2, 2.5, 3.0] + +[#] +sets = ["zonal_mean_2d_stratosphere"] +case_id = "model_vs_model" +variables = ["T"] +contour_levels = [180, 185, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 295, 300] +diff_levels = [-8, -6, -4, -2, -1, -0.5, 0.5, 1, 2, 4, 6, 8] + +[#] +sets = ["zonal_mean_xy"] +case_id = "model_vs_model" +variables = ["T"] +plevs = [850.0] + +[#] +sets = ["meridional_mean_2d"] +case_id = "model_vs_model" +variables = ["T"] +contour_levels = [180, 185, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 295, 300] +diff_levels = [-7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 6, 7] diff --git a/examples/ex9-snapshot-analysis/ex9.py b/examples/ex9-snapshot-analysis/ex9.py new file mode 100644 index 000000000..a5c91afdc --- /dev/null +++ b/examples/ex9-snapshot-analysis/ex9.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" +Example 9: Snapshot Analysis for Core Sets + +This example demonstrates time slice analysis on core diagnostic sets. +Instead of computing climatological seasonal means, this analyzes individual +time steps from model output using index-based time selection. + +This is useful for: +- Analyzing specific events or time periods +- Comparing model states at specific time points +- Understanding temporal evolution without time averaging +- Event-based or process-oriented diagnostics + +This feature was introduced in E3SM Diags v3.1.0. + +Supported diagnostic sets: +- lat_lon +- lat_lon_native +- polar +- zonal_mean_2d +- meridional_mean_2d +- zonal_mean_2d_stratosphere +""" + +import os + +from e3sm_diags.parameter.core_parameter import CoreParameter +from e3sm_diags.run import runner + +# Auto-detect username +username = os.environ.get('USER', 'unknown_user') + +# Create parameter object +param = CoreParameter() + +# Location of the data +param.test_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr" +param.test_file = "T_005101_006012.nc" + +param.reference_data_path = "/lcrc/group/e3sm/public_html/e3sm_diags_test_data/postprocessed_e3sm_v2_data_for_e3sm_diags/20210528.v2rc3e.piControl.ne30pg2_EC30to60E2r2.chrysalis/time-series/rgr" +param.ref_file = "T_005101_006012.nc" + +# Short names for display +param.short_test_name = "v2 test" +param.short_ref_name = "v2 test" + +# Key difference: Use time_slices instead of seasons +# Time slices are zero-based indices into the time dimension +# Examples: +# ["0"] = first time step +# ["5"] = 6th time step +# ["0", "1", "2"] = first 3 time steps (analyzed separately) +param.time_slices = ["0", "1"] + +# IMPORTANT: time_slices and seasons are mutually exclusive +# When using time_slices, do NOT set param.seasons + +# Comparison settings +param.case_id = "model_vs_model" +param.run_type = "model_vs_model" + +# Name of the folder where the results are stored. +# Change `prefix` to use your directory. +prefix = f"/lcrc/group/e3sm/public_html/diagnostic_output/{username}/e3sm_diags_examples" +param.results_dir = os.path.join(prefix, "ex9_snapshot_analysis") + +# Below are more optional arguments. + +# For running with multiprocessing. +# param.multiprocessing = True +# param.num_workers = 32 + +# Run the diagnostics on multiple sets +runner.sets_to_run = [ + "lat_lon", + "zonal_mean_xy", + "zonal_mean_2d", + "zonal_mean_2d_stratosphere", + "polar", + "meridional_mean_2d", +] + +runner.run_diags([param]) diff --git a/examples/run_all_sets_E3SM_machines.py b/examples/run_all_sets_E3SM_machines.py index 6b0f74b1c..6054e85e2 100644 --- a/examples/run_all_sets_E3SM_machines.py +++ b/examples/run_all_sets_E3SM_machines.py @@ -66,7 +66,7 @@ def run_all_sets(): "JJA", ] # Default setting: seasons = ["ANN", "DJF", "MAM", "JJA", "SON"] - param.results_dir = f"{machine_paths['html_path']}/v2_9_0_all_sets" + param.results_dir = f"{machine_paths['html_path']}/v3_1_0_all_sets" param.multiprocessing = True param.num_workers = 24 diff --git a/tests/e3sm_diags/driver/test_lat_lon_driver.py b/tests/e3sm_diags/driver/test_lat_lon_driver.py index fca03cfe9..619bcd253 100644 --- a/tests/e3sm_diags/driver/test_lat_lon_driver.py +++ b/tests/e3sm_diags/driver/test_lat_lon_driver.py @@ -3,7 +3,7 @@ import pytest import xarray as xr -from e3sm_diags.driver.lat_lon_driver import _get_ref_climo_dataset +from e3sm_diags.driver.lat_lon_driver import _get_ref_dataset from e3sm_diags.driver.utils.dataset_xr import Dataset from tests.e3sm_diags.driver.utils.test_dataset_xr import ( _create_parameter_object, @@ -155,7 +155,7 @@ def test_raises_error_if_dataset_data_type_is_not_ref(self): ds = Dataset(parameter, data_type="test") with pytest.raises(RuntimeError): - _get_ref_climo_dataset(ds, "ts", "ANN") + _get_ref_dataset(ds, "ts", "ANN", is_time_slice=False) def test_returns_reference_climo_dataset_from_file(self): parameter = _create_parameter_object( @@ -166,7 +166,7 @@ def test_returns_reference_climo_dataset_from_file(self): self.ds_climo.to_netcdf(f"{self.data_path}/{parameter.ref_file}") ds = Dataset(parameter, data_type="ref") - result = _get_ref_climo_dataset(ds, "ts", "ANN") + result = _get_ref_dataset(ds, "ts", "ANN", is_time_slice=False) expected = self.ds_climo.squeeze(dim="time").drop_vars("time") xr.testing.assert_identical(result, expected) @@ -178,6 +178,6 @@ def test_returns_None_if_climo_dataset_not_found(self): parameter.ref_file = "ref_file.nc" ds = Dataset(parameter, data_type="ref") - result = _get_ref_climo_dataset(ds, "ts", "ANN") + result = _get_ref_dataset(ds, "ts", "ANN", is_time_slice=False) assert result is None diff --git a/tests/e3sm_diags/driver/utils/test_dataset_xr.py b/tests/e3sm_diags/driver/utils/test_dataset_xr.py index 6637a11a4..f9d6dc3b6 100644 --- a/tests/e3sm_diags/driver/utils/test_dataset_xr.py +++ b/tests/e3sm_diags/driver/utils/test_dataset_xr.py @@ -318,6 +318,193 @@ def test_property_is_timeseries_returns_false_and_is_climo_returns_true_for_ref( assert ds.is_climo +class TestGetTimeSlicedDataset: + @pytest.fixture(autouse=True) + def setup(self, tmp_path): + self.data_path = tmp_path / "input_data" + self.data_path.mkdir() + + # Set up time series dataset and save to a temp file. + self.ts_path = f"{self.data_path}/ts_200001_200112.nc" + self.ds_ts = xr.Dataset( + coords={ + **spatial_coords, + "time": xr.DataArray( + dims="time", + data=np.array( + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + dtype=object, + ), + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + data_vars={ + **spatial_bounds, + "time_bnds": xr.DataArray( + name="time_bnds", + data=np.array( + [ + [ + cftime.DatetimeGregorian( + 2000, 1, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 2, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + [ + cftime.DatetimeGregorian( + 2000, 3, 1, 0, 0, 0, 0, has_year_zero=False + ), + cftime.DatetimeGregorian( + 2000, 4, 1, 0, 0, 0, 0, has_year_zero=False + ), + ], + ], + dtype=object, + ), + dims=["time", "bnds"], + ), + "ts": xr.DataArray( + xr.DataArray( + data=np.array( + [ + [[1.0, 1.0], [1.0, 1.0]], + [[2.0, 2.0], [2.0, 2.0]], + [[3.0, 3.0], [3.0, 3.0]], + ] + ), + dims=["time", "lat", "lon"], + ) + ), + }, + ) + self.ds_ts.to_netcdf(self.ts_path) + + def test_raises_error_if_var_arg_is_not_valid(self): + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + ds = Dataset(parameter, data_type="ref") + + with pytest.raises(ValueError): + ds.get_time_sliced_dataset(var=1, time_slice="0") # type: ignore + + with pytest.raises(ValueError): + ds.get_time_sliced_dataset(var="", time_slice="0") + + def test_raises_error_if_file_not_found(self): + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + parameter.ref_file = "non_existent_file.nc" + ds = Dataset(parameter, data_type="ref") + + with pytest.raises(RuntimeError, match="File not found:"): + ds.get_time_sliced_dataset(var="ts", time_slice="0") + + def test_raises_error_if_filepath_is_none(self): + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + # Set ref_file to empty string to simulate None filepath. + parameter.ref_file = "" + ds = Dataset(parameter, data_type="ref") + + with pytest.raises( + RuntimeError, match="Unable to get file path for ref dataset" + ): + ds.get_time_sliced_dataset(var="ts", time_slice="0") + + def test_raises_error_if_xarray_fails_to_open_file(self, monkeypatch): + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + parameter.ref_file = "ts_200001_200112.nc" + ds = Dataset(parameter, data_type="ref") + + def mock_open_dataset(*args, **kwargs): + raise OSError("Simulated xarray open_dataset failure.") + + monkeypatch.setattr(xr, "open_dataset", mock_open_dataset) + + with pytest.raises( + RuntimeError, + match="Failed to open dataset", + ): + ds.get_time_sliced_dataset(var="ts", time_slice="0") + + def test_returns_time_sliced_dataset(self, caplog): + # Silence logger warning to not pollute test suite. + caplog.set_level(logging.CRITICAL) + + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + parameter.ref_file = "ts_200001_200112.nc" + ds = Dataset(parameter, data_type="ref") + + result = ds.get_time_sliced_dataset(var="ts", time_slice="1") + expected = self.ds_ts.isel(time=1) + + xr.testing.assert_identical(result, expected) + + def test_raises_error_if_time_slice_is_invalid(self, caplog): + # Silence logger warning to not pollute test suite. + caplog.set_level(logging.CRITICAL) + + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + parameter.ref_file = "ts_200001_200112.nc" + ds = Dataset(parameter, data_type="ref") + + with pytest.raises( + IndexError, + match="Time slice index 5 is out of bounds for time dimension of size 3.", + ): + ds.get_time_sliced_dataset(var="ts", time_slice="5") + + def test_returns_original_dataset_if_no_time_dim_is_found(self): + parameter = _create_parameter_object( + "ref", "time_series", self.data_path, "2000", "2001" + ) + parameter.ref_file = "ts_200001_200112.nc" + ds = Dataset(parameter, data_type="ref") + + # Remove time dimension to simulate no time dim found. + ds_ts_no_time = self.ds_ts.drop_vars("time").drop_dims("time") + ds_ts_no_time.to_netcdf(f"{self.data_path}/{parameter.ref_file}") + + result = ds.get_time_sliced_dataset(var="ts", time_slice="0") + expected = ds_ts_no_time + + xr.testing.assert_identical(result, expected) + + class TestGetClimoDataset: @pytest.fixture(autouse=True) def setup(self, tmp_path): diff --git a/tests/e3sm_diags/driver/utils/test_io.py b/tests/e3sm_diags/driver/utils/test_io.py index 1fc281db1..f13140c35 100644 --- a/tests/e3sm_diags/driver/utils/test_io.py +++ b/tests/e3sm_diags/driver/utils/test_io.py @@ -2,18 +2,111 @@ import os from copy import deepcopy from pathlib import Path +from unittest.mock import MagicMock import pytest import xarray as xr from e3sm_diags.driver.utils.io import ( + DatasetResult, _get_output_dir, + _get_xarray_datasets, _write_vars_to_netcdf, _write_vars_to_single_netcdf, ) from e3sm_diags.parameter.core_parameter import CoreParameter +class TestGetXarrayDatasets: + @pytest.fixture(autouse=True) + def setup(self): + self.test_ds = MagicMock() + self.ref_ds = MagicMock() + self.var_key = "ts" + self.time_selection = "DJF" + + def test_fetches_datasets_for_time_slices(self): + self.test_ds.get_time_sliced_dataset.return_value = xr.Dataset( + data_vars={"ts": xr.DataArray(name="ts", data=[1, 2, 3])} + ) + self.ref_ds.get_time_sliced_dataset.return_value = xr.Dataset( + data_vars={"ts": xr.DataArray(name="ts", data=[4, 5, 6])} + ) + self.test_ds._get_land_sea_mask.return_value = xr.Dataset( + data_vars={"mask": xr.DataArray(name="mask", data=[1, 0, 1])} + ) + + result = _get_xarray_datasets( + self.test_ds, + self.ref_ds, + self.var_key, + "time_slices", + self.time_selection, + get_land_sea_mask=True, + ) + + assert isinstance(result, DatasetResult) + xr.testing.assert_identical( + result.ds_test, self.test_ds.get_time_sliced_dataset() + ) + xr.testing.assert_identical( + result.ds_ref, self.ref_ds.get_time_sliced_dataset() + ) + xr.testing.assert_identical( + result.ds_land_sea_mask, self.test_ds._get_land_sea_mask("ANN") + ) + + def test_fetches_datasets_for_seasons(self): + self.test_ds.get_climo_dataset.return_value = xr.Dataset( + data_vars={"ts": xr.DataArray(name="ts", data=[7, 8, 9])} + ) + self.ref_ds.get_climo_dataset.return_value = xr.Dataset( + data_vars={"ts": xr.DataArray(name="ts", data=[10, 11, 12])} + ) + self.test_ds._get_land_sea_mask.return_value = xr.Dataset( + data_vars={"mask": xr.DataArray(name="mask", data=[0, 1, 0])} + ) + + result = _get_xarray_datasets( + self.test_ds, + self.ref_ds, + self.var_key, + "seasons", + self.time_selection, + get_land_sea_mask=True, + ) + + assert isinstance(result, DatasetResult) + xr.testing.assert_identical(result.ds_test, self.test_ds.get_climo_dataset()) + xr.testing.assert_identical(result.ds_ref, self.ref_ds.get_climo_dataset()) + xr.testing.assert_identical( + result.ds_land_sea_mask, + self.test_ds._get_land_sea_mask(self.time_selection), + ) + + def test_does_not_fetch_land_sea_mask_when_disabled(self): + self.test_ds.get_climo_dataset.return_value = xr.Dataset( + data_vars={"ts": xr.DataArray(name="ts", data=[13, 14, 15])} + ) + self.ref_ds.get_climo_dataset.return_value = xr.Dataset( + data_vars={"ts": xr.DataArray(name="ts", data=[16, 17, 18])} + ) + + result = _get_xarray_datasets( + self.test_ds, + self.ref_ds, + self.var_key, + "seasons", + self.time_selection, + get_land_sea_mask=False, + ) + + assert isinstance(result, DatasetResult) + xr.testing.assert_identical(result.ds_test, self.test_ds.get_climo_dataset()) + xr.testing.assert_identical(result.ds_ref, self.ref_ds.get_climo_dataset()) + assert result.ds_land_sea_mask is None + + class TestWriteVarsToNetcdf: @pytest.fixture(autouse=True) def setup(self, tmp_path: Path): diff --git a/tests/e3sm_diags/test_parameters.py b/tests/e3sm_diags/test_parameters.py index 6261abf11..8918b2546 100644 --- a/tests/e3sm_diags/test_parameters.py +++ b/tests/e3sm_diags/test_parameters.py @@ -82,6 +82,41 @@ def test_check_values_raises_error_if_test_timeseries_input_and_no_test_start_an with pytest.raises(RuntimeError): param.check_values() + @pytest.mark.parametrize( + "time_slices, expected_error", + [ + ( + ["not_an_integer"], + r"Invalid time_slice format: 'not_an_integer'. Expected a non-negative integer index. Examples: '0', '5', '42'", + ), + ( + ["-2000"], + r"Invalid time_slice format: '-2000'. Expected a non-negative integer index. Examples: '0', '5', '42'", + ), + ( + ["0010"], + r"Invalid time_slice format: '0010'. Expected a non-negative integer index. Examples: '0', '5', '42'", + ), + ( + ["1"], + None, # No error expected for valid input + ), + ], + ) + def test_check_values_time_slices(self, time_slices, expected_error): + param = CoreParameter() + param.reference_data_path = "path" + param.test_data_path = "path" + param.results_dir = "path" + param.time_slices = time_slices + + if expected_error: + with pytest.raises(ValueError, match=expected_error): + param.check_values() + else: + # Should not raise any error for valid input + param.check_values() + @pytest.mark.xfail def test_returns_parameter_with_results(self): # FIXME: This test will while we refactor sets and utilities. It should @@ -145,55 +180,102 @@ def test_logs_exception_if_driver_run_diag_function_fails(self, caplog): class TestCoreParameterAdditionalMethods: - def test_add_time_series_file_path_attr_valid(self): + def test_add_time_series_filepath_attr_valid(self): param = CoreParameter() ds = xr.Dataset(attrs={"file_path": "/path/to/test/file.nc"}) - param._add_time_series_file_path_attr("test", ds) + param._add_time_series_filepath_attr("test", ds) assert param.test_data_file_path == "/path/to/test/file.nc" - def test_add_time_series_file_path_attr_invalid_data_type(self): + def test_add_time_series_filepath_attr_invalid_data_type(self): param = CoreParameter() ds = xr.Dataset(attrs={"file_path": "/path/to/test/file.nc"}) with pytest.raises( ValueError, match="data_type must be either 'test' or 'ref'." ): - param._add_time_series_file_path_attr("invalid", ds) # type: ignore + param._add_time_series_filepath_attr("invalid", ds) # type: ignore - def test_add_time_series_file_path_attr_missing_file_path(self): + def test_add_time_series_filepath_attr_missing_file_path(self): param = CoreParameter() ds = xr.Dataset() - param._add_time_series_file_path_attr("test", ds) + param._add_time_series_filepath_attr("test", ds) assert param.test_data_file_path == "Unknown" - def test_add_climatology_file_path_attr_valid(self): + def test_add_filepath_attr_valid(self): param = CoreParameter() filepath = "/path/to/climatology/file.nc" - param._add_climatology_file_path_attr("ref", filepath) + param._add_filepath_attr("ref", filepath) assert param.ref_data_file_path == os.path.abspath(filepath) - def test_add_climatology_file_path_attr_invalid_data_type(self): + def test_add_filepath_attr_invalid_data_type(self): param = CoreParameter() filepath = "/path/to/climatology/file.nc" with pytest.raises( ValueError, match="data_type must be either 'test' or 'ref'." ): - param._add_climatology_file_path_attr("invalid", filepath) # type: ignore + param._add_filepath_attr("invalid", filepath) # type: ignore - def test_add_climatology_file_path_attr_missing_filepath(self): + def test_add_filepath_attr_missing_filepath(self): param = CoreParameter() with pytest.raises( - ValueError, match="Filepath must be provided for climatology data." + ValueError, + match="Filepath must be provided for the climatology or time-slice data.", ): - param._add_climatology_file_path_attr("test", None) + param._add_filepath_attr("test", None) + + def test_get_time_selection_to_use_raises_error_if_neither_seasons_nor_time_slices_specified( + self, + ): + param = CoreParameter() + param.seasons = [] + param.time_slices = [] + + with pytest.raises( + RuntimeError, match="Must specify either 'seasons' or 'time_slices'." + ): + param._get_time_selection_to_use() + + def test_get_time_selection_to_use_returns_time_slices_if_specified(self): + param = CoreParameter() + param.seasons = ["ANN", "DJF"] + param.time_slices = ["0", "5"] + + selection_type, values = param._get_time_selection_to_use() + + assert selection_type == "time_slices" + assert values == ["0", "5"] + + def test_get_time_selection_to_use_returns_seasons_if_time_slices_not_specified( + self, + ): + param = CoreParameter() + param.seasons = ["ANN", "DJF"] + param.time_slices = [] + + selection_type, values = param._get_time_selection_to_use() + + assert selection_type == "seasons" + assert values == ["ANN", "DJF"] + + def test_get_time_selection_to_use_does_not_raise_error_if_require_one_is_false( + self, + ): + param = CoreParameter() + param.seasons = [] + param.time_slices = [] + + selection_type, values = param._get_time_selection_to_use(require_one=False) + + assert selection_type == "seasons" + assert values == [] def test_ac_zonal_mean_parameter(): diff --git a/tests/e3sm_diags/test_parsers.py b/tests/e3sm_diags/test_parsers.py index 9e6f87a32..5c43305a8 100644 --- a/tests/e3sm_diags/test_parsers.py +++ b/tests/e3sm_diags/test_parsers.py @@ -58,6 +58,7 @@ def setup(self): "plot_plevs", "plot_log_plevs", "seasons", + "time_slices", "regions", "regrid_tool", "regrid_method",