diff --git a/docs/img/floodforecaster_source_domain.gif b/docs/img/floodforecaster_source_domain.gif new file mode 100644 index 0000000000..eefb3198db Binary files /dev/null and b/docs/img/floodforecaster_source_domain.gif differ diff --git a/docs/img/floodforecaster_target_domain.gif b/docs/img/floodforecaster_target_domain.gif new file mode 100644 index 0000000000..ecc5cba2f4 Binary files /dev/null and b/docs/img/floodforecaster_target_domain.gif differ diff --git a/examples/weather/flood_modeling/flood_forecaster/README.md b/examples/weather/flood_modeling/flood_forecaster/README.md new file mode 100644 index 0000000000..8a092fcc18 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/README.md @@ -0,0 +1,392 @@ +# FloodForecaster: A Domain-Adaptive Geometry-Informed Neural Operator Framework for Rapid Flood Forecasting + +FloodForecaster is a deep learning framework for rapid, high-resolution flood forecasting that leverages a time-dependent Geometry-Informed Neural Operator (GINO) with domain adaptation capabilities. The framework enables accurate, real-time flood predictions by learning from source domain data and adapting to target domains through adversarial training, delivering predictions of water depth and velocity across unstructured spatial meshes. + +## Problem Overview + +Flooding is one of the most destructive and widespread natural hazards, causing significant socioeconomic damage worldwide. Rapid urbanization and climate change are intensifying flood risks, making vulnerable population centers even more susceptible. To effectively mitigate these escalating risks, rapid and high-resolution flood forecasts are essential for enabling timely public warnings, efficient emergency response, and strategic resource deployment. + +Traditionally, flood forecasting relies on physically based numerical models that solve the shallow water equations (SWEs). While accurate, these models demand immense computational resources, especially when simulating large geographical areas at fine-scale resolutions (e.g., 1–10 m grid cells) needed to capture complex topographies and flow paths. This computational burden is prohibitive for many real-world applications, such as generating rapid inundation forecasts from meteorological predictions or running large ensembles needed to quantify forecast uncertainty within tight operational deadlines. + +FloodForecaster addresses these challenges by offering a computationally efficient surrogate model that: +- **Processes complex geometries**: Leverages unstructured spatial meshes to handle irregular terrain +- **Enables domain transfer**: Incorporates domain adaptation to transfer knowledge from data-rich source domains to data-scarce target domains +- **Achieves superior performance**: Demonstrates superior accuracy and stability over state-of-the-art GNN baselines +- **Requires minimal data**: Successfully adapts with as few as 10 training simulations from a new domain, reducing prediction error by approximately 75% compared to standard fine-tuning + +## Model Overview and Architecture + +### Core Architecture + +FloodForecaster uses a three-stage training pipeline to achieve domain-adaptive flood prediction: + +1. **Pretraining**: Train a GINO model on source domain data to learn fundamental flood dynamics +2. **Domain Adaptation**: Fine-tune the model using adversarial training with a domain classifier to learn domain-invariant features +3. **Rollout Evaluation**: Perform autoregressive multi-step predictions and compute comprehensive evaluation metrics + +### Key Components + +#### GINO (Geometry-Informed Neural Operator) + +The core predictive model synergizes: +- **Graph Neural Operators (GNO)**: Process irregular terrain and extract geometric features from unstructured meshes +- **Fourier Neural Operators (FNO)**: Efficiently capture global flood dynamics through spectral processing + +GINO processes: +- **Static features**: Elevation, slope, curvature, roughness (Manning's n), cell area +- **Dynamic features**: Water depth, velocity components (Vx, Vy) over time +- **Boundary conditions**: Inflow hydrographs at upstream boundaries + +#### Domain Adaptation with Gradient Reversal Layer + +The framework integrates a domain adaptation technique using a gradient reversal layer (GRL), which encourages the model to learn domain-invariant physical features. This approach: +- **Prevents catastrophic forgetting**: Unlike standard fine-tuning, preserves the model's original expertise while learning hydraulics of new river segments +- **Enables data-efficient adaptation**: Achieves strong performance with minimal target domain data +- **Uses adversarial training**: A CNN-based domain classifier distinguishes between source and target domains, with gradients reversed during backpropagation to encourage domain-invariant representations + +### Key Features + +- **Autoregressive forecasting**: Multi-step predictions using a sliding window of historical states +- **Physics-informed evaluation**: Metrics including volume conservation, arrival time, and inundation duration +- **Comprehensive metrics**: RMSE, CSI (Critical Success Index), MAE for temporal characteristics, and FHCA (Flood Hazard Classification Accuracy) + +## Data Generation + +Data generation utilities for creating synthetic hydrographs and automating HEC-RAS simulations are available in a separate repository: + +**Data Generation Repository**: [https://github.com/MehdiTaghizadehUVa/FloodForecaster](https://github.com/MehdiTaghizadehUVa/FloodForecaster) + +This repository includes: +- Synthetic hydrograph generation utilities +- HEC-RAS simulation automation scripts +- Data preprocessing and formatting tools + +Please refer to the [FloodForecaster repository](https://github.com/MehdiTaghizadehUVa/FloodForecaster) for data generation documentation and usage examples. + +## Dataset + +FloodForecaster expects data organized in a structured format with the following components: + +### Spatial Mesh Files +- `M40_XY.txt`: Cell center coordinates (N × 2) +- `M40_CA.txt`: Cell area (N × 1) + +### Static Attribute Files +- `M40_CE.txt`: Elevation +- `M40_CS.txt`: Slope +- `M40_CU.txt`: Curvature +- `M40_FA.txt`: Roughness (Manning's n) +- `M40_A.txt`: Additional static features + +### Dynamic Time-Series Files (per timestep `t`) +- `M40_WD_{t}.txt`: Water depth (N × 1) +- `M40_VX_{t}.txt`: X-velocity component (N × 1) +- `M40_VY_{t}.txt`: Y-velocity component (N × 1) + +### Boundary Condition Files (per timestep `t`) +- `M40_US_InF_{t}.txt`: Upstream inflow hydrograph + +Data should be organized in folders with consistent naming patterns as specified in the configuration. The model supports both source and target domain datasets for domain adaptation training. + +**Data Generation Scripts**: Scripts for generating the dataset used in the paper are available at: https://github.com/MehdiTaghizadehUVa/FloodForecaster. This repository provides utilities for creating synthetic hydrographs and automating HEC-RAS simulations to generate training data, but does not include the pre-generated dataset itself. + +## Quick Start + +### Installation + +1. Install required dependencies: + +```bash +pip install -r requirements.txt +``` + +2. Prepare your dataset following the structure described above. Organize source and target domain data in separate directories. + +3. Configure training parameters in `conf/config.yaml`: + - Set `source_data.root` and `target_data.root` to your data paths + - You can use environment variables: `DATA_ROOT`, `TARGET_DATA_ROOT`, `ROLLOUT_DATA_ROOT` + - Or directly edit the paths in the config file + - Adjust model parameters (channels, FNO modes, hidden dimensions) + - Configure training epochs, learning rates, and batch sizes + - Set domain adaptation hyperparameters (lambda_max, classifier architecture) + +### Training + +Run the training script: + +```bash +python train.py +``` + +The training pipeline will: +- Pretrain on source domain data +- Perform domain adaptation on combined source and target data +- Save checkpoints after each stage + +Training logs, model checkpoints, and metrics will be saved in the directory specified in `config.yaml`. + +**Resuming Training:** + +To resume training from a checkpoint, set the appropriate checkpoint path in `conf/config.yaml`: + +- **Resume pretraining**: Set `checkpoint.resume_from_source` to the pretraining checkpoint directory (e.g., `"./checkpoints_flood_forecaster/pretrain"`). This will resume the source domain pretraining stage from the saved checkpoint. + +- **Resume domain adaptation**: Set `checkpoint.resume_from_adapt` to the domain adaptation checkpoint directory (e.g., `"./checkpoints_flood_forecaster/adapt"`). This will resume the domain adaptation stage from the saved checkpoint. + +- **For inference**: Set `checkpoint.resume_from_adapt` (preferred) or `checkpoint.resume_from_source` to load a trained model. The inference script will use `resume_from_adapt` if available, otherwise falls back to `resume_from_source`. + +Checkpoints are saved in subdirectories under `checkpoint.save_dir`: +- `{save_dir}/pretrain/` - Contains pretraining checkpoints +- `{save_dir}/adapt/` - Contains domain adaptation checkpoints + +### Inference + +To perform autoregressive rollout and generate evaluation visualizations: + +1. Configure your inference settings in `conf/config.yaml`: + - Set `rollout_data.root` to your test dataset path + - Configure `rollout.out_dir` for output directory + - Adjust `rollout_length` and `skip_before_timestep` as needed + +2. Run the inference script: + +```bash +python inference.py +``` + +**Note**: The inference script currently does not support multi-GPU or multi-node distributed inference. It runs on a single GPU/device. For distributed inference, you would need to modify the rollout logic to split samples across ranks. + +3. The script will output comprehensive visualizations and metrics: + - **Publication maps**: Water depth and velocity comparisons at selected time steps (12, 24, 36, 48, 60, 72 hours) + - **Maximum value maps**: Peak water depth and velocity across the entire event + - **Combined analysis plots**: Temporal characteristics (arrival time, duration), hazard metrics (max momentum flux), and classification accuracy + - **Volume conservation plots**: Total water volume over time for both predictions and ground truth + - **Conditional error analysis**: Error distributions conditioned on water depth and velocity magnitude + - **Rollout animations**: GIF animations showing temporal evolution of water depth and velocity components (3×2 grid: GT vs. Predicted) + - **Aggregated metrics**: Time-series metrics (RMSE, CSI) and scalar metrics (MAE, FHCA) aggregated across all test events + - **Event magnitude analysis**: RMSE vs. peak inflow and total volume relationships + +All outputs are saved to the configured output directory with organized subdirectories for figures and animations. + +### Example Results + +The following animations demonstrate FloodForecaster's performance on source and target domain data: + +**Source Domain Rollout:** +

+ Source domain rollout animation +

+ +**Target Domain Rollout (T1):** +

+ Target domain rollout animation +

+ +These animations show the temporal evolution of water depth and velocity components, comparing ground truth (left panels) with model predictions (right panels) across multiple time steps. The model demonstrates accurate flood forecasting capabilities in both source and target domains, with successful domain adaptation enabling transfer to new river segments. + +## Dataset Loading + +The dataset is handled via custom dataset classes defined in the `datasets/` module: + +- **`FloodDatasetWithQueryPoints`**: Loads raw flood simulation data and generates query points for GINO's latent and output representations +- **`NormalizedDataset`**: Wraps the raw dataset with normalization using `UnitGaussianNormalizer` for static, dynamic, boundary, and target fields +- **`NormalizedRolloutTestDataset`**: Specialized dataset for rollout evaluation that preserves run IDs and cell area information + +The datasets automatically: +- Load and concatenate static, dynamic, and boundary features +- Generate query point grids for GINO's coordinate-based processing +- Normalize features using statistics computed from training data +- Handle variable-length time series and multiple simulation runs + +To use the datasets, they are instantiated through the training and inference pipelines, which handle data splitting, normalization fitting, and DataLoader creation automatically. + +## Evaluation Metrics + +FloodForecaster computes comprehensive evaluation metrics: + +### Time-Series Metrics + +- **RMSE (Root Mean Square Error)**: For water depth (WD) and velocity components (Vx, Vy) at each time step +- **CSI (Critical Success Index)**: Binary classification accuracy at thresholds of 0.05m and 0.3m water depth + +### Scalar Hydrological Metrics + +- **Arrival Time MAE**: Mean absolute error in flood arrival time (hours) +- **Inundation Duration MAE**: Mean absolute error in flood duration (hours) +- **Max Momentum Flux RMSE**: RMSE of maximum h·V² (m³/s²) across the event +- **FHCA (Flood Hazard Classification Accuracy)**: Classification accuracy for flood hazard categories + +### Physics-Informed Metrics + +- **Volume Conservation**: Total water volume over time, comparing predictions to ground truth +- **Conditional Error Analysis**: Error distributions conditioned on water depth and velocity magnitude + +All metrics are aggregated across multiple test events and saved as both visualizations and numerical data (`.npz` files) for further analysis. + +## Configuration + +Key configuration sections in `conf/config.yaml`: + +### Data Paths + +```yaml +source_data: + root: "${DATA_ROOT:/path/to/source/data}" + xy_file: "M40_XY.txt" + static_files: ["M40_XY.txt", "M40_CA.txt", ...] + dynamic_patterns: + WD: "M40_WD_{}.txt" + VX: "M40_VX_{}.txt" + VY: "M40_VY_{}.txt" + boundary_patterns: + inflow: "M40_US_InF_{}.txt" +``` + +### Model Settings + +```yaml +model: + model_arch: 'gino' # Note: This codebase is hardcoded for GINO architecture + data_channels: 20 + out_channels: 3 + fno_n_modes: [16, 16] + fno_hidden_channels: 64 + gno_embed_channels: 32 +``` + +**Note**: While `model_arch` is a parameter for neuralop's `get_model` function, the FloodForecaster codebase is specifically designed for the GINO (Geometry-Informed Neural Operator) architecture. The code includes GINO-specific wrappers (`GINOWrapper`), data processors (`FloodGINODataProcessor`), and domain adaptation logic that assumes GINO's architecture. Changing `model_arch` to a different model type would require significant code modifications to support other architectures. + +### Training Settings + +```yaml +training: + n_epochs_source: 100 + n_epochs_adapt: 50 + learning_rate: 1e-4 + batch_size: 8 + da_lambda_max: 1.0 + da_class_loss_weight: 0.0 +``` + +### Distributed Computing + +FloodForecaster uses `physicsnemo`'s `DistributedManager` for distributed training, which automatically detects and configures the distributed environment. The framework supports: + +- **torchrun**: For PyTorch-native distributed training +- **mpirun**: For OpenMPI-based distributed training +- **SLURM**: For cluster-based distributed training + +**Configuration:** + +The `distributed` section in `config.yaml` contains minimal settings: + +```yaml +distributed: + seed: 123 + device: 'cuda:0' # Fallback device for non-distributed execution +``` + +**Key Points:** + +- **Device Assignment**: When running in distributed mode (via `torchrun` or `mpirun`), the device is automatically set to `cuda:{local_rank}` for each process. The `device` field in the config is only used as a fallback for single-GPU/CPU execution. + +- **No Manual Configuration Needed**: `DistributedManager.initialize()` automatically detects: + - Number of processes (`world_size`) + - Process rank (`rank`) + - Local rank (`local_rank`) + - Appropriate device assignment + +- **Example Distributed Training:** + + ```bash + # Single node, multiple GPUs + torchrun --standalone --nnodes=1 --nproc_per_node=4 train.py + + # Multi-node (example) + torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr= train.py + ``` + +The framework handles all distributed setup automatically - you don't need to specify device lists or wireup configurations. + +### Domain Adaptation + +```yaml +da_classifier: + conv_layers: + - out_channels: 64 + kernel_size: 3 + pool_size: 2 + fc_dim: 1 +``` + +## Logging + +FloodForecaster supports logging via [Weights & Biases (W&B)](https://wandb.ai/): + +- Training and validation losses for both pretraining and domain adaptation +- Domain classification loss during adversarial training +- Learning rate schedule +- Model checkpoints and training state + +Set up W&B by modifying `wandb.log`, `wandb.project`, and `wandb.entity` in `config.yaml`. The framework also uses `physicsnemo`'s `PythonLogger` for distributed training and standard logging. + +## Project Structure + +``` +FloodForecaster/ +├── conf/ +│ └── config.yaml # Hydra configuration file +├── datasets/ # Dataset classes +│ ├── flood_dataset.py # Raw dataset loader +│ ├── normalized_dataset.py # Normalized training dataset +│ └── rollout_dataset.py # Rollout evaluation dataset +├── data_processing/ # Data preprocessing +│ └── data_processor.py # GINO data processor and wrappers +├── training/ # Training modules +│ ├── pretraining.py # Source domain pretraining +│ └── domain_adaptation.py # Domain adaptation fine-tuning +├── inference/ # Inference modules +│ └── rollout.py # Autoregressive rollout and evaluation +├── utils/ # Utility functions +│ ├── normalization.py # Data normalization utilities +│ └── plotting.py # Visualization and plotting functions +├── train.py # Main training script +├── inference.py # Main inference script +└── README.md +``` + +## Notes + +- **Batching**: GINO supports batching when geometry is shared across samples. For variable geometries, use `batch_size: 1`. +- **GPU Requirements**: GPU with 24GB+ VRAM recommended for larger meshes (>10,000 cells) and longer rollouts. +- **Domain Adaptation**: The model uses adversarial domain adaptation to improve generalization. The gradient reversal lambda can be scheduled during training for improved stability. This approach prevents catastrophic forgetting and enables data-efficient adaptation with as few as 10 training simulations from a new domain. +- **Autoregressive Error**: Long rollouts may accumulate prediction errors. The model uses a sliding window of historical states to mitigate this. +- **Framework Integration**: This example uses Hydra for configuration management and `physicsnemo` utilities for distributed training and logging, following NVIDIA Modulus (PhysicsNeMo) framework patterns. + +## Citation + +If you use FloodForecaster in your research, please cite: + +```bibtex +@article{taghizadeh2025floodforecaster, + title = {FloodForecaster: A domain-adaptive geometry-informed neural operator framework for rapid flood forecasting}, + author = {Taghizadeh, Mehdi and Zandsalimi, Zanko and Nabian, Mohammad Amin and Goodall, Jonathan L. and Alemazkoor, Negin}, + journal = {Journal of Hydrology}, + volume = {664}, + pages = {134512}, + year = {2026}, + publisher = {Elsevier}, + doi = {10.1016/j.jhydrol.2025.134512}, + url = {https://doi.org/10.1016/j.jhydrol.2025.134512} +} +``` + +## Contact + +For questions, feedback, or collaborations: + +- **Mehdi Taghizadeh** (Code Contributor and Maintainer) – +- **Zanko Zandsalimi** – +- **Mohammad Amin Nabian** – +- **Jonathan L. Goodall** – +- **Negin Alemazkoor** (Corresponding Author) – diff --git a/examples/weather/flood_modeling/flood_forecaster/conf/config.yaml b/examples/weather/flood_modeling/flood_forecaster/conf/config.yaml new file mode 100644 index 0000000000..b6af6209f3 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/conf/config.yaml @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration file for FloodForecaster training. +# This file is used by Hydra to configure the training run. + +hydra: + job: + chdir: True # Change directory to the job's working directory. + run: + dir: ./outputs/ # Directory to save outputs. + +# Distributed computing +# Note: FloodForecaster uses physicsnemo's DistributedManager which automatically +# detects distributed environments (torchrun, mpirun, SLURM). When running in +# distributed mode, the device is automatically set to cuda:{local_rank} for each process. +# The device field below is only used as a fallback for single-GPU/CPU execution. +distributed: + seed: 123 # Random seed for reproducibility (integer) + device: 'cuda:0' # Fallback device for non-distributed execution. Ignored when using torchrun/mpirun. + +# Dataset related for training and one-step evaluation +source_data: + root: "C:/Users/jrj6wm/Box/Flood_Modeling/Simulations/New_Results" # Source domain training data path + resolution: 48 # Spatial resolution of the dataset (integer) + n_history: 3 # Number of historical time steps to use as input (integer) + batch_size: 64 # Batch size for training (integer) + query_res: [48, 48] # Query resolution for GINO model [height, width] (list of 2 integers) + xy_file: "M40_XY.txt" # Filename for XY coordinates (geometry file) + static_files: + # Note: M40_XY.txt is included here as a static feature (XY coordinates as features) + # The xy_file parameter loads it separately for geometry, while static_files includes it as a feature + # This is intentional for flood modeling where XY coordinates are used both for geometry and as features + - "M40_XY.txt" + - "M40_CA.txt" + - "M40_CE.txt" + - "M40_CS.txt" + - "M40_FA.txt" + - "M40_A.txt" + - "M40_CU.txt" + dynamic_patterns: # Filename patterns for dynamic variables ({} will be replaced with timestep) + WD: "M40_WD_{}.txt" # Water depth pattern + VX: "M40_VX_{}.txt" # X-velocity pattern + VY: "M40_VY_{}.txt" # Y-velocity pattern + boundary_patterns: # Filename patterns for boundary conditions ({} will be replaced with timestep) + inflow: "M40_US_InF_{}.txt" # Inflow boundary condition pattern + noise_type: "none" # Type of noise to add to data. Options: "none", "gaussian" (string) + noise_std: [0.01, 0.001, 0.001] # Standard deviation for noise per channel [WD, VX, VY] (list of floats) + rollout_length: 78 # Number of timesteps for autoregressive rollout (integer) + skip_before_timestep: 12 # Number of initial timesteps to skip before starting rollout (integer) + dt: 1200 # Time step size in seconds (float) + +# Target domain dataset +target_data: + root: "C:/Users/jrj6wm/Box/Flood_Modeling/Simulations/Case_4/Results_Target/Train" # Target domain training data path (can be same as source or different) + resolution: 48 # Spatial resolution of the dataset (integer) + n_history: 3 # Number of historical time steps to use as input (integer) + batch_size: 64 # Batch size for domain adaptation training (integer) + query_res: [48, 48] # Query resolution for GINO model [height, width] (list of 2 integers) + xy_file: "M40_XY.txt" # Filename for XY coordinates (geometry file) + static_files: + # Note: M40_XY.txt is included here as a static feature (XY coordinates as features) + # The xy_file parameter loads it separately for geometry, while static_files includes it as a feature + # This is intentional for flood modeling where XY coordinates are used both for geometry and as features + - "M40_XY.txt" + - "M40_CA.txt" + - "M40_CE.txt" + - "M40_CS.txt" + - "M40_FA.txt" + - "M40_A.txt" + - "M40_CU.txt" + dynamic_patterns: # Filename patterns for dynamic variables ({} will be replaced with timestep) + WD: "M40_WD_{}.txt" # Water depth pattern + VX: "M40_VX_{}.txt" # X-velocity pattern + VY: "M40_VY_{}.txt" # Y-velocity pattern + boundary_patterns: # Filename patterns for boundary conditions ({} will be replaced with timestep) + inflow: "M40_US_InF_{}.txt" # Inflow boundary condition pattern + noise_type: "none" # Type of noise to add to data. Options: "none", "gaussian" (string) + noise_std: [0.01, 0.001, 0.001] # Standard deviation for noise per channel [WD, VX, VY] (list of floats) + rollout_length: 78 # Number of timesteps for autoregressive rollout (integer) + skip_before_timestep: 12 # Number of initial timesteps to skip before starting rollout (integer) + dt: 1200 # Time step size in seconds (float) + +# Rollout evaluation dataset +rollout_data: + root: "C:/Users/jrj6wm/Box/Flood_Modeling/Simulations/New_Results/Test_20_Paper" # Test/evaluation data path + xy_file: "M40_XY.txt" # Filename for XY coordinates (geometry file) + static_files: + # Note: M40_XY.txt is included here as a static feature (XY coordinates as features) + # The xy_file parameter loads it separately for geometry, while static_files includes it as a feature + # This is intentional for flood modeling where XY coordinates are used both for geometry and as features + - "M40_XY.txt" + - "M40_CA.txt" + - "M40_CE.txt" + - "M40_CS.txt" + - "M40_FA.txt" + - "M40_A.txt" + - "M40_CU.txt" + dynamic_patterns: # Filename patterns for dynamic variables ({} will be replaced with timestep) + WD: "M40_WD_{}.txt" # Water depth pattern + VX: "M40_VX_{}.txt" # X-velocity pattern + VY: "M40_VY_{}.txt" # Y-velocity pattern + boundary_patterns: # Filename patterns for boundary conditions ({} will be replaced with timestep) + inflow: "M40_US_InF_{}.txt" # Inflow boundary condition pattern + +# Model configuration (for neuralop get_model compatibility) +# Note: While model_arch is a parameter for neuralop's get_model, the FloodForecaster codebase +# is specifically designed for the GINO architecture. Changing model_arch would require significant code modifications. +model: + model_arch: 'gino' # Model architecture (string, currently only 'gino' is supported) + autoregressive: true # Enable autoregressive residual connection for time-stepping (boolean: true/false) + data_channels: 20 # Number of input data channels (integer) + out_channels: 3 # Number of output channels (integer, typically 3 for WD, VX, VY) + latent_feature_channels: null # Number of latent feature channels (integer or null) + projection_channel_ratio: 4 # Channel expansion ratio for projection layers (float) + gno_coord_dim: 2 # Coordinate dimension for Graph Neural Operator (integer, 2 for 2D) + in_gno_radius: 0.1 # Input GNO radius for neighbor search (float) + out_gno_radius: 0.1 # Output GNO radius for neighbor search (float) + in_gno_transform_type: 'linear' # Input GNO transform type (string: 'linear', 'nonlinear', etc.) + out_gno_transform_type: 'linear' # Output GNO transform type (string: 'linear', 'nonlinear', etc.) + gno_weighting_function: null # GNO weighting function (string or null) + gno_weight_function_scale: 1.0 # Scale factor for GNO weighting function (float) + in_gno_pos_embed_type: 'transformer' # Input positional embedding type (string) + out_gno_pos_embed_type: 'transformer' # Output positional embedding type (string) + fno_in_channels: 20 # FNO input channels (integer) + fno_n_modes: [16, 16] # FNO number of Fourier modes per dimension [modes_x, modes_y] (list of integers) + fno_hidden_channels: 64 # FNO hidden channel dimension (integer) + fno_lifting_channel_ratio: 2 # FNO channel expansion ratio for lifting (float) + fno_n_layers: 4 # Number of FNO layers (integer) + gno_embed_channels: 32 # GNO embedding channel dimension (integer) + gno_embed_max_positions: 10000 # Maximum positions for positional embedding (integer) + in_gno_channel_mlp_hidden_layers: [80, 80, 80] # Input GNO MLP hidden layer sizes (list of integers) + out_gno_channel_mlp_hidden_layers: [512, 256] # Output GNO MLP hidden layer sizes (list of integers) + gno_use_open3d: false # Use Open3D for neighbor search (boolean) + gno_use_torch_scatter: false # Use torch_scatter for operations (boolean) + out_gno_tanh: null # Apply tanh activation to output GNO (boolean or null) + fno_resolution_scaling_factor: null # FNO resolution scaling factor (float or null) + fno_block_precision: 'full' # FNO block precision (string: 'full', 'half', etc.) + fno_use_channel_mlp: true # Use channel MLP in FNO blocks (boolean) + fno_channel_mlp_dropout: 0 # Dropout rate for FNO channel MLP (float, 0.0 to 1.0) + fno_channel_mlp_expansion: 0.5 # Channel expansion ratio for FNO MLP (float) + fno_norm: 'instance_norm' # Normalization type for FNO (string: 'instance_norm', 'layer_norm', etc.) + fno_ada_in_features: 16 # FNO adaptive input features (integer) + fno_ada_in_dim: 1 # FNO adaptive input dimension (integer) + fno_preactivation: false # Use preactivation in FNO blocks (boolean) + fno_skip: 'linear' # FNO skip connection type (string: 'linear', 'identity', etc.) + fno_channel_mlp_skip: 'soft-gating' # FNO channel MLP skip type (string) + fno_separable: false # Use separable FNO (boolean) + fno_factorization: 'tucker' # FNO tensor factorization type (string: 'tucker', 'cp', etc.) + fno_rank: 0.4 # FNO factorization rank (float, typically 0.0 to 1.0) + fno_fixed_rank_modes: false # Use fixed rank modes in FNO (boolean) + fno_implementation: 'factorized' # FNO implementation type (string) + +# Checkpoint configuration +checkpoint: + save_dir: "./checkpoints_flood_forecaster" + # Resume training from checkpoint: + # - resume_from_source: Path to pretraining checkpoint directory (e.g., "./checkpoints_flood_forecaster/pretrain") + # Used to resume pretraining stage. Set to null to start from scratch. + # - resume_from_adapt: Path to domain adaptation checkpoint directory (e.g., "./checkpoints_flood_forecaster/adapt") + # Used to resume domain adaptation stage or load model for inference. + # For inference, this takes precedence over resume_from_source. + resume_from_source: null # Path to pretraining checkpoint directory, or null to start from scratch + resume_from_adapt: null # Path to domain adaptation checkpoint directory, or null + # Pretraining checkpoint saving options + save_best: "source_val_l2" # Metric to monitor for best model saving (e.g., "source_val_l2") + # Set to null to disable best model saving + save_every: null # Save checkpoint every N epochs (e.g., 10). Set to null to disable interval saving + # Note: save_best takes precedence over save_every if both are set + +# Rollout output directory +rollout: + out_dir: "./rollout_outputs" # Directory to save rollout evaluation outputs (string) + +# Optimization settings +training: + n_epochs: 50 # Total epochs (fallback if n_epochs_source not specified) + n_epochs_source: 30 # Number of epochs for source domain pretraining (integer) + n_epochs_adapt: 20 # Number of epochs for domain adaptation (integer) + learning_rate: 1e-4 # Learning rate for pretraining (float, typically 1e-5 to 1e-3) + adapt_learning_rate: 1e-4 # Learning rate for domain adaptation (float, typically 1e-5 to 1e-3) + training_loss: 'l2' # Training loss function. Available options: 'l1' (L1/LpLoss with p=1), 'l2' (L2/LpLoss with p=2) + testing_loss: 'l2' # Testing/evaluation loss function. Available options: 'l1' (L1/LpLoss with p=1), 'l2' (L2/LpLoss with p=2) + weight_decay: 1e-4 # Weight decay for optimizer (float, typically 1e-5 to 1e-3) + amp_autocast: false # Enable automatic mixed precision training (boolean: true/false) + scheduler: 'StepLR' # Learning rate scheduler. Available options: 'StepLR', 'ReduceLROnPlateau', 'CosineAnnealingLR' + scheduler_T_max: 200 # Maximum number of iterations for CosineAnnealingLR (integer) + scheduler_patience: 5 # Patience for ReduceLROnPlateau (number of epochs with no improvement, integer) + step_size: 50 # Period of learning rate decay for StepLR (integer, epochs) + gamma: 0.5 # Multiplicative factor for learning rate decay (float, typically 0.1 to 0.9) + da_class_loss_weight: 0.0 # Weight for domain classification adversarial loss (float, 0.0 disables adversarial training) + da_lambda_max: 1.0 # Maximum lambda value for domain adaptation gradient reversal (float) + da_classifier: # Domain classifier architecture for adversarial domain adaptation + conv_layers: # Convolutional layers for domain classifier + - out_channels: 64 # Number of output channels for this conv layer (integer) + kernel_size: 3 # Convolution kernel size (integer) + pool_size: 2 # Pooling size after convolution (integer) + fc_dim: 1 # Fully connected layer dimension (integer, output dimension for binary classification) + +# Weights & Biases logging +wandb: + log: false # Enable Weights & Biases logging (boolean: true/false) + name: null # Run name for W&B (string or null, null uses auto-generated name) + group: 'flood-experiments' # Experiment group name for organizing runs (string) + project: 'Flood_GINO_NoPhysics' # W&B project name (string) + entity: 'uva_mehdi' # W&B entity/username (string) + sweep: false # Enable W&B hyperparameter sweep mode (boolean: true/false) + log_output: true # Log model outputs to W&B (boolean: true/false) + eval_interval: 1 # Evaluation logging interval in epochs (integer) + diff --git a/examples/weather/flood_modeling/flood_forecaster/data_processing/__init__.py b/examples/weather/flood_modeling/flood_forecaster/data_processing/__init__.py new file mode 100644 index 0000000000..d3c5be8944 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/data_processing/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Data processing modules for flood prediction.""" + +from .data_processor import FloodGINODataProcessor, GINOWrapper, LpLossWrapper + +__all__ = ["FloodGINODataProcessor", "GINOWrapper", "LpLossWrapper"] + diff --git a/examples/weather/flood_modeling/flood_forecaster/data_processing/data_processor.py b/examples/weather/flood_modeling/flood_forecaster/data_processing/data_processor.py new file mode 100644 index 0000000000..d7054ee29c --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/data_processing/data_processor.py @@ -0,0 +1,1130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Data processor for GINO flood prediction model. + +Compatible with neuralop 2.0.0 API. +""" + +from typing import Any, Dict, Optional, Tuple, Union +import warnings + +import torch +import torch.nn as nn + +import physicsnemo +from physicsnemo.models.meta import ModelMetaData + +try: + from jaxtyping import Float + HAS_JAXTYPING = True +except ImportError: + HAS_JAXTYPING = False + # Fallback type alias for when jaxtyping is not available + Float = None + + +def _create_custom_physicsnemo_wrapper(torch_model_instance: nn.Module) -> physicsnemo.models.Module: + r""" + Create a custom PhysicsNeMo wrapper that preserves the original model's forward signature. + + This wrapper is used to convert PyTorch modules to PhysicsNeMo modules while maintaining + full compatibility with the original model's interface, including complex forward signatures + and attribute access. + + Parameters + ---------- + torch_model_instance : nn.Module + The PyTorch module instance to wrap. + + Returns + ------- + physicsnemo.models.Module + A PhysicsNeMo module that wraps the original PyTorch module. + """ + class CustomPhysicsNeMoWrapper(physicsnemo.models.Module): + r""" + Custom PhysicsNeMo wrapper that preserves original model interface. + + This wrapper stores the original PyTorch model as inner_model and delegates + all method calls and attribute access to it, ensuring full compatibility. + """ + + def __init__(self, inner_model: nn.Module): + r""" + Initialize the wrapper with the original PyTorch model. + + Parameters + ---------- + inner_model : nn.Module + The original PyTorch model instance to wrap. + """ + model_name = inner_model.__class__.__name__ + super().__init__(meta=ModelMetaData(name=f"{model_name}PhysicsNeMo")) + # Set inner_model as direct attribute first (for immediate access) + # Then register it as a submodule (for PyTorch parameter/buffer tracking) + object.__setattr__(self, 'inner_model', inner_model) + if isinstance(inner_model, nn.Module): + # Register in _modules for proper PyTorch tracking + # This will also set self.inner_model, but we've already set it above + # so this ensures it's in _modules for proper module tracking + self._modules['inner_model'] = inner_model + + # CRITICAL: Remove inner_model from _args so PhysicsNeMo's save process + # doesn't try to serialize it. inner_model is a PyTorch module and would + # cause an error. The model state is stored in state_dict, not in _args. + # This is safe because inner_model is just an implementation detail for + # wrapping - the actual model parameters are in the module's state_dict. + if hasattr(self, '_args') and 'inner_model' in self._args.get('__args__', {}): + # Remove inner_model from serializable args + # We'll reconstruct it from state_dict during loading if needed + del self._args['__args__']['inner_model'] + + def forward(self, *args, **kwargs): + r""" + Forward pass that preserves the original model's signature. + + All arguments are passed through to the inner model, preserving + the original forward signature completely. + + Parameters + ---------- + *args + Positional arguments for the original model's forward method. + **kwargs + Keyword arguments for the original model's forward method. + + Returns + ------- + Any + Output from the inner model's forward method. + """ + return self.inner_model(*args, **kwargs) + + def __getattr__(self, name: str): + r""" + Delegate attribute access to the inner model. + + This ensures that all attributes, properties, and methods of the + original model are accessible through the wrapper. + + Parameters + ---------- + name : str + Name of the attribute to access. + + Returns + ------- + Any + The requested attribute from the inner model. + + Raises + ------ + AttributeError + If the attribute is not found in either the wrapper or inner model. + """ + # If we're asking for inner_model itself, try multiple ways to get it + if name == 'inner_model': + # Try direct attribute access first + try: + return object.__getattribute__(self, 'inner_model') + except AttributeError: + # Try _modules dict (where add_module stores it) + if 'inner_model' in self._modules: + return self._modules['inner_model'] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute 'inner_model'. " + f"This should not happen - inner_model should be set during __init__." + ) + + # For wrapper attributes, use parent class + if name in ['meta', '_args']: + return super().__getattribute__(name) + + # Get inner_model - try multiple ways + inner_model = None + try: + inner_model = object.__getattribute__(self, 'inner_model') + except AttributeError: + # Try _modules dict + if 'inner_model' in self._modules: + inner_model = self._modules['inner_model'] + + if inner_model is None: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute 'inner_model'. " + f"Cannot delegate attribute '{name}'." + ) + + # Try to get from inner model + try: + return getattr(inner_model, name) + except AttributeError: + # Fall back to parent class attributes + try: + return super().__getattribute__(name) + except AttributeError: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}' " + f"and inner model '{type(inner_model).__name__}' also has no attribute '{name}'" + ) + + return CustomPhysicsNeMoWrapper(torch_model_instance) + + +class LpLossWrapper: + r""" + Wrapper around LpLoss that filters out unexpected kwargs. + + The neuralop Trainer calls loss(out, **sample) where sample contains + all keys including model inputs. This wrapper filters to only pass 'y'. + """ + + def __init__(self, loss_fn): + r""" + Initialize LpLoss wrapper. + + Parameters + ---------- + loss_fn : callable + The underlying loss function to wrap. + + Raises + ------ + ValueError + If ``loss_fn`` is None. + """ + if loss_fn is None: + raise ValueError("loss_fn cannot be None") + self.loss_fn = loss_fn + + def __call__(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + r""" + Compute loss, filtering out unexpected kwargs. + + Parameters + ---------- + y_pred : torch.Tensor + Predicted values of shape :math:`(B, D)` where :math:`B` is batch size + and :math:`D` is the number of output dimensions. + y : torch.Tensor, optional + Ground truth values of shape :math:`(B, D)`, optional. + **kwargs + Additional arguments (filtered out, not used). + + Returns + ------- + torch.Tensor + Loss value as a scalar tensor. + """ + # Ignore all kwargs except y - silently filter out model input keys + return self.loss_fn(y_pred, y) + + +class GINOWrapper(physicsnemo.Module): + r""" + Enhanced wrapper around GINO model that adds enhanced functionality. + + This wrapper adds: + 1. Autoregressive residual connection support + 2. Feature extraction support (return_features) + 3. Filters out unexpected kwargs + + The neuralop Trainer calls model(**sample) where sample contains + both model inputs AND 'y' for loss computation. This wrapper + filters out 'y' before passing to GINO to avoid warnings. + + This wrapper restores functionality from the older GINO version: + - autoregressive: If True, adds residual connection: ``out = x[..., :out_channels] + predicted_delta`` + - return_features: If True, returns (out, latent_embed) tuple for domain adaptation + + Parameters + ---------- + model : nn.Module + The GINO model to wrap. + autoregressive : bool, optional, default=False + If True, enable residual connection for autoregressive time-stepping. + + Forward + ------- + input_geom : torch.Tensor + Input geometry tensor of shape :math:`(n_{in}, D)` or :math:`(1, n_{in}, D)` + where :math:`n_{in}` is the number of input points and :math:`D` is the + coordinate dimension (typically 2). + latent_queries : torch.Tensor + Latent query points of shape :math:`(H, W, D)` or :math:`(1, H, W, D)` + where :math:`H, W` are spatial dimensions. + output_queries : torch.Tensor or Dict[str, torch.Tensor] + Output query points of shape :math:`(n_{out}, D)` or :math:`(1, n_{out}, D)`, + or a dictionary of such tensors. + x : torch.Tensor, optional + Input features of shape :math:`(B, n_{in}, C_{in})` where :math:`B` is batch size + and :math:`C_{in}` is the number of input channels. + latent_features : torch.Tensor, optional + Latent features of shape :math:`(B, H, W, C_{feat})` where :math:`C_{feat}` is + the number of feature channels. + ada_in : torch.Tensor, optional + Adaptive input tensor. + return_features : bool, optional, default=False + If True, return (output, latent_embed) tuple for domain adaptation. + + Outputs + ------- + torch.Tensor or Tuple[torch.Tensor, torch.Tensor] + Model output tensor of shape :math:`(B, n_{out}, C_{out})`, or + (output, features) tuple if ``return_features=True``. Features are the + latent embedding from FNO blocks of shape :math:`(B, C, H, W)` for 2D. + """ + + def __init__(self, model: nn.Module, autoregressive: bool = False): + r""" + Initialize GINO wrapper. + + Parameters + ---------- + model : nn.Module + The GINO model to wrap. If it's a PyTorch module (not PhysicsNeMo), + it will be automatically converted to a PhysicsNeMo module to enable + full PhysicsNeMo checkpoint support. + autoregressive : bool, optional, default=False + If True, enable residual connection for autoregressive time-stepping. + + Raises + ------ + ValueError + If ``model`` is None. + """ + if model is None: + raise ValueError("model cannot be None") + + # === OPTION 1: Auto-convert PyTorch modules to PhysicsNeMo modules === + # IMPORTANT: Convert BEFORE calling super().__init__() so that PhysicsNeMo's + # __new__ captures the converted model in _args, not the original PyTorch model. + # This enables full PhysicsNeMo checkpoint support without requiring + # manual conversion or workarounds. + if isinstance(model, nn.Module) and not isinstance(model, physicsnemo.models.Module): + # Convert PyTorch module to PhysicsNeMo module + try: + # Use custom wrapper that preserves forward signature and attribute access + converted_model = _create_custom_physicsnemo_wrapper(model) + + # Verify the conversion worked + if not isinstance(converted_model, physicsnemo.models.Module): + raise RuntimeError( + f"Conversion failed: result is not a PhysicsNeMo Module, " + f"got {type(converted_model)}" + ) + + # Replace model with converted version BEFORE super().__init__() + # This ensures PhysicsNeMo's __new__ captures the converted model in _args + model = converted_model + + # Optional: Log the conversion (only in debug mode to avoid noise) + import logging + logger = logging.getLogger(__name__) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Auto-converted PyTorch model {model.inner_model.__class__.__name__} " + f"to PhysicsNeMo Module for checkpoint compatibility" + ) + + except Exception as e: + # Fallback: Keep original model but warn user + # The checkpoint saving code will handle this case by saving as PyTorch state_dict + warnings.warn( + f"Failed to convert PyTorch model {model.__class__.__name__} to PhysicsNeMo Module: {e}. " + f"Model will be saved as PyTorch state_dict instead of using PhysicsNeMo's Module.save(). " + f"This is a fallback and should work, but full PhysicsNeMo checkpoint features won't be available.", + UserWarning, + stacklevel=2 + ) + # Continue with original model - the save_checkpoint will handle it + + # Now call super().__init__() - PhysicsNeMo's __new__ has already captured + # the arguments (including the converted model if conversion happened) + # But we need to update _args to ensure the converted model is stored + super().__init__(meta=ModelMetaData(name="GINOWrapper")) + + # Update _args to ensure the converted model is stored (in case conversion + # happened after __new__ captured the original model) + # This is a safety measure - ideally conversion happens before __new__, + # but this ensures _args is correct + if hasattr(self, '_args') and 'model' in self._args.get('__args__', {}): + # Check if the stored model is PyTorch but we have a converted one + stored_model = self._args['__args__']['model'] + if isinstance(stored_model, nn.Module) and not isinstance(stored_model, physicsnemo.models.Module): + # Update _args with the converted model + self._args['__args__']['model'] = model + + # Register as a submodule so it's properly tracked by PyTorch + # This ensures the model is properly stored and accessible + # Always store as both submodule (for PyTorch) and direct attribute (for easy access) + if isinstance(model, nn.Module): + self.add_module('gino', model) + # Also store as direct attribute for easy access (works for both Module and non-Module) + self.gino = model + self.autoregressive = autoregressive + + def forward( + self, + input_geom: torch.Tensor, + latent_queries: torch.Tensor, + output_queries: torch.Tensor, + x: Optional[torch.Tensor] = None, + latent_features: Optional[torch.Tensor] = None, + ada_in: Optional[torch.Tensor] = None, + return_features: bool = False, + **kwargs + ): + r""" + Forward pass through wrapped GINO model with enhanced features. + + This method replicates the new GINO's forward logic but adds: + 1. Autoregressive residual connection support + 2. Feature extraction support (return_features) + + Parameters + ---------- + input_geom : torch.Tensor + Input geometry tensor of shape :math:`(n_{in}, D)` or :math:`(1, n_{in}, D)`. + latent_queries : torch.Tensor + Latent query points of shape :math:`(H, W, D)` or :math:`(1, H, W, D)`. + output_queries : torch.Tensor or Dict[str, torch.Tensor] + Output query points of shape :math:`(n_{out}, D)` or :math:`(1, n_{out}, D)`, + or a dictionary of such tensors. + x : torch.Tensor, optional + Input features of shape :math:`(B, n_{in}, C_{in})`. + latent_features : torch.Tensor, optional + Latent features of shape :math:`(B, H, W, C_{feat})`. + ada_in : torch.Tensor, optional + Adaptive input tensor. + return_features : bool, optional, default=False + If True, return (output, latent_embed) tuple. + **kwargs + Additional arguments (filtered out, including 'y'). + + Returns + ------- + torch.Tensor or Tuple[torch.Tensor, torch.Tensor] + Model output tensor of shape :math:`(B, n_{out}, C_{out})`, or + (output, features) tuple if ``return_features=True``. Features are the + latent embedding from FNO blocks of shape :math:`(B, C, H, W)` for 2D. + """ + ### Input validation + # Skip validation when running under torch.compile for performance + if not torch.compiler.is_compiling(): + # Validate input_geom shape + if input_geom.ndim not in [2, 3]: + raise ValueError( + f"Expected input_geom to be 2D or 3D tensor, got {input_geom.ndim}D tensor " + f"with shape {tuple(input_geom.shape)}" + ) + if input_geom.ndim == 3 and input_geom.shape[0] != 1: + raise ValueError( + f"Expected input_geom with batch dim to have shape (1, n_in, D), " + f"got {tuple(input_geom.shape)}" + ) + + # Validate latent_queries shape + if latent_queries.ndim not in [3, 4]: + raise ValueError( + f"Expected latent_queries to be 3D or 4D tensor, got {latent_queries.ndim}D tensor " + f"with shape {tuple(latent_queries.shape)}" + ) + if latent_queries.ndim == 4 and latent_queries.shape[0] != 1: + raise ValueError( + f"Expected latent_queries with batch dim to have shape (1, H, W, D), " + f"got {tuple(latent_queries.shape)}" + ) + + # Validate x shape if provided + if x is not None: + if x.ndim != 3: + raise ValueError( + f"Expected x to be 3D tensor (B, n_in, C_in), got {x.ndim}D tensor " + f"with shape {tuple(x.shape)}" + ) + # Check consistency with input_geom + n_in_geom = input_geom.shape[-2] if input_geom.ndim == 2 else input_geom.shape[1] + if x.shape[1] != n_in_geom: + raise ValueError( + f"Expected x.shape[1] ({x.shape[1]}) to match input_geom n_in ({n_in_geom})" + ) + + # Filter out unexpected kwargs (e.g., 'y' target) - GINO forward doesn't accept these + # These are handled separately in the training loop + + # Determine batch size (matching new GINO logic) + if x is None: + batch_size = 1 + else: + batch_size = x.shape[0] + + # Handle latent_features validation (matching new GINO) + # Access gino attributes - works for both converted and unconverted models + # Use safe attribute access to avoid triggering __getattr__ recursion + try: + # Try to get inner_model directly (for converted models) + gino_model = object.__getattribute__(self.gino, 'inner_model') + except (AttributeError, TypeError): + # If inner_model doesn't exist, use gino directly (unconverted model) + gino_model = self.gino + if latent_features is not None: + if gino_model.latent_feature_channels is None: + raise ValueError("if passing latent features, latent_feature_channels must be set.") + if latent_features.shape[-1] != gino_model.latent_feature_channels: + raise ValueError(f"latent_features.shape[-1] must equal latent_feature_channels") + if latent_features.ndim != gino_model.gno_coord_dim + 2: + raise ValueError( + f"Latent features must be of shape (batch, n_gridpts_1, ...n_gridpts_n, feat_dim), " + f"got {latent_features.shape}" + ) + if latent_features.shape[0] != batch_size: + if latent_features.shape[0] == 1: + latent_features = latent_features.repeat(batch_size, *[1]*(latent_features.ndim-1)) + + # Handle input geometry and queries (matching new GINO: squeeze(0)) + input_geom = input_geom.squeeze(0) + latent_queries = latent_queries.squeeze(0) + + # Pass through input GNOBlock (matching new GINO exactly) + in_p = gino_model.gno_in( + y=input_geom, + x=latent_queries.view((-1, latent_queries.shape[-1])), + f_y=x + ) + + # Reshape to grid format (matching new GINO) + grid_shape = latent_queries.shape[:-1] # (H, W) for 2D + in_p = in_p.reshape((batch_size, *grid_shape, -1)) + + # Concatenate latent features if provided (matching new GINO) + if latent_features is not None: + in_p = torch.cat((in_p, latent_features), dim=-1) + + # Get latent embedding (this is what we need for feature extraction) + # This matches new GINO's latent_embedding call + # Handle nn.Identity() which doesn't accept keyword arguments + if isinstance(gino_model.latent_embedding, torch.nn.Identity): + latent_embed = gino_model.latent_embedding(in_p) + else: + latent_embed = gino_model.latent_embedding(in_p=in_p, ada_in=ada_in) + # latent_embed shape: (B, channels, H, W) for 2D - keep this for feature return + + # Prepare for output GNO (matching new GINO exactly) + # latent_embed shape (b, c, n_1, n_2, ..., n_k) + batch_size = latent_embed.shape[0] + # permute to (b, n_1, n_2, ...n_k, c) then reshape to (b, n_1 * n_2 * ...n_k, c) + # Note: new GINO reassigns latent_embed here, but we keep both versions + latent_embed_flat = latent_embed.permute( + 0, *gino_model.in_coord_dim_reverse_order, 1 + ).reshape(batch_size, -1, gino_model.fno_hidden_channels) + + # Apply tanh if needed (matching new GINO - applied after reshape to flattened version) + if gino_model.out_gno_tanh in ["latent_embed", "both"]: + latent_embed_flat = torch.tanh(latent_embed_flat) + # Also apply to unflattened latent_embed for feature return consistency + latent_embed = torch.tanh(latent_embed) + + # Handle output_queries (matching new GINO logic) + if isinstance(output_queries, dict): + # Multiple output queries - handle each separately + out = {} + for key, out_p in output_queries.items(): + out_p = out_p.squeeze(0) + + sub_output = gino_model.gno_out( + y=latent_queries.reshape((-1, latent_queries.shape[-1])), + x=out_p, + f_y=latent_embed_flat, + ) + sub_output = sub_output.permute(0, 2, 1) # (B, channels, n_out) -> (B, n_out, channels) + sub_output = gino_model.projection(sub_output) # (B, n_out, channels) -> (B, out_channels, n_out) + # Transpose to get (B, n_out, out_channels) to match expected shape + sub_output = sub_output.permute(0, 2, 1) # (B, out_channels, n_out) -> (B, n_out, out_channels) + + # Apply residual connection if autoregressive (NEW FUNCTIONALITY) + # For autoregressive mode, add the previous time step's output + # The previous step should come from the last timestep of dynamic history in x + # Note: This only works if output_queries match input_geom spatial locations + if self.autoregressive and (x is not None): + # sub_output shape: (B, n_out, out_channels) + # x shape: (B, n_in, in_channels) + # For autoregressive to work, n_out must equal n_in (same spatial locations) + if sub_output.shape[0] == x.shape[0] and sub_output.shape[1] == x.shape[1]: + # Spatial dimensions match - can apply autoregressive connection + # Extract last 3 channels from x (last timestep of dynamic: WD, VX, VY) + # x contains [static, boundary, dynamic] where dynamic ends with last 3 channels + if x.shape[2] >= 3: + prev_step = x[..., -3:] # (B, n_in, 3) = (B, n_out, 3) + sub_output = sub_output + prev_step + # If spatial dimensions don't match, skip autoregressive (output_queries != input_geom) + + out[key] = sub_output + else: + # Single output query (matching new GINO) + output_queries = output_queries.squeeze(0) + + out = gino_model.gno_out( + y=latent_queries.reshape((-1, latent_queries.shape[-1])), + x=output_queries, + f_y=latent_embed_flat, + ) + out = out.permute(0, 2, 1) # (B, channels, n_out) -> (B, n_out, channels) + out = gino_model.projection(out) # (B, n_out, channels) -> (B, out_channels, n_out) + # Transpose to get (B, n_out, out_channels) to match expected shape + out = out.permute(0, 2, 1) # (B, out_channels, n_out) -> (B, n_out, out_channels) + + # Apply residual connection if autoregressive (NEW FUNCTIONALITY) + # For autoregressive mode, add the previous time step's output + # The previous step should come from the last timestep of dynamic history in x + # Note: This only works if output_queries match input_geom spatial locations + if self.autoregressive and (x is not None): + # out shape: (B, n_out, out_channels) + # x shape: (B, n_in, in_channels) + # For autoregressive to work, n_out must equal n_in (same spatial locations) + if out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]: + # Spatial dimensions match - can apply autoregressive connection + # Extract last 3 channels from x (last timestep of dynamic: WD, VX, VY) + # x contains [static, boundary, dynamic] where dynamic ends with last 3 channels + if x.shape[2] >= 3: + prev_step = x[..., -3:] # (B, n_in, 3) = (B, n_out, 3) + out = out + prev_step + # If spatial dimensions don't match, skip autoregressive (output_queries != input_geom) + + # Return features if requested (NEW FUNCTIONALITY) + if return_features: + return out, latent_embed + else: + return out + + @property + def fno_hidden_channels(self) -> int: + r"""Expose GINO's fno_hidden_channels for domain classifier.""" + # Handle both converted (has inner_model) and unconverted models + try: + gino_model = object.__getattribute__(self.gino, 'inner_model') + except (AttributeError, TypeError): + gino_model = self.gino + return gino_model.fno_hidden_channels + + @property + def model(self) -> nn.Module: + r"""Alias for the wrapped model.""" + # Return the actual GINO model (unwrap if converted) + try: + return object.__getattribute__(self.gino, 'inner_model') + except (AttributeError, TypeError): + return self.gino + + @property + def device(self) -> torch.device: + r"""Return the device of the model (device of first parameter).""" + # Get the actual GINO model (unwrap if converted) + try: + gino_model = object.__getattribute__(self.gino, 'inner_model') + except (AttributeError, TypeError): + gino_model = self.gino + + # Return device of first parameter (standard PyTorch pattern) + if hasattr(gino_model, 'device'): + return gino_model.device + # Fallback: get device from first parameter + try: + first_param = next(gino_model.parameters()) + return first_param.device + except StopIteration: + # No parameters, return CPU as default + return torch.device('cpu') + + def __getattr__(self, name): + r""" + Delegate attribute access to the underlying GINO model. + + This allows the wrapper to be used as a drop-in replacement. + Handles both converted (PhysicsNeMo) and unconverted (PyTorch) models. + """ + # 'gino' and 'autoregressive' are wrapper attributes, should be accessible directly + # If we get here, it means normal attribute lookup failed + # For 'gino', try to get from _modules (if registered as submodule) + if name == 'gino': + if 'gino' in self._modules: + return self._modules['gino'] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + if name == 'autoregressive': + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # For other attributes, delegate to gino + # Get gino (should exist as attribute or in _modules) + gino = None + if hasattr(self, 'gino'): + gino = self.gino + elif 'gino' in self._modules: + gino = self._modules['gino'] + + if gino is not None: + # If gino is a converted model (has inner_model), access through it + # Otherwise, access directly + try: + # Try to get inner_model safely (for converted models) + try: + inner_model = object.__getattribute__(gino, 'inner_model') + # Try inner_model first (the actual GINO model) + return getattr(inner_model, name) + except (AttributeError, TypeError): + # No inner_model, so gino is the actual model + return getattr(gino, name) + except AttributeError: + # If not found, try the wrapper itself (for wrapper-specific attributes) + try: + inner_model = object.__getattribute__(gino, 'inner_model') + # Try wrapper attributes + return getattr(gino, name) + except (AttributeError, TypeError): + pass + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True) -> Any: + r""" + Load state dictionary, filtering out non-model keys like _metadata. + + PhysicsNeMo's Module.save() may include metadata keys in the state_dict + that PyTorch's load_state_dict() doesn't expect. This method filters + out those keys before loading. + + Parameters + ---------- + state_dict : Dict[str, Any] + State dictionary to load. + strict : bool, optional + Whether to strictly enforce that the keys in state_dict match + the keys returned by this module's state_dict() function. + Default is True. + + Returns + ------- + Any + Missing and unexpected keys (same as PyTorch's load_state_dict). + """ + # Filter out non-model keys that PhysicsNeMo may include + # These are metadata keys that shouldn't be loaded into the model + filtered_state_dict = {} + non_model_keys = ['_metadata', 'device_buffer'] # Add other non-model keys here if needed + + # Get the actual GINO model (unwrap if converted) + try: + gino_model = object.__getattribute__(self.gino, 'inner_model') + except (AttributeError, TypeError): + gino_model = self.gino + + # Check if gino_model is a PhysicsNeMo Module - if so, it should handle its own loading + if isinstance(gino_model, physicsnemo.models.Module): + # For PhysicsNeMo modules, pass the state dict through with proper key mapping + # PhysicsNeMo's Module.load_state_dict will handle the nested structure + for key, value in state_dict.items(): + # Skip non-model keys + if key in non_model_keys or (key.startswith('_') and key != '_modules'): + continue + + # Handle nested keys from PhysicsNeMo checkpoints + if key.startswith('gino.inner_model.'): + # Map to inner model keys (remove 'gino.inner_model.' prefix) + inner_key = key.replace('gino.inner_model.', '') + filtered_state_dict[inner_key] = value + elif key.startswith('gino.'): + # For keys like 'gino.*', check if they should be passed to inner model + # If the inner model is a PhysicsNeMo Module, it might expect these keys + inner_key = key.replace('gino.', '') + # Only include if it's not a wrapper-specific key + if not inner_key.startswith('_'): + filtered_state_dict[inner_key] = value + else: + # Direct keys (parameters stored at root level in checkpoint) + filtered_state_dict[key] = value + + # Load into the inner GINO model (PhysicsNeMo Module) + return gino_model.load_state_dict(filtered_state_dict, strict=strict) + else: + # For regular PyTorch modules, use the original logic + for key, value in state_dict.items(): + # Skip non-model keys + if key in non_model_keys or key.startswith('_'): + continue + + # Handle nested keys from PhysicsNeMo checkpoints (e.g., 'gino.inner_model.*') + if key.startswith('gino.inner_model.'): + # Map to inner model keys + inner_key = key.replace('gino.inner_model.', '') + filtered_state_dict[inner_key] = value + elif key.startswith('gino.'): + # Skip gino wrapper keys, only load inner model + continue + else: + filtered_state_dict[key] = value + + # Load into the inner GINO model + return gino_model.load_state_dict(filtered_state_dict, strict=strict) + + def save_checkpoint(self, save_folder: str, save_name: str) -> None: + r""" + Delegate checkpoint saving to wrapped GINO model. + + Parameters + ---------- + save_folder : str + Directory to save checkpoint. + save_name : str + Name of checkpoint file. + + Raises + ------ + AttributeError + If wrapped model does not have ``save_checkpoint`` method. + """ + # Get the actual GINO model (unwrap if converted) + try: + gino_model = object.__getattribute__(self.gino, 'inner_model') + except (AttributeError, TypeError): + gino_model = self.gino + + if not hasattr(gino_model, 'save_checkpoint'): + raise AttributeError(f"Wrapped model {type(gino_model).__name__} does not have save_checkpoint method") + return gino_model.save_checkpoint(save_folder, save_name) + + @classmethod + def from_checkpoint(cls, save_folder: str, save_name: str, map_location: Optional[str] = None): + r""" + Load from checkpoint - delegate to GINO. + + Parameters + ---------- + save_folder : str + Directory containing checkpoint. + save_name : str + Name of checkpoint file. + map_location : str, optional + Device to map checkpoint to. + + Returns + ------- + GINOWrapper + GINOWrapper instance with loaded model. + """ + from neuralop.models import GINO + gino = GINO.from_checkpoint(save_folder, save_name, map_location) + return cls(gino) + + +class FloodGINODataProcessor(physicsnemo.Module): + r""" + Data processor for flood GINO model that handles preprocessing and postprocessing. + + Compatible with neuralop 2.0.0 DataProcessor interface. + + GINO batching note: + GINO supports batching only when geometry is SHARED across the batch. + - input_geom, latent_queries, output_queries: NO batch dimension (shared) + - x (features): HAS batch dimension :math:`(B, n_{in}, C_{in})` + - y (target): HAS batch dimension :math:`(B, n_{out}, C_{out})` + + Parameters + ---------- + device : str or torch.device, optional, default="cuda" + Device to move tensors to (string like "cuda" or "cpu", or torch.device object). + target_norm : nn.Module, optional + Target normalizer for inverse transform. + inverse_test : bool, optional, default=True + Whether to apply inverse transform during testing. + """ + + def __init__( + self, + device: Union[str, torch.device] = "cuda", + target_norm: Optional[nn.Module] = None, + inverse_test: bool = True + ): + r""" + Initialize flood GINO data processor. + + Parameters + ---------- + device : str or torch.device, optional, default="cuda" + Device to move tensors to (string like "cuda" or "cpu", or torch.device object). + target_norm : nn.Module, optional + Target normalizer for inverse transform. + inverse_test : bool, optional, default=True + Whether to apply inverse transform during testing. + + Raises + ------ + TypeError + If ``device`` is not a string or torch.device. + """ + super().__init__(meta=ModelMetaData(name="FloodGINODataProcessor")) + # Accept both string and torch.device objects - preserve original type + if not isinstance(device, (str, torch.device)): + raise TypeError(f"device must be a string or torch.device, got {type(device)}") + # Store device string for reference, but use super().to() to actually move module + # Note: physicsnemo.Module has a read-only 'device' property, so we can't set self.device directly + self._device_str = str(device) if isinstance(device, torch.device) else device + # Move module to device using parent class method + super().to(device) + self.model: Optional[nn.Module] = None + self.target_norm = target_norm + self.inverse_test = inverse_test + + def preprocess(self, sample: Dict, batched: bool = True) -> Dict: + r""" + Preprocess sample for GINO model input. + + GINO expects: + - input_geom: :math:`(n_{in}, 2)` - NO batch dim, shared geometry + - latent_queries: :math:`(H, W, 2)` - NO batch dim + - output_queries: :math:`(n_{out}, 2)` - NO batch dim + - x: :math:`(B, n_{in}, C_{in})` - HAS batch dim + - y: :math:`(B, n_{out}, C_{out})` - HAS batch dim (for loss) + + Parameters + ---------- + sample : Dict + Sample dictionary with geometry, static, boundary, dynamic, etc. + batched : bool, optional, default=True + Whether the data is batched (currently unused but kept for API compatibility). + + Returns + ------- + Dict + New dictionary with GINO-compatible keys + y for loss. + + Raises + ------ + KeyError + If required keys are missing from sample. + RuntimeError + If tensor shapes are incompatible. + """ + ### Input validation + # Skip validation when running under torch.compile for performance + if not torch.compiler.is_compiling(): + # Validate required keys + required_keys = ["geometry", "static", "boundary", "dynamic", "query_points"] + missing_keys = [key for key in required_keys if key not in sample] + if missing_keys: + raise KeyError(f"Missing required keys in sample: {missing_keys}") + + # Validate tensor shapes + for key in ["geometry", "static", "boundary", "dynamic", "query_points"]: + if key in sample and isinstance(sample[key], torch.Tensor): + if sample[key].ndim < 2: + raise ValueError( + f"Expected {key} to be at least 2D tensor, got {sample[key].ndim}D tensor " + f"with shape {tuple(sample[key].shape)}" + ) + + # Move all tensors to device + for k, v in sample.items(): + if isinstance(v, torch.Tensor): + sample[k] = v.to(self.device) + + # Get batch dimension info + dyn_ = sample["dynamic"] + # dynamic comes as (B, n_history, num_cells, 3) or (n_history, num_cells, 3) + if dyn_.dim() == 3: + # Single sample: (n_history, num_cells, 3) -> add batch dim + dyn_ = dyn_.unsqueeze(0) + # Now dyn_ is (B, n_history, num_cells, 3) + # Reshape to (B, num_cells, n_history * 3) + dyn_ = dyn_.permute(0, 2, 1, 3) # (B, num_cells, n_history, 3) + B, N, H, D = dyn_.shape + dyn_ = dyn_.reshape(B, N, H * D) + + # boundary: (B, n_history, num_cells, bc_dim) or (n_history, num_cells, bc_dim) + bc_ = sample["boundary"] + if bc_.dim() == 3: + bc_ = bc_.unsqueeze(0) + bc_ = bc_.permute(0, 2, 1, 3) # (B, num_cells, n_history, bc_dim) + B2, N2, H2, C2 = bc_.shape + bc_ = bc_.reshape(B2, N2, H2 * C2) + + # static: (B, num_cells, static_dim) or (num_cells, static_dim) + st_ = sample["static"] + if st_.dim() == 2: + st_ = st_.unsqueeze(0) + + # Concatenate all features: [static, boundary, dynamic] + # x shape: (B, num_cells, total_features) + x_ = torch.cat([st_, bc_, dyn_], dim=2) + + # geometry: (B, num_cells, 2) or (num_cells, 2) + # GINO expects geometry WITHOUT batch dim - use first sample's geometry (shared) + geom_ = sample["geometry"] + if geom_.dim() == 3: + # Take first sample's geometry (all should be same for GINO batching) + geom_ = geom_[0] + # geom_ is now (num_cells, 2) - NO batch dim + + # target (y): (B, num_cells, 3) or (num_cells, 3) + y_ = sample.get("target", None) + if y_ is not None and y_.dim() == 2: + y_ = y_.unsqueeze(0) + + # query_points: (B, H, W, 2) or (H, W, 2) + # GINO expects latent_queries WITHOUT batch dim + q_ = sample["query_points"] + if q_.dim() == 4: + # Take first sample's query points (should be same for all) + q_ = q_[0] + # q_ is now (H, W, 2) - NO batch dim + + # Return ONLY the keys needed (GINO inputs + y for loss) + return { + "input_geom": geom_, # (n_in, 2) - NO batch + "latent_queries": q_, # (H, W, 2) - NO batch + "output_queries": geom_.clone(), # (n_out, 2) - NO batch + "x": x_, # (B, n_in, features) - HAS batch + "y": y_, # (B, n_out, 3) - HAS batch + } + + def postprocess(self, out: torch.Tensor, sample: Dict) -> Tuple[torch.Tensor, Dict]: + r""" + Postprocess model output. + + Parameters + ---------- + out : torch.Tensor + Model output tensor of shape :math:`(B, n_{out}, C_{out})`. + sample : Dict + Sample dictionary. + + Returns + ------- + Tuple[torch.Tensor, Dict] + Tuple of (postprocessed output, sample). + """ + ### Input validation + # Skip validation when running under torch.compile for performance + if not torch.compiler.is_compiling(): + if not isinstance(out, torch.Tensor): + raise ValueError( + f"Expected out to be torch.Tensor, got {type(out)}" + ) + if out.ndim < 2: + raise ValueError( + f"Expected out to be at least 2D tensor (B, n_out, C_out), " + f"got {out.ndim}D tensor with shape {tuple(out.shape)}" + ) + if not isinstance(sample, dict): + raise ValueError( + f"Expected sample to be dict, got {type(sample)}" + ) + + if (not self.training) and self.inverse_test and (self.target_norm is not None): + out = self.target_norm.inverse_transform(out) + if sample.get("y") is not None: + sample["y"] = self.target_norm.inverse_transform(sample["y"]) + return out, sample + + def to(self, device: Union[str, torch.device]): + r""" + Move processor to device. + + Parameters + ---------- + device : str or torch.device + Target device (string like "cuda" or "cpu", or torch.device object). + + Returns + ------- + FloodGINODataProcessor + Self for method chaining. + + Raises + ------ + TypeError + If ``device`` is not a string or torch.device. + """ + # Accept both string and torch.device objects - preserve original type + if not isinstance(device, (str, torch.device)): + raise TypeError(f"device must be a string or torch.device, got {type(device)}") + + # Update device string reference + self._device_str = str(device) if isinstance(device, torch.device) else device + # Move module to device using parent class method (physicsnemo.Module has read-only device property) + super().to(device) + if self.target_norm is not None: + self.target_norm = self.target_norm.to(device) + return self + + def wrap(self, model: nn.Module): + r""" + Wrap model with this processor. + + Parameters + ---------- + model : nn.Module + Model to wrap. + + Returns + ------- + FloodGINODataProcessor + Self for method chaining. + + Raises + ------ + ValueError + If ``model`` is None. + """ + if model is None: + raise ValueError("model cannot be None") + self.model = model + return self + + def forward(self, **data_dict) -> Tuple[torch.Tensor, Dict]: + r""" + Forward pass through processor and model. + + Parameters + ---------- + **data_dict + Input data dictionary. + + Returns + ------- + Tuple[torch.Tensor, Dict] + Tuple of (output tensor, processed data dictionary). + + Raises + ------ + RuntimeError + If no model is attached. + """ + data_dict = self.preprocess(data_dict) + if self.model is None: + raise RuntimeError("No model attached. Call wrap(model).") + + model_input = { + "input_geom": data_dict["input_geom"], + "latent_queries": data_dict["latent_queries"], + "output_queries": data_dict["output_queries"], + "x": data_dict["x"], + } + out = self.model(**model_input) + + out, data_dict = self.postprocess(out, data_dict) + return out, data_dict diff --git a/examples/weather/flood_modeling/flood_forecaster/datasets/__init__.py b/examples/weather/flood_modeling/flood_forecaster/datasets/__init__.py new file mode 100644 index 0000000000..8ae57e1a90 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/datasets/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Dataset classes for flood prediction.""" + +from .flood_dataset import FloodDatasetWithQueryPoints +from .normalized_dataset import NormalizedDataset, NormalizedRolloutTestDataset +from .rollout_dataset import FloodRolloutTestDatasetNew + +__all__ = [ + "FloodDatasetWithQueryPoints", + "FloodRolloutTestDatasetNew", + "NormalizedDataset", + "NormalizedRolloutTestDataset", +] + diff --git a/examples/weather/flood_modeling/flood_forecaster/datasets/flood_dataset.py b/examples/weather/flood_modeling/flood_forecaster/datasets/flood_dataset.py new file mode 100644 index 0000000000..82d30b32ac --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/datasets/flood_dataset.py @@ -0,0 +1,362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Training dataset for flood prediction with query points. +""" + +import math +import warnings +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + + +class FloodDatasetWithQueryPoints(Dataset): + r""" + Dataset for training/one-step testing with channel order [WD, VX, VY]. + + Ensures dynamic history channels are [WD=0, VX=1, VY=2]. + + Parameters + ---------- + data_root : str or Path + Root directory containing data files. + n_history : int + Number of history timesteps. + query_res : List[int], optional, default=[64, 64] + Query resolution. + xy_file : str, optional + Filename for XY coordinates. + static_files : List[str], optional + List of static feature filenames. + dynamic_patterns : Dict[str, str], optional + Dict mapping variable names to filename patterns. + boundary_patterns : Dict[str, str], optional + Dict mapping boundary names to filename patterns. + raise_on_smaller : bool, optional, default=True + Whether to raise error if data is smaller than expected. + skip_before_timestep : int, optional, default=0 + Number of timesteps to skip at the beginning. + noise_type : str, optional, default="none" + Type of noise to apply ("none", "only_last", "correlated", "uncorrelated", "random_walk"). + noise_std : List[float], optional + List of 3 floats for noise std for [WD, VX, VY]. + + Raises + ------ + FileNotFoundError + If data root or required files are not found. + ValueError + If noise_std length is not 3 or if no valid run IDs are found. + """ + + def __init__( + self, + data_root, + n_history, + query_res=None, + xy_file=None, + static_files=None, + dynamic_patterns=None, + boundary_patterns=None, + raise_on_smaller=True, + skip_before_timestep=0, + noise_type="none", + noise_std=None, + ): + r""" + Initialize flood dataset with query points. + """ + super().__init__() + self.data_root = Path(data_root) + if not self.data_root.exists(): + raise FileNotFoundError(f"Data root not found: {self.data_root}") + self.n_history = n_history + self.query_res = query_res if query_res else [64, 64] + self.xy_file = xy_file + self.static_files = static_files if static_files else [] + self.dynamic_patterns = dynamic_patterns if dynamic_patterns else {} + self.boundary_patterns = boundary_patterns if boundary_patterns else {} + self.raise_on_smaller = raise_on_smaller + self.skip_before_timestep = skip_before_timestep + + # NOISE PARAMS + if noise_std is None or (isinstance(noise_std, (list, tuple)) and len(noise_std) == 0): + self.noise_type = "none" + self.noise_std = [0.0, 0.0, 0.0] + else: + self.noise_type = noise_type.lower() if noise_type else "none" + self.noise_std = noise_std + if len(self.noise_std) != 3: + raise ValueError("noise_std must be a list of exactly 3 floats for WD, VX, VY.") + + # Read run IDs from train.txt + train_txt = self.data_root / "train.txt" + if not train_txt.exists(): + raise FileNotFoundError(f"Expected train.txt at {train_txt}, not found!") + try: + with open(train_txt, "r", encoding="utf-8") as f: + lines = [ln.strip() for ln in f if ln.strip()] + except IOError as e: + raise IOError(f"Failed to read train.txt from {train_txt}: {e}") from e + if len(lines) == 1 and "," in lines[0]: + self.run_ids = lines[0].split(",") + else: + self.run_ids = lines + self.run_ids = [rid.strip() for rid in self.run_ids if rid.strip()] + if not self.run_ids: + raise ValueError("No valid run IDs found in train.txt") + + # Internals + self.xy_coords = None + self.static_data = None + self.dynamic_data = {} + self.boundary_data = {} + self.sample_index = [] + + # Load data + self.reference_cell_count = self._load_xy_file() + self._load_static() + self._load_all_runs() + self._build_sample_indices() + + def _load_xy_file(self): + r"""Load and normalize XY coordinates.""" + if not self.xy_file: + raise ValueError("xy_file was not provided! Please specify in config.") + xy_path = self.data_root / self.xy_file + if not xy_path.exists(): + raise FileNotFoundError(f"Reference XY file not found: {xy_path}") + + # Load raw coordinates + xy_arr = np.loadtxt(str(xy_path), delimiter="\t", dtype=np.float32) + if xy_arr.ndim != 2 or xy_arr.shape[1] != 2: + raise ValueError(f"{self.xy_file} must be shape (num_cells,2). Got {xy_arr.shape}.") + + # Unit-box normalization + min_xy = xy_arr.min(axis=0) + max_xy = xy_arr.max(axis=0) + range_xy = max_xy - min_xy + range_xy[range_xy == 0] = 1.0 + xy_arr = (xy_arr - min_xy) / range_xy + + self.xy_coords = torch.tensor(xy_arr, device='cpu') + return self.xy_coords.shape[0] + + def _load_static(self): + r"""Load static feature files.""" + if not self.static_files: + self.static_data = torch.zeros((self.reference_cell_count, 0), device='cpu') + return + static_list = [] + for fname in self.static_files: + fpath = self.data_root / fname + if not fpath.exists(): + warnings.warn(f"Static file not found: {fpath}, skipping.") + continue + arr = np.loadtxt(str(fpath), delimiter="\t", dtype=np.float32) + if arr.ndim == 1: + arr = arr[:, None] + n_file = arr.shape[0] + if n_file < self.reference_cell_count: + msg = f"Static {fname} has {n_file} < {self.reference_cell_count}" + if self.raise_on_smaller: + raise ValueError(msg) + else: + warnings.warn(msg + " -> skipping.") + continue + elif n_file > self.reference_cell_count: + arr = arr[:self.reference_cell_count, :] + static_list.append(arr) + if not static_list: + self.static_data = torch.zeros((self.reference_cell_count, 0), device='cpu') + return + combined_arr = np.concatenate(static_list, axis=1) + self.static_data = torch.tensor(combined_arr, device='cpu') + + def _load_all_runs(self): + r"""Load dynamic and boundary data for all runs.""" + for run_id in tqdm(self.run_ids, desc="Loading runs for training"): + self.dynamic_data[run_id] = {} + self.boundary_data[run_id] = {} + for dkey, pattern in self.dynamic_patterns.items(): + fname = pattern.format(run_id) + fpath = self.data_root / fname + if not fpath.exists(): + warnings.warn(f"Dynamic file not found: {fpath}, skipping {dkey}.") + continue + arr = np.loadtxt(str(fpath), delimiter="\t", dtype=np.float32) + N_file = arr.shape[1] + if N_file < self.reference_cell_count: + msg = f"{fname} has {N_file} < {self.reference_cell_count}" + if self.raise_on_smaller: + raise ValueError(msg) + else: + warnings.warn(msg + " -> skipping.") + continue + elif N_file > self.reference_cell_count: + arr = arr[:, :self.reference_cell_count] + self.dynamic_data[run_id][dkey] = torch.tensor(arr, device='cpu') + # boundary + for bc_key, bc_pattern in self.boundary_patterns.items(): + fname = bc_pattern.format(run_id) + fpath = self.data_root / fname + if not fpath.exists(): + warnings.warn(f"Boundary file not found: {fpath}, skipping {bc_key}.") + continue + bc_arr = np.loadtxt(str(fpath), delimiter="\t", dtype=np.float32) + if bc_arr.ndim == 1: + bc_arr = bc_arr[:, None] + if bc_arr.shape[1] == 2: + bc_arr = bc_arr[:, 1].reshape(-1, 1) + bc_tensor = torch.tensor(bc_arr, device='cpu') + bc_tensor = bc_tensor.expand(-1, self.reference_cell_count) + bc_tensor = bc_tensor.unsqueeze(-1) + self.boundary_data[run_id][bc_key] = bc_tensor + + def _build_sample_indices(self): + r"""Build list of (run_id, timestep) indices for valid samples.""" + for run_id in self.run_ids: + dyn_dict = self.dynamic_data[run_id] + bc_dict = self.boundary_data[run_id] + if not dyn_dict or not bc_dict: + continue + required_dyn = ["WD", "VX", "VY"] + required_bc = ["inflow"] + if not all(k in dyn_dict for k in required_dyn) or not all(k in bc_dict for k in required_bc): + warnings.warn(f"Run ID {run_id} missing required dynamic/boundary variables. Skipping.") + continue + ref_tensor = dyn_dict["WD"] + T = ref_tensor.shape[0] + start_t = max(self.n_history, self.skip_before_timestep) + for t in range(start_t, T): + self.sample_index.append((run_id, t)) + + def __len__(self): + return len(self.sample_index) + + def _apply_noise(self, dynamic_hist: torch.Tensor): + r""" + Apply noise to dynamic history tensor. + + Parameters + ---------- + dynamic_hist : torch.Tensor + Dynamic history tensor of shape :math:`(H, n_{cells}, 3)` with channels + [WD=0, VX=1, VY=2] where :math:`H` is history length and :math:`n_{cells}` + is the number of cells. + + Returns + ------- + torch.Tensor + Noisy dynamic history tensor of same shape as input. + """ + if self.noise_type == "none" or all(s <= 0.0 for s in self.noise_std): + return dynamic_hist + + n_history, num_cells, d = dynamic_hist.shape + device = dynamic_hist.device + H = n_history + + def make_noise_for_all_steps(n_steps: int, n_cells: int): + base = torch.randn((n_steps, n_cells, d), device=device) + std_tensor = torch.tensor(self.noise_std, device=device).view(1, 1, d) + return base * std_tensor + + if self.noise_type == "only_last": + step_noise = make_noise_for_all_steps(1, num_cells)[0] + dynamic_hist[-1] += step_noise + + elif self.noise_type == "correlated": + single_noise = make_noise_for_all_steps(1, num_cells)[0] + for t in range(n_history): + dynamic_hist[t] += single_noise + + elif self.noise_type == "uncorrelated": + noise_ = make_noise_for_all_steps(n_history, num_cells) + dynamic_hist += noise_ + + elif self.noise_type == "random_walk": + step_sigma = [s / math.sqrt(H) for s in self.noise_std] + step_sigma_t = torch.tensor(step_sigma, device=device).view(1, 1, 3) + offset = torch.zeros((num_cells, 3), device=device) + for t in range(n_history): + # Broadcast step_sigma_t across channels: (1, 1, 3) -> (3,) + step_n = torch.randn((num_cells, 3), device=device) * step_sigma_t.squeeze() + offset += step_n + dynamic_hist[t] += offset + + else: + warnings.warn(f"Unknown noise_type={self.noise_type}, skipping noise.") + return dynamic_hist + + def __getitem__(self, idx): + r"""Get a single sample.""" + run_id, target_t = self.sample_index[idx] + in_geom = self.xy_coords + static_feats = self.static_data + dyn_dict = self.dynamic_data[run_id] + bc_dict = self.boundary_data[run_id] + t0 = target_t - self.n_history + num_cells = in_geom.shape[0] + + # Build dynamic history with order [WD, VX, VY] + wanted_order = ["WD", "VX", "VY"] + hist_list = [] + for dkey in wanted_order: + arr_slice = dyn_dict[dkey][t0:target_t, :] + hist_list.append(arr_slice.unsqueeze(-1)) + dynamic_hist = torch.cat(hist_list, dim=-1) + + # Noise injection + dynamic_hist = self._apply_noise(dynamic_hist) + + # Boundary condition history + boundary_vars = sorted(bc_dict.keys()) + bc_list = [] + for bck in boundary_vars: + bc_slice = bc_dict[bck][t0:target_t, :] + bc_list.append(bc_slice) + if bc_list: + bc_hist = torch.cat(bc_list, dim=-1) + else: + bc_hist = torch.zeros((self.n_history, num_cells, 1)) + + # Build target [WD, VX, VY] + def safe_get(k): + if k not in dyn_dict: + return torch.zeros((num_cells,)) + return dyn_dict[k][target_t, :] + + wd = safe_get("WD").unsqueeze(-1) + vx = safe_get("VX").unsqueeze(-1) + vy = safe_get("VY").unsqueeze(-1) + target_all = torch.cat([wd, vx, vy], dim=-1) + + return { + "geometry": in_geom, + "static": static_feats, + "boundary": bc_hist, + "dynamic": dynamic_hist, + "target": target_all, + "run_id": run_id, + "time_index": target_t, + } + diff --git a/examples/weather/flood_modeling/flood_forecaster/datasets/normalized_dataset.py b/examples/weather/flood_modeling/flood_forecaster/datasets/normalized_dataset.py new file mode 100644 index 0000000000..b10d6e9b5b --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/datasets/normalized_dataset.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Normalized dataset wrappers for flood prediction. +""" + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class NormalizedDataset(Dataset): + r""" + Dataset wrapper that provides normalized data with query points. + """ + + def __init__(self, geometry, static, boundary, dynamic, target=None, query_res=None, cell_area=None): + r""" + Initialize normalized dataset. + + Parameters + ---------- + geometry : torch.Tensor + Normalized geometry tensor of shape :math:`(N, n_{cells}, 2)`. + static : torch.Tensor + Normalized static features tensor of shape :math:`(N, n_{cells}, C_{static})`. + boundary : torch.Tensor + Normalized boundary conditions tensor of shape :math:`(N, H, n_{cells}, C_{boundary})`. + dynamic : torch.Tensor + Normalized dynamic features tensor of shape :math:`(N, H, n_{cells}, C_{dynamic})`. + target : torch.Tensor, optional + Normalized target tensor of shape :math:`(N, n_{cells}, C_{target})`. + query_res : List[int], optional, default=[64, 64] + Query resolution [height, width]. If None, defaults to [64, 64]. + cell_area : torch.Tensor, optional + Cell area tensor of shape :math:`(N, n_{cells})`. + """ + self.geometry = geometry + self.static = static + self.boundary = boundary + self.dynamic = dynamic + self.target = target + # Use immutable default to avoid mutable default argument issues + self.query_res = query_res if query_res is not None else [64, 64] + self.cell_area = cell_area + + if self.geometry is not None and self.geometry.shape[0] > 0: + geom_sample = self.geometry[0].cpu().numpy() + x_vals = geom_sample[:, 0] + y_vals = geom_sample[:, 1] + min_x, max_x = x_vals.min(), x_vals.max() + min_y, max_y = y_vals.min(), y_vals.max() + tx = np.linspace(min_x, max_x, self.query_res[0], dtype=np.float32) + ty = np.linspace(min_y, max_y, self.query_res[1], dtype=np.float32) + grid_x, grid_y = np.meshgrid(tx, ty, indexing="ij") + q_pts = np.stack([grid_x, grid_y], axis=-1) + self.query_points = torch.tensor(q_pts, device='cpu') + else: + self.query_points = torch.zeros((self.query_res[0], self.query_res[1], 2), dtype=torch.float32) + + def __len__(self): + return self.geometry.shape[0] if self.geometry is not None else 0 + + def __getitem__(self, idx): + r"""Get a single normalized sample.""" + sample = { + "geometry": self.geometry[idx], + "static": self.static[idx], + "boundary": self.boundary[idx], + "dynamic": self.dynamic[idx], + "query_points": self.query_points + } + if self.target is not None: + sample["target"] = self.target[idx] + if self.cell_area is not None: + sample["cell_area"] = self.cell_area[idx] + return sample + + +class NormalizedRolloutTestDataset(Dataset): + r""" + Dataset wrapper for normalized rollout test samples. + """ + + def __init__(self, normalized_samples, query_res=None): + r""" + Initialize normalized rollout test dataset. + + Parameters + ---------- + normalized_samples : List[Dict] + List of normalized sample dictionaries. + query_res : List[int], optional, default=[64, 64] + Query resolution [height, width]. If None, defaults to [64, 64]. + """ + self.normalized_samples = normalized_samples + # Use immutable default to avoid mutable default argument issues + self.query_res = query_res if query_res is not None else [64, 64] + + if len(self.normalized_samples) > 0: + geom_sample = self.normalized_samples[0]["geometry"].cpu().numpy() + x_vals = geom_sample[:, 0] + y_vals = geom_sample[:, 1] + min_x, max_x = x_vals.min(), x_vals.max() + min_y, max_y = y_vals.min(), y_vals.max() + tx = np.linspace(min_x, max_x, self.query_res[0], dtype=np.float32) + ty = np.linspace(min_y, max_y, self.query_res[1], dtype=np.float32) + grid_x, grid_y = np.meshgrid(tx, ty, indexing="ij") + q_pts = np.stack([grid_x, grid_y], axis=-1) + self.query_points = torch.tensor(q_pts, device='cpu') + else: + self.query_points = torch.zeros((self.query_res[0], self.query_res[1], 2), dtype=torch.float32) + + def __len__(self): + return len(self.normalized_samples) + + def __getitem__(self, idx): + r"""Get a single normalized rollout sample.""" + sample = self.normalized_samples[idx].copy() + sample["query_points"] = self.query_points + return sample + diff --git a/examples/weather/flood_modeling/flood_forecaster/datasets/rollout_dataset.py b/examples/weather/flood_modeling/flood_forecaster/datasets/rollout_dataset.py new file mode 100644 index 0000000000..3547fd5686 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/datasets/rollout_dataset.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Rollout test dataset for flood prediction. +""" + +import warnings +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import Dataset +from tqdm import tqdm + + +class FloodRolloutTestDatasetNew(Dataset): + r""" + Dataset for rollout evaluation with channel order [WD, VX, VY]. + """ + + def __init__( + self, + rollout_data_root, + n_history, + rollout_length, + xy_file=None, + query_res=None, + static_files=None, + dynamic_patterns=None, + boundary_patterns=None, + raise_on_smaller=True, + skip_before_timestep=0 + ): + r""" + Initialize rollout test dataset. + + Parameters + ---------- + rollout_data_root : str or Path + Root directory containing rollout test data. + n_history : int + Number of history timesteps. + rollout_length : int + Length of rollout to evaluate. + xy_file : str, optional + Filename for XY coordinates. + query_res : List[int], optional, default=[64, 64] + Query resolution. + static_files : List[str], optional + List of static feature filenames. + dynamic_patterns : Dict[str, str], optional + Dict mapping variable names to filename patterns. + boundary_patterns : Dict[str, str], optional + Dict mapping boundary names to filename patterns. + raise_on_smaller : bool, optional, default=True + Whether to raise error if data is smaller than expected. + skip_before_timestep : int, optional, default=0 + Number of timesteps to skip at the beginning. + + Raises + ------ + FileNotFoundError + If data root or required files are not found. + ValueError + If no valid run IDs are found. + """ + super().__init__() + self.data_root = Path(rollout_data_root) + if not self.data_root.exists(): + raise FileNotFoundError(f"Rollout data root not found: {self.data_root}") + self.n_history = n_history + self.rollout_length = rollout_length + self.xy_file = xy_file + self.query_res = query_res if query_res else [64, 64] + self.static_files = static_files if static_files else [] + self.dynamic_patterns = dynamic_patterns if dynamic_patterns else {} + self.boundary_patterns = boundary_patterns if boundary_patterns else {} + self.raise_on_smaller = raise_on_smaller + self.skip_before_timestep = skip_before_timestep + + # Read run IDs from test.txt + test_txt = self.data_root / "test.txt" + if not test_txt.exists(): + raise FileNotFoundError(f"Expected test.txt at {test_txt}, not found!") + try: + with open(test_txt, "r", encoding="utf-8") as f: + lines = [ln.strip() for ln in f if ln.strip()] + except IOError as e: + raise IOError(f"Failed to read test.txt from {test_txt}: {e}") from e + if len(lines) == 1 and "," in lines[0]: + self.run_ids = lines[0].split(",") + else: + self.run_ids = lines + self.run_ids = [rid.strip() for rid in self.run_ids if rid.strip()] + if not self.run_ids: + raise ValueError("No valid run IDs found in test.txt") + + self.xy_coords = None + self.static_data = None + self.cell_area = None # For volume conservation + self.dynamic_data = {} + self.boundary_data = {} + + self.reference_cell_count = self._load_xy_file() + self._load_static() + self._load_all_runs() + + # Filter out runs that lack enough time steps + self.valid_run_ids = [] + for run_id in self.run_ids: + missing_vars = [var for var in ["WD", "VX", "VY"] if var not in self.dynamic_data[run_id]] + missing_bc = [var for var in ["inflow"] if var not in self.boundary_data[run_id]] + if missing_vars or missing_bc: + warnings.warn(f"Run ID {run_id} missing variables: {missing_vars}, bc: {missing_bc}. Skipping.") + continue + T = self.dynamic_data[run_id]["WD"].shape[0] + if T >= self.skip_before_timestep + self.n_history + self.rollout_length: + self.valid_run_ids.append(run_id) + + if not self.valid_run_ids: + raise ValueError("No hydrographs have enough time steps for rollout evaluation.") + + # Store the geometry/static from the first valid sample + # Safe to call __getitem__ here because all data has been loaded in _load_all_runs() + # and valid_run_ids has been populated, ensuring the dataset is fully initialized + if len(self.valid_run_ids) > 0: + sample0 = self.__getitem__(0) + self.geometry = sample0["geometry"] + self.static = sample0["static"] + else: + # This should never happen due to the check above, but defensive programming + raise RuntimeError("Cannot initialize geometry/static: no valid runs available") + + def _load_xy_file(self): + r"""Load and normalize XY coordinates.""" + if not self.xy_file: + raise ValueError("xy_file was not provided for rollout dataset! Please specify in config.") + xy_path = self.data_root / self.xy_file + if not xy_path.exists(): + raise FileNotFoundError(f"Rollout XY file not found: {xy_path}") + + # Load raw coordinates + xy_arr = np.loadtxt(str(xy_path), delimiter="\t", dtype=np.float32) + if xy_arr.ndim != 2 or xy_arr.shape[1] != 2: + raise ValueError(f"{self.xy_file} must be shape (num_cells,2). Got {xy_arr.shape}.") + + # Unit-box normalization + min_xy = xy_arr.min(axis=0) + max_xy = xy_arr.max(axis=0) + range_xy = max_xy - min_xy + range_xy[range_xy == 0] = 1.0 + xy_arr = (xy_arr - min_xy) / range_xy + + self.xy_coords = torch.tensor(xy_arr, device='cpu') + return self.xy_coords.shape[0] + + def _load_static(self): + r"""Load static feature files, including cell area.""" + if not self.static_files: + self.static_data = torch.zeros((self.reference_cell_count, 0), device='cpu') + return + static_list = [] + for fname in self.static_files: + fpath = self.data_root / fname + if not fpath.exists(): + warnings.warn(f"Static file not found in rollout folder: {fpath}, skipping.") + continue + arr = np.loadtxt(str(fpath), delimiter="\t", dtype=np.float32) + + if arr.ndim == 1: + arr = arr[:, None] + n_file = arr.shape[0] + if n_file < self.reference_cell_count: + msg = f"Static {fname} has {n_file} < {self.reference_cell_count}" + if self.raise_on_smaller: + raise ValueError(msg) + else: + warnings.warn(msg + " -> skipping.") + continue + elif n_file > self.reference_cell_count: + arr = arr[:self.reference_cell_count, :] + + # Capture cell area AFTER trimming to match reference cell count + if "M40_CA.txt" in str(fname): + self.cell_area = torch.from_numpy(arr.flatten()).float() + + static_list.append(arr) + + if self.cell_area is None: + warnings.warn( + "Cell Area file ('M40_CA.txt') not found in static_files. Volume conservation cannot be calculated.") + + if not static_list: + self.static_data = torch.zeros((self.reference_cell_count, 0), device='cpu') + return + + combined_arr = np.concatenate(static_list, axis=1) + self.static_data = torch.tensor(combined_arr, device='cpu') + + def _load_all_runs(self): + r"""Load dynamic and boundary data for all runs.""" + for run_id in tqdm(self.run_ids, desc="Loading runs for rollout evaluation"): + self.dynamic_data[run_id] = {} + self.boundary_data[run_id] = {} + # dynamic + for dkey, pattern in self.dynamic_patterns.items(): + fname = pattern.format(run_id) + fpath = self.data_root / fname + if not fpath.exists(): + warnings.warn(f"Dynamic file not found: {fpath}, skipping {dkey}.") + continue + arr = np.loadtxt(str(fpath), delimiter="\t", dtype=np.float32) + N_file = arr.shape[1] + if N_file < self.reference_cell_count: + msg = f"{fname} has {N_file} < {self.reference_cell_count}" + if self.raise_on_smaller: + raise ValueError(msg) + else: + warnings.warn(msg + " -> skipping.") + continue + elif N_file > self.reference_cell_count: + arr = arr[:, :self.reference_cell_count] + self.dynamic_data[run_id][dkey] = torch.tensor(arr, device='cpu') + # boundary + for bc_key, bc_pattern in self.boundary_patterns.items(): + fname = bc_pattern.format(run_id) + fpath = self.data_root / fname + if not fpath.exists(): + warnings.warn(f"Boundary file not found: {fpath}, skipping {bc_key}.") + continue + bc_arr = np.loadtxt(str(fpath), delimiter="\t", dtype=np.float32) + if bc_arr.ndim == 1: + bc_arr = bc_arr[:, None] + if bc_arr.shape[1] == 2: + bc_arr = bc_arr[:, 1].reshape(-1, 1) + bc_tensor = torch.tensor(bc_arr, device='cpu') + bc_tensor = bc_tensor.expand(-1, self.reference_cell_count) + bc_tensor = bc_tensor.unsqueeze(-1) + self.boundary_data[run_id][bc_key] = bc_tensor + + def __len__(self): + return len(self.valid_run_ids) + + def __getitem__(self, idx): + r"""Get a single rollout sample.""" + run_id = self.valid_run_ids[idx] + dynamic_vars = ["WD", "VX", "VY"] + dynamic = torch.stack([self.dynamic_data[run_id][var] for var in dynamic_vars], dim=-1) + boundary_keys = sorted(list(self.boundary_data[run_id].keys())) + if boundary_keys: + boundary = torch.cat([self.boundary_data[run_id][var] for var in boundary_keys], dim=-1) + else: + # Ensure device consistency: xy_coords is on CPU, match dynamic tensor device + boundary_device = dynamic.device if isinstance(dynamic, torch.Tensor) else self.xy_coords.device + boundary = torch.zeros((dynamic.shape[0], self.reference_cell_count, 1), device=boundary_device) + + sample = { + "run_id": run_id, + "dynamic": dynamic, + "boundary": boundary, + "geometry": self.xy_coords, + "static": self.static_data, + } + if self.cell_area is not None: + sample["cell_area"] = self.cell_area + return sample + diff --git a/examples/weather/flood_modeling/flood_forecaster/inference.py b/examples/weather/flood_modeling/flood_forecaster/inference.py new file mode 100644 index 0000000000..158d3c74d6 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/inference.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Inference script for FloodForecaster using trained GINO model. + +This script loads a trained model checkpoint and performs rollout evaluation +on test data, generating visualizations and metrics. +""" + +import sys +from pathlib import Path + +import hydra +import torch +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf + +from neuralop import get_model + +from physicsnemo.distributed.manager import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.launch.utils.checkpoint import load_checkpoint + +from datasets import FloodRolloutTestDatasetNew, NormalizedRolloutTestDataset +from data_processing import FloodGINODataProcessor, GINOWrapper +from inference.rollout import rollout_prediction +from utils.normalization import collect_all_fields, transform_with_existing_normalizers + + +def log_section(logger: RankZeroLoggingWrapper, title: str, char: str = "=", width: int = 60): + r"""Log a section header for visual separation.""" + separator = char * width + logger.info("") + logger.info(separator) + logger.info(title) + logger.info(separator) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def run_inference(cfg: DictConfig) -> None: + r""" + Run inference using a trained model checkpoint. + + This function loads a trained model and performs rollout evaluation: + 1. Load model from checkpoint + 2. Load and normalize test data + 3. Perform autoregressive rollout + 4. Generate visualizations and metrics + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object. + + Raises + ------ + SystemExit + If critical errors occur during execution. + """ + # Initialize distributed manager (must be called first) + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize logging + log = PythonLogger(name="flood_forecaster_inference") + log_rank_zero = RankZeroLoggingWrapper(log, dist) + + log_section(log_rank_zero, "FLOOD FORECASTER - Inference and Evaluation") + + try: + # Get device from distributed manager or config + device = dist.device if dist.device is not None else cfg.distributed.device + + # Log device information + log_rank_zero.info("=" * 50) + log_rank_zero.info(f"PyTorch version: {torch.__version__}") + log_rank_zero.info(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + log_rank_zero.info(f"CUDA version: {torch.version.cuda}") + log_rank_zero.info(f"GPU device: {torch.cuda.get_device_name(0)}") + log_rank_zero.info(f"Using device: {device}") + log_rank_zero.info(f"Distributed: rank={dist.rank}, world_size={dist.world_size}") + log_rank_zero.info("=" * 50) + + # Check checkpoint path + checkpoint_path = cfg.checkpoint.get("resume_from_adapt") or cfg.checkpoint.get("resume_from_source") + if checkpoint_path is None: + log_rank_zero.error("No checkpoint path specified in config.checkpoint.resume_from_adapt or resume_from_source") + sys.exit(1) + + checkpoint_path = Path(to_absolute_path(checkpoint_path)) + if not checkpoint_path.exists(): + log_rank_zero.error(f"Checkpoint path does not exist: {checkpoint_path}") + sys.exit(1) + + log_rank_zero.info(f"Loading model from checkpoint: {checkpoint_path}") + + # Create model (same as training) + log_rank_zero.info("Creating GINO model...") + # Convert config.model to dict to avoid struct mode issues with neuralop's get_model + # neuralop's get_model tries to pop from config, which doesn't work with struct mode + # It expects config.model to exist, so we wrap it in a new OmegaConf DictConfig + # (not in struct mode) that supports both attribute and dict access + model_config_dict = OmegaConf.to_container(cfg.model, resolve=True) + # Create a wrapper config that neuralop expects: {"model": {...}} + # Extract autoregressive parameter before passing to get_model (GINO doesn't accept it) + autoregressive = model_config_dict.pop("autoregressive", False) + + # Convert to OmegaConf DictConfig (not struct mode) so it supports attribute access + wrapper_config = OmegaConf.create({"model": model_config_dict}) + gino_model = get_model(wrapper_config) + gino_model = gino_model.to(device) + + # Create GINOWrapper first (checkpoints are saved as GINOWrapper in PhysicsNeMo format) + # Enable autoregressive residual connection if specified in config + model = GINOWrapper(gino_model, autoregressive=autoregressive) + model = model.to(device) + + # Load checkpoint into the model + # Support both PhysicsNeMo format (new) and neuralop format (old) for backward compatibility + checkpoint_loaded = False + + # Try PhysicsNeMo format first (new format) + # Checkpoints are saved as GINOWrapper (PhysicsNeMo Module), so we load into the wrapper + try: + metadata_dict = {} + load_checkpoint( + path=str(checkpoint_path), + models=model, # Load into GINOWrapper (PhysicsNeMo Module) + optimizer=None, + scheduler=None, + scaler=None, + epoch=None, # Load latest checkpoint + metadata_dict=metadata_dict, + device=device, + ) + log_rank_zero.info("Loaded checkpoint using PhysicsNeMo format") + checkpoint_loaded = True + except (FileNotFoundError, KeyError, ValueError) as e: + # Fall back to neuralop format (old format) + log_rank_zero.info(f"PhysicsNeMo checkpoint not found, trying neuralop format: {e}") + + # Try neuralop format (old format) if PhysicsNeMo format failed + # For old format, we need to load into the inner gino_model, not the wrapper + if not checkpoint_loaded: + try: + from neuralop.training.training_state import load_training_state + + # Check for neuralop checkpoint files + if (checkpoint_path / "best_model_state_dict.pt").exists(): + save_name = "best_model" + elif (checkpoint_path / "model_state_dict.pt").exists(): + save_name = "model" + else: + log_rank_zero.error(f"No checkpoint found in {checkpoint_path}") + log_rank_zero.error("Tried both PhysicsNeMo and neuralop formats") + sys.exit(1) + + # Load checkpoint using neuralop format into the inner model + # Extract the inner model from GINOWrapper for loading + inner_model = model.model if hasattr(model, 'model') else gino_model + inner_model, _, _, _, _ = load_training_state( + save_dir=checkpoint_path, + save_name=save_name, + model=inner_model, + optimizer=None, + scheduler=None, + ) + log_rank_zero.info(f"Loaded checkpoint using neuralop format: {save_name}") + checkpoint_loaded = True + except Exception as e: + log_rank_zero.error(f"Failed to load checkpoint in both formats: {e}") + sys.exit(1) + + if not checkpoint_loaded: + log_rank_zero.error("Failed to load checkpoint in any format") + sys.exit(1) + + # Load normalizers from checkpoint if available + # Normalizers are saved in both pretrain and adapt folders, so check both locations + normalizers_path = checkpoint_path / "normalizers.pt" + if not normalizers_path.exists(): + # If loading from adapt folder, also check pretrain folder as fallback + # (normalizers are saved in both places and are the same) + if checkpoint_path.name == "adapt": + pretrain_path = checkpoint_path.parent / "pretrain" / "normalizers.pt" + if pretrain_path.exists(): + normalizers_path = pretrain_path + log_rank_zero.info(f"Normalizers not found in adapt folder, checking pretrain folder...") + + if normalizers_path.exists(): + log_rank_zero.info(f"Loading normalizers from: {normalizers_path}") + normalizers = torch.load(normalizers_path, map_location=device) + log_rank_zero.info("Normalizers loaded successfully") + else: + # Fallback: recreate normalizers from source data if not saved + log_rank_zero.info("Normalizers not found in checkpoint or pretrain folder. Recreating from source data...") + from datasets import FloodDatasetWithQueryPoints, NormalizedDataset + from utils.normalization import stack_and_fit_transform + + source_full_dataset = FloodDatasetWithQueryPoints( + data_root=cfg.source_data.root, + n_history=cfg.source_data.n_history, + xy_file=cfg.source_data.get("xy_file", None), + query_res=cfg.source_data.get("query_res", [64, 64]), + static_files=cfg.source_data.get("static_files", []), + dynamic_patterns=cfg.source_data.get("dynamic_patterns", {}), + boundary_patterns=cfg.source_data.get("boundary_patterns", {}), + raise_on_smaller=True, + skip_before_timestep=cfg.source_data.get("skip_before_timestep", 0), + noise_type="none", + noise_std=None, + ) + + # Use a subset to fit normalizers (faster) + from torch.utils.data import random_split + + train_sz = min(100, int(0.9 * len(source_full_dataset))) # Use up to 100 samples + source_train_subset, _ = random_split(source_full_dataset, [train_sz, len(source_full_dataset) - train_sz]) + + geom, static, boundary, dyn, tgt = collect_all_fields(source_train_subset, True) + normalizers, _ = stack_and_fit_transform(geom, static, boundary, dyn, tgt) + log_rank_zero.info("Normalizers recreated from source data") + + # Create data processor + data_processor = FloodGINODataProcessor( + device=device, + target_norm=normalizers.get("target", None), + inverse_test=True, + ) + data_processor.wrap(model) + + # Load rollout test dataset + log_section(log_rank_zero, "Loading Rollout Test Dataset") + rollout_test_dataset = FloodRolloutTestDatasetNew( + rollout_data_root=cfg.rollout_data.root, + n_history=cfg.source_data.n_history, + rollout_length=cfg.source_data.rollout_length, + xy_file=cfg.rollout_data.get("xy_file", None), + query_res=cfg.source_data.get("query_res", [32, 32]), + static_files=cfg.rollout_data.get("static_files", []), + dynamic_patterns=cfg.rollout_data.get("dynamic_patterns", {}), + boundary_patterns=cfg.rollout_data.get("boundary_patterns", {}), + raise_on_smaller=True, + skip_before_timestep=cfg.source_data.get("skip_before_timestep", 0), + ) + log_rank_zero.info(f"Loaded {len(rollout_test_dataset)} rollout test samples") + + # Collect and normalize rollout data + ( + rollout_geom, + rollout_static, + rollout_boundary, + rollout_dyn, + _, + rollout_cell_area, + ) = collect_all_fields(rollout_test_dataset, expect_target=False) + + # Move normalizers to CPU for data transformation + for norm in normalizers.values(): + norm.to("cpu") + + transformed_rollout = transform_with_existing_normalizers( + rollout_geom, rollout_static, rollout_boundary, rollout_dyn, normalizers + ) + + normalized_rollout_samples = [ + { + "run_id": rollout_test_dataset.valid_run_ids[i], + "geometry": transformed_rollout["geometry"][i], + "static": transformed_rollout["static"][i], + "boundary": transformed_rollout["boundary"][i], + "dynamic": transformed_rollout["dynamic"][i], + "cell_area": rollout_cell_area[i], + } + for i in range(len(rollout_test_dataset)) + ] + + # Run rollout prediction + log_section(log_rank_zero, "Running Rollout Prediction") + rollout_prediction( + model=model, + rollout_dataset=NormalizedRolloutTestDataset(normalized_rollout_samples, cfg.source_data.query_res), + rollout_length=cfg.source_data.rollout_length, + history_steps=cfg.source_data.n_history, + dynamic_norm=normalizers["dynamic"], + target_norm=normalizers["target"], + boundary_norm=normalizers["boundary"], + device=device, + skip_before_timestep=cfg.source_data.get("skip_before_timestep", 0), + dt=cfg.source_data.dt, + out_dir=cfg.rollout.out_dir, + logger=log_rank_zero, + ) + + log_section(log_rank_zero, "Inference Complete!") + + except KeyboardInterrupt: + log_rank_zero.warning("Inference interrupted by user") + sys.exit(1) + except Exception as e: + import traceback + log_rank_zero.error(f"Fatal error in inference pipeline: {e}") + log_rank_zero.error(traceback.format_exc()) + raise + + +if __name__ == "__main__": + run_inference() + diff --git a/examples/weather/flood_modeling/flood_forecaster/inference/__init__.py b/examples/weather/flood_modeling/flood_forecaster/inference/__init__.py new file mode 100644 index 0000000000..29b7736751 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/inference/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Inference modules for flood prediction.""" + +from .rollout import rollout_prediction + +__all__ = ["rollout_prediction"] + diff --git a/examples/weather/flood_modeling/flood_forecaster/inference/rollout.py b/examples/weather/flood_modeling/flood_forecaster/inference/rollout.py new file mode 100644 index 0000000000..27fe9e730c --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/inference/rollout.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Rollout prediction and evaluation module. + +Compatible with neuralop 2.0.0 API. +""" + +import os +import time +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import tqdm + +from utils.plotting import ( + create_rollout_animation, + generate_publication_maps, + generate_max_value_maps, + generate_combined_analysis_maps, + plot_volume_conservation, + plot_conditional_error_analysis, + plot_aggregated_scalar_metrics, + plot_event_magnitude_analysis, +) + + +def compute_csi(threshold, pred, gt): + r""" + Compute Critical Success Index (CSI) for binary classification. + + Parameters + ---------- + threshold : float + Threshold value for binary classification. + pred : np.ndarray + Predicted values. + gt : np.ndarray + Ground truth values. + + Returns + ------- + float + CSI score. + """ + event_pred, event_gt = pred >= threshold, gt >= threshold + TP = np.sum(event_pred & event_gt) + FP = np.sum(event_pred & (~event_gt)) + FN = np.sum((~event_pred) & event_gt) + return TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 1.0 + + +def rollout_prediction( + model, + rollout_dataset, + rollout_length, + history_steps, + dynamic_norm, + target_norm, + boundary_norm, + device, + skip_before_timestep, + dt, + out_dir="./rollout_gifs", + logger=None, +): + r""" + Performs autoregressive rollout, computing and plotting metrics for both water depth and velocity. + + Compatible with neuralop 2.0.0 API. + + Parameters + ---------- + model : nn.Module + Trained model. + rollout_dataset : Dataset + Dataset for rollout evaluation. + rollout_length : int + Length of rollout to perform. + history_steps : int + Number of history timesteps. + dynamic_norm : UnitGaussianNormalizer + Normalizer for dynamic features. + target_norm : UnitGaussianNormalizer + Normalizer for target. + boundary_norm : UnitGaussianNormalizer + Normalizer for boundary conditions. + device : str or torch.device + Device to run inference on. + skip_before_timestep : int + Number of timesteps to skip at beginning. + dt : float + Time step size in seconds. + out_dir : str, optional, default="./rollout_gifs" + Output directory for results. + logger : Any, optional + Optional logger instance. + """ + if logger is None: + # Fallback to print if no logger provided + def log_info(msg): + print(msg) + + logger = type("Logger", (), {"info": lambda self, msg: log_info(msg)})() + + logger.info(f"Starting rollout prediction on {len(rollout_dataset)} samples") + logger.info(f"Rollout length: {rollout_length}, History steps: {history_steps}") + logger.info(f"Output directory: {out_dir}") + + model = model.to(device) + model.eval() + dynamic_norm = dynamic_norm.to(device) + target_norm = target_norm.to(device) + boundary_norm = boundary_norm.to(device) + + # Initialize lists for aggregated metrics + aggregated_metrics = { + 'rmse_wd': [], 'csi_005': [], 'csi_03': [], 'rmse_vx': [], 'rmse_vy': [], + 'h_V2_rmse': [], 'fhca': [], + 'arrival_mae': [], 'duration_mae': [], + } + + # Initialize lists for event magnitude analysis + event_q_peaks = [] + event_total_volumes = [] + event_avg_rmse_wd = [] + + # Initialize list to store inference times + rollout_inference_times = [] + + for idx, sample in enumerate(tqdm(rollout_dataset, desc="Performing rollout evaluation")): + run_id = sample.get("run_id", f"sample_{idx}") + full_dynamic = sample["dynamic"].to(device) + full_boundary = sample["boundary"].to(device) + geometry = sample["geometry"] + + cell_area = sample.get("cell_area", None) + if cell_area is not None: + cell_area = cell_area.cpu().numpy() + + # Calculate hydrograph characteristics for the current event + unnormalized_boundary = boundary_norm.inverse_transform(full_boundary).squeeze(0) + inflow_hydrograph = unnormalized_boundary[:, 0, 0].cpu().numpy() + q_peak = np.max(inflow_hydrograph) + total_volume = np.sum(inflow_hydrograph) * dt + event_q_peaks.append(q_peak) + event_total_volumes.append(total_volume) + + start_pred_t = skip_before_timestep + history_steps + end_pred_t = start_pred_t + rollout_length + gt_rollout = full_dynamic[start_pred_t:end_pred_t] + gt_boundary_rollout = full_boundary[start_pred_t:end_pred_t] + + wd_pred_list, wd_gt_list = [], [] + vx_pred_list, vy_pred_list = [], [] + vx_gt_list, vy_gt_list = [], [] + run_ts_metrics = {'rmse_wd': [], 'csi_005': [], 'csi_03': [], 'rmse_vx': [], 'rmse_vy': []} + + # Record start time for the rollout + start_time = time.time() + + current_dynamic = full_dynamic[skip_before_timestep:start_pred_t].clone() + current_boundary = full_boundary[skip_before_timestep:start_pred_t].clone() + + for t in range(rollout_length): + # Prepare input tensors + dyn_flat = current_dynamic.permute(1, 0, 2).reshape(1, current_dynamic.shape[1], -1) + bc_flat = current_boundary.permute(1, 0, 2).reshape(1, current_boundary.shape[1], -1) + x = torch.cat([sample["static"].to(device).unsqueeze(0), bc_flat, dyn_flat], dim=2) + + with torch.no_grad(): + # Call model with GINO signature + pred = model( + input_geom=geometry.to(device).unsqueeze(0), + latent_queries=sample["query_points"].to(device).unsqueeze(0), + output_queries=geometry.to(device).unsqueeze(0), + x=x + ) + + # Inverse transform predictions and ground truth + inv_pred = target_norm.inverse_transform(pred) + inv_gt = dynamic_norm.inverse_transform(gt_rollout[t].unsqueeze(0)) + + # Extract water depth and velocity components + wd_pred, vx_pred, vy_pred = [ch.cpu().numpy() for ch in inv_pred[0].T] + wd_gt, vx_gt, vy_gt = [ch.cpu().numpy() for ch in inv_gt[0].T] + + wd_pred_list.append(wd_pred) + wd_gt_list.append(wd_gt) + vx_pred_list.append(vx_pred) + vx_gt_list.append(vx_gt) + vy_pred_list.append(vy_pred) + vy_gt_list.append(vy_gt) + + # Time-step metrics + run_ts_metrics['rmse_wd'].append(np.sqrt(np.mean((wd_pred - wd_gt) ** 2))) + run_ts_metrics['csi_005'].append(compute_csi(0.05, wd_pred, wd_gt)) + run_ts_metrics['csi_03'].append(compute_csi(0.3, wd_pred, wd_gt)) + run_ts_metrics['rmse_vx'].append(np.sqrt(np.mean((vx_pred - vx_gt) ** 2))) + run_ts_metrics['rmse_vy'].append(np.sqrt(np.mean((vy_pred - vy_gt) ** 2))) + + # Update current dynamic state with prediction + current_dynamic = torch.cat([current_dynamic[1:], pred.squeeze(0).unsqueeze(0)], dim=0) + current_boundary = torch.cat([current_boundary[1:], gt_boundary_rollout[t].unsqueeze(0)], dim=0) + + # Record end time and append the duration + end_time = time.time() + rollout_inference_times.append(end_time - start_time) + + # Convert lists to numpy arrays for this run + wd_pred_arr, wd_gt_arr = np.stack(wd_pred_list), np.stack(wd_gt_list) + vx_pred_arr, vy_pred_arr = np.stack(vx_pred_list), np.stack(vy_pred_list) + vx_gt_arr, vy_gt_arr = np.stack(vx_gt_list), np.stack(vy_gt_list) + + # Store the overall error for this event + avg_rmse_for_run = np.mean(run_ts_metrics['rmse_wd']) + event_avg_rmse_wd.append(avg_rmse_for_run) + + # Append run-averaged metrics to aggregated lists + for key in ['rmse_wd', 'csi_005', 'csi_03', 'rmse_vx', 'rmse_vy']: + aggregated_metrics[key].append(np.array(run_ts_metrics[key])) + + figures_path = os.path.join(out_dir, "figures_final") + os.makedirs(figures_path, exist_ok=True) + + # Generate Plots and Scalar Metrics for this Run + logger.info(f"Generating plots for run {run_id}...") + generate_publication_maps( + geometry, + wd_gt_arr, + wd_pred_arr, + vx_gt_arr, + vy_gt_arr, + vx_pred_arr, + vy_pred_arr, + [12, 24, 36, 48, 60, 72], + figures_path, + run_id, + ) + logger.info(f"Saved publication maps for run {run_id}") + generate_max_value_maps( + geometry, + wd_gt_arr, + wd_pred_arr, + vx_gt_arr, + vy_gt_arr, + vx_pred_arr, + vy_pred_arr, + figures_path, + run_id, + ) + logger.info(f"Saved max value maps for run {run_id}") + + mae_arrival, mae_duration, rmse_hv2, fhca = generate_combined_analysis_maps( + geometry, + wd_gt_arr, + wd_pred_arr, + vx_gt_arr, + vy_gt_arr, + vx_pred_arr, + vy_pred_arr, + dt, + figures_path, + run_id, + ) + logger.info(f"Saved combined analysis plot for run {run_id}") + aggregated_metrics["arrival_mae"].append(mae_arrival) + aggregated_metrics["duration_mae"].append(mae_duration) + aggregated_metrics["h_V2_rmse"].append(rmse_hv2) + aggregated_metrics["fhca"].append(fhca) + + # Generate Volume Conservation Plot + plot_volume_conservation(wd_gt_arr, wd_pred_arr, cell_area, dt, figures_path, run_id) + logger.info(f"Saved volume conservation plot for run {run_id}") + + # Generate Conditional Error Plot + plot_conditional_error_analysis( + wd_gt_arr, + wd_pred_arr, + vx_gt_arr, + vy_gt_arr, + vx_pred_arr, + vy_pred_arr, + figures_path, + run_id, + ) + logger.info(f"Saved conditional error plot for run {run_id}") + + # Generate Rollout Animation (GIF) + create_rollout_animation( + geometry, + wd_gt_arr, + wd_pred_arr, + vx_gt_arr, + vy_gt_arr, + vx_pred_arr, + vy_pred_arr, + run_id=run_id, + out_dir=out_dir, + filename_prefix="rollout", + dt_seconds=dt, + ) + logger.info(f"Saved rollout animation for run {run_id}") + + if aggregated_metrics['rmse_wd']: + # Process and plot aggregated metrics across all runs + ts_metrics = { + k: np.stack(v) for k, v in aggregated_metrics.items() + if k in ['rmse_wd', 'csi_005', 'csi_03', 'rmse_vx', 'rmse_vy'] + } + ts_stats = {key: {'mean': arr.mean(axis=0), 'std': arr.std(axis=0)} for key, arr in ts_metrics.items()} + time_hours = (np.arange(1, rollout_length + 1) * dt) / 3600.0 + + fig, axs = plt.subplots(3, 2, figsize=(16, 18), tight_layout=True) + axs = axs.flatten() + fig.suptitle("Time-Series Metrics During Rollout", fontsize=18) + plot_info = { + 0: ('rmse_wd', 'RMSE (Depth)', 'RMSE (m)'), + 1: ('rmse_vx', r'RMSE ($V_{x}$)', 'RMSE (m/s)'), + 2: ('rmse_vy', r'RMSE ($V_{y}$)', 'RMSE (m/s)'), + 3: ('csi_005', 'CSI (0.05m)', 'CSI'), + 4: ('csi_03', 'CSI (0.3m)', 'CSI') + } + for i, ax in enumerate(axs): + if i in plot_info: + key, title, ylabel = plot_info[i] + mean, std = ts_stats[key]['mean'], ts_stats[key]['std'] + ax.plot(time_hours, mean, label=f'{title} Mean', marker='o', markersize=4) + ax.fill_between(time_hours, mean - std, mean + std, alpha=0.3, label='+/-1 Std Dev') + ax.set_title(title) + ax.set_xlabel("Time (hour)") + ax.set_ylabel(ylabel) + ax.legend() + ax.grid(True, linestyle='--') + else: + ax.set_visible(False) + plt.savefig(os.path.join(out_dir, "rollout_metrics_summary.png")) + plt.close(fig) + logger.info(f"Saved aggregated rollout metrics plot to: {os.path.join(out_dir, 'rollout_metrics_summary.png')}") + + # Scalar metrics plotting + scalar_metrics_for_plot = { + 'h_V2_rmse': aggregated_metrics['h_V2_rmse'], + 'fhca': aggregated_metrics['fhca'], + 'arrival_mae_hrs': np.array(aggregated_metrics['arrival_mae']) / 3600.0, + 'duration_mae_hrs': np.array(aggregated_metrics['duration_mae']) / 3600.0, + } + plot_aggregated_scalar_metrics(scalar_metrics_for_plot, out_dir) + logger.info("Saved aggregated scalar metrics plots") + + # Call the event magnitude analysis plotting function + plot_event_magnitude_analysis( + q_peaks=event_q_peaks, + total_volumes=event_total_volumes, + avg_rmses_wd=event_avg_rmse_wd, + out_dir=out_dir, + ) + logger.info("Saved event magnitude analysis plots") + + # Calculate and log timing statistics + if rollout_inference_times: + mean_inference_time = np.mean(rollout_inference_times) + std_inference_time = np.std(rollout_inference_times) + logger.info("=" * 60) + logger.info("Inference Time Summary") + logger.info("=" * 60) + logger.info(f"Time per full rollout (averaged over {len(rollout_inference_times)} hydrographs):") + logger.info(f" Mean: {mean_inference_time:.4f} seconds") + logger.info(f" Std Dev: {std_inference_time:.4f} seconds") + + # Final summary and data saving + logger.info("") + logger.info("=" * 60) + logger.info("Aggregated Rollout Metrics Summary") + logger.info("=" * 60) + scalar_stats = { + key: {'mean': np.nanmean(v), 'std': np.nanstd(v)} + for key, v in scalar_metrics_for_plot.items() + } + logger.info("Scalar Hydrological Metrics (averaged over all runs):") + for key, stat in scalar_stats.items(): + logger.info(f" {key:<20}: Mean={stat['mean']:.4f}, Std={stat['std']:.4f}") + + npz_data = {'time_hours': time_hours} + for key, stat_dict in ts_stats.items(): + npz_data[f'{key}_mean'] = stat_dict['mean'] + npz_data[f'{key}_std'] = stat_dict['std'] + for key, data in scalar_metrics_for_plot.items(): + npz_data[f'{key}_all_runs'] = np.array(data) + + # Add timings to saved data + if rollout_inference_times: + npz_data['rollout_inference_times'] = np.array(rollout_inference_times) + + metrics_file = os.path.join(out_dir, "rollout_metrics_data.npz") + np.savez(metrics_file, **npz_data) + logger.info(f"Saved all aggregated rollout metrics data to: {metrics_file}") diff --git a/examples/weather/flood_modeling/flood_forecaster/requirements.txt b/examples/weather/flood_modeling/flood_forecaster/requirements.txt new file mode 100644 index 0000000000..06de0fd6c1 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/requirements.txt @@ -0,0 +1,11 @@ +hydra-core>=1.2.0 +neuralop>=2.0.0 +wandb>=0.12.0 +matplotlib>=3.5.0 +tqdm>=4.62.0 +numpy>=1.21.0 +torch>=2.0.0 +omegaconf>=2.3.0 +pandas>=1.3.0 +h5py>=3.0.0 + diff --git a/examples/weather/flood_modeling/flood_forecaster/train.py b/examples/weather/flood_modeling/flood_forecaster/train.py new file mode 100644 index 0000000000..6a412d510c --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/train.py @@ -0,0 +1,444 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Training script for FloodForecaster using GINO with domain adaptation. + +This script implements a two-stage training pipeline: +1. Pretraining on source domain +2. Domain adaptation on source + target domains + +For rollout evaluation and visualization, use inference.py instead. +""" + +import os +import sys +from typing import Any, Dict, Optional + +import hydra +import torch +import wandb +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import DataLoader, random_split + +from neuralop import get_model +from neuralop.utils import get_wandb_api_key + +from physicsnemo.distributed.manager import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.launch.logging.wandb import initialize_wandb + +from datasets import ( + FloodDatasetWithQueryPoints, + NormalizedDataset, +) +from data_processing import FloodGINODataProcessor +from training.pretraining import pretrain_model +from training.domain_adaptation import adapt_model +from utils.normalization import ( + collect_all_fields, + stack_and_fit_transform, +) + + +def _register_hydra_resolvers() -> None: + r""" + Register custom Hydra resolvers if needed. + + This function can be extended to register custom resolvers for the config. + Currently, OmegaConf provides built-in resolvers like oc.env for environment + variables, so no custom registration is needed. + + Note: The config uses ${VAR:default} syntax which is Hydra's legacy + environment variable interpolation. If this causes issues, consider migrating + to ${oc.env:VAR,default} syntax in the config file. + """ + # Placeholder for future custom resolvers if needed + # oc.env is already built into OmegaConf, so no registration needed + pass + + +def safe_config_to_dict( + cfg: DictConfig, + exclude_keys: Optional[list] = None, + logger: Optional[RankZeroLoggingWrapper] = None, +) -> Dict[str, Any]: + r""" + Safely convert OmegaConf DictConfig to a Python dictionary for wandb logging. + + This function handles unresolved interpolations gracefully by: + 1. Attempting full resolution first + 2. Filtering out problematic keys if resolution fails + 3. Falling back to partial resolution if needed + + Parameters + ---------- + cfg : DictConfig + The OmegaConf configuration object to convert. + exclude_keys : list, optional + List of top-level keys to exclude from the output (e.g., ['rollout_data']). + Defaults to ['rollout_data'] since it's only needed for inference. + logger : RankZeroLoggingWrapper, optional + Logger instance for warning messages. If None, warnings are suppressed. + + Returns + ------- + Dict[str, Any] + A Python dictionary representation of the config, suitable for wandb logging. + + Examples + -------- + >>> config_dict = safe_config_to_dict(cfg, exclude_keys=['rollout_data']) + >>> wandb.init(config=config_dict) + """ + if exclude_keys is None: + exclude_keys = ["rollout_data"] # Default: exclude rollout_data (only for inference) + + # Strategy 1: Try full resolution first + try: + config_dict = OmegaConf.to_container(cfg, resolve=True) + # Remove excluded keys + for key in exclude_keys: + config_dict.pop(key, None) + return config_dict + except Exception as e: + if logger: + logger.info( + f"Config resolution encountered unresolved interpolations: {type(e).__name__}. " + f"Attempting filtered resolution..." + ) + + # Strategy 2: Filter problematic keys, then resolve + try: + # Get unresolved config as dict + config_dict_unresolved = OmegaConf.to_container(cfg, resolve=False) + + # Remove excluded keys + filtered_dict = { + key: value + for key, value in config_dict_unresolved.items() + if key not in exclude_keys + } + + # Create new config from filtered dict and try to resolve + # Use struct=False to allow modifications during resolution + filtered_cfg = OmegaConf.create(filtered_dict) + OmegaConf.set_struct(filtered_cfg, False) + + try: + config_dict = OmegaConf.to_container(filtered_cfg, resolve=True) + return config_dict + except Exception: + # If resolution still fails, try resolving each top-level key individually + partially_resolved = {} + for key, value in filtered_dict.items(): + try: + # Try to resolve this key's section + key_cfg = OmegaConf.create({key: value}) + resolved_key = OmegaConf.to_container(key_cfg, resolve=True) + partially_resolved[key] = resolved_key[key] + except Exception: + # If this key fails, include it unresolved + if isinstance(value, DictConfig): + partially_resolved[key] = OmegaConf.to_container(value, resolve=False) + else: + partially_resolved[key] = value + return partially_resolved + except Exception as e: + if logger: + logger.warning( + f"Could not fully resolve config: {type(e).__name__}. " + f"Using partially resolved config for wandb logging." + ) + + # Strategy 3: Fallback to unresolved config (last resort) + config_dict = OmegaConf.to_container(cfg, resolve=False) + for key in exclude_keys: + config_dict.pop(key, None) + return config_dict + + +def log_section(logger: RankZeroLoggingWrapper, title: str, char: str = "=", width: int = 60): + r"""Log a section header for visual separation.""" + separator = char * width + logger.info("") + logger.info(separator) + logger.info(title) + logger.info(separator) + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def train_flood_forecaster(cfg: DictConfig) -> None: + r""" + Main training pipeline for FloodForecaster. + + This function orchestrates the complete training workflow: + 1. Configuration loading and device setup + 2. Pretraining on source domain + 3. Domain adaptation on source + target domains + + After training completes, use inference.py to perform rollout evaluation + and generate visualizations. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object. + + Raises + ------ + SystemExit + If critical errors occur during execution. + """ + # Register custom Hydra resolvers for environment variable interpolation + _register_hydra_resolvers() + + # Initialize distributed manager (must be called first) + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize logging + log = PythonLogger(name="flood_forecaster") + log_rank_zero = RankZeroLoggingWrapper(log, dist) + + log_section(log_rank_zero, "FLOOD FORECASTER - Training and Evaluation Pipeline") + + try: + # Get device from distributed manager or config + device = dist.device if dist.device is not None else cfg.distributed.device + is_logger = dist.rank == 0 + + # Log device information prominently + log_rank_zero.info("=" * 50) + log_rank_zero.info(f"PyTorch version: {torch.__version__}") + log_rank_zero.info(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + log_rank_zero.info(f"CUDA version: {torch.version.cuda}") + log_rank_zero.info(f"GPU device: {torch.cuda.get_device_name(0)}") + log_rank_zero.info( + f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" + ) + log_rank_zero.info(f"Using device: {device}") + log_rank_zero.info(f"Distributed: rank={dist.rank}, world_size={dist.world_size}") + log_rank_zero.info("=" * 50) + + if not torch.cuda.is_available(): + log_rank_zero.warning("CUDA is not available! Training will be very slow on CPU.") + log_rank_zero.warning("Please check your PyTorch installation with CUDA support.") + + # Adjust FNO modes if needed (access via OmegaConf) + if ( + hasattr(cfg, "source_data") + and hasattr(cfg.source_data, "resolution") + and hasattr(cfg.model, "fno_n_modes") + and cfg.source_data.resolution < cfg.model.fno_n_modes[0] + ): + cfg.model.fno_n_modes = [cfg.source_data.resolution] * len(cfg.model.fno_n_modes) + # Safely log debug message - PythonLogger doesn't have debug method + try: + # Check if logger has a 'logger' attribute (underlying logging.Logger) + # RankZeroLoggingWrapper wraps PythonLogger which has a 'logger' attribute + if hasattr(log_rank_zero, 'obj') and hasattr(log_rank_zero.obj, 'logger'): + log_rank_zero.obj.logger.debug(f"Adjusted FNO modes to: {cfg.model.fno_n_modes}") + elif hasattr(log_rank_zero, 'logger') and hasattr(log_rank_zero.logger, 'debug'): + log_rank_zero.logger.debug(f"Adjusted FNO modes to: {cfg.model.fno_n_modes}") + # Fallback: try direct debug method (for loggers that support it) + elif hasattr(log_rank_zero, 'debug'): + log_rank_zero.debug(f"Adjusted FNO modes to: {cfg.model.fno_n_modes}") + except (AttributeError, TypeError): + # Skip debug logging if not available (not critical) + pass + + # Initialize wandb if logging is enabled + if cfg.wandb.log and is_logger: + log_rank_zero.info("Initializing Weights & Biases logging...") + # Try to login if API key is available, but don't fail if it's not + # wandb.init() will handle authentication automatically if user has logged in via CLI + try: + api_key = get_wandb_api_key() + if api_key: + wandb.login(key=api_key) + log_rank_zero.info("W&B API key found and logged in") + except (KeyError, FileNotFoundError, Exception) as e: + # API key not found - this is OK, wandb.init() will use existing login or prompt + log_rank_zero.info( + "W&B API key not found in environment or file. " + "Will use existing wandb login or prompt for authentication." + ) + + wandb_name = ( + cfg.wandb.name + if cfg.wandb.name + else f"flood-run_{getattr(cfg.source_data, 'resolution', 64)}" + ) + + # Safely convert config to dict for wandb, handling unresolved interpolations gracefully + # This excludes 'rollout_data' by default since it's only needed for inference + wandb_config_dict = safe_config_to_dict( + cfg, + exclude_keys=["rollout_data"], # Not needed for training logging + logger=log_rank_zero, + ) + + wandb_init_args = dict( + config=wandb_config_dict, + name=wandb_name, + group=cfg.wandb.group, + project=cfg.wandb.project, + entity=cfg.wandb.entity, + ) + if cfg.wandb.sweep: + for key in wandb.config.keys(): + if hasattr(cfg, "params"): + cfg.params[key] = wandb.config[key] + wandb.init(**wandb_init_args) + log_rank_zero.info(f"W&B initialized: project={cfg.wandb.project}, name={wandb_name}") + + # Stage 1: Pretraining on source domain + log_section(log_rank_zero, "Stage 1: Pretraining on Source Domain") + model, normalizers, trainer_src = pretrain_model( + config=cfg, + device=device, + is_logger=is_logger, + source_data_config=cfg.source_data, + logger=log_rank_zero, + ) + + # Recreate source loaders for domain adaptation + log_rank_zero.info("Recreating source loaders for domain adaptation...") + source_full_dataset = FloodDatasetWithQueryPoints( + data_root=cfg.source_data.root, + n_history=cfg.source_data.n_history, + xy_file=cfg.source_data.get("xy_file", None), + query_res=cfg.source_data.get("query_res", [64, 64]), + static_files=cfg.source_data.get("static_files", []), + dynamic_patterns=cfg.source_data.get("dynamic_patterns", {}), + boundary_patterns=cfg.source_data.get("boundary_patterns", {}), + raise_on_smaller=True, + skip_before_timestep=cfg.source_data.get("skip_before_timestep", 0), + noise_type=cfg.source_data.get("noise_type", "none"), + noise_std=cfg.source_data.get("noise_std", None), + ) + train_sz_source = int(0.9 * len(source_full_dataset)) + source_train_raw, source_val_raw = random_split( + source_full_dataset, + [train_sz_source, len(source_full_dataset) - train_sz_source], + ) + + # Move normalizers to CPU for data transformation + for norm in normalizers.values(): + norm.to("cpu") + + geom_s_tr, static_s_tr, boundary_s_tr, dyn_s_tr, tgt_s_tr = collect_all_fields( + source_train_raw, True + ) + _, big_source_train = stack_and_fit_transform( + geom_s_tr, + static_s_tr, + boundary_s_tr, + dyn_s_tr, + tgt_s_tr, + normalizers=normalizers, + fit_normalizers=False, + ) + source_train_ds = NormalizedDataset( + geometry=big_source_train["geometry"], + static=big_source_train["static"], + boundary=big_source_train["boundary"], + dynamic=big_source_train["dynamic"], + target=big_source_train["target"], + query_res=cfg.source_data.query_res, + ) + source_train_loader = DataLoader( + source_train_ds, batch_size=cfg.source_data.batch_size, shuffle=True + ) + + geom_s_val, static_s_val, boundary_s_val, dyn_s_val, tgt_s_val = collect_all_fields( + source_val_raw, True + ) + _, big_source_val = stack_and_fit_transform( + geom_s_val, + static_s_val, + boundary_s_val, + dyn_s_val, + tgt_s_val, + normalizers=normalizers, + fit_normalizers=False, + ) + source_val_ds = NormalizedDataset( + geometry=big_source_val["geometry"], + static=big_source_val["static"], + boundary=big_source_val["boundary"], + dynamic=big_source_val["dynamic"], + target=big_source_val["target"], + query_res=cfg.source_data.query_res, + ) + source_val_loader = DataLoader( + source_val_ds, batch_size=cfg.source_data.batch_size, shuffle=False + ) + + # Stage 2: Domain adaptation + log_section(log_rank_zero, "Stage 2: Domain Adaptation") + data_processor = trainer_src.data_processor + + # Calculate wandb step offset to continue from pretraining + # neuralop Trainer uses step=epoch+1, so if pretraining ran for n_epochs_source, + # domain adaptation should start from step n_epochs_source + 1 + n_epochs_source = cfg.training.get("n_epochs_source", cfg.training.get("n_epochs", 100)) + wandb_step_offset = n_epochs_source if (cfg.wandb.log and is_logger) else 0 + + model, domain_classifier, trainer_adapt = adapt_model( + model=model, + normalizers=normalizers, + data_processor=data_processor, + config=cfg, + device=device, + is_logger=is_logger, + source_train_loader=source_train_loader, + source_val_loader=source_val_loader, + target_data_config=cfg.target_data, + logger=log_rank_zero, + wandb_step_offset=wandb_step_offset, + ) + + if cfg.wandb.log and is_logger: + wandb.finish() + log_rank_zero.info("W&B logging finished") + + log_section(log_rank_zero, "Training Complete!") + log_rank_zero.info("") + log_rank_zero.info("To perform rollout evaluation and generate visualizations,") + log_rank_zero.info("run: python inference.py --config-path conf --config-name config") + log_rank_zero.info("") + + except KeyboardInterrupt: + log_rank_zero.warning("Training interrupted by user") + if cfg.wandb.log and is_logger: + wandb.finish() + sys.exit(1) + except Exception as e: + import traceback + log_rank_zero.error(f"Fatal error in main pipeline: {e}") + log_rank_zero.error(traceback.format_exc()) + if cfg.wandb.log and is_logger: + wandb.finish() + raise + + +if __name__ == "__main__": + train_flood_forecaster() + diff --git a/examples/weather/flood_modeling/flood_forecaster/training/__init__.py b/examples/weather/flood_modeling/flood_forecaster/training/__init__.py new file mode 100644 index 0000000000..3559a4424e --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/training/__init__.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Training modules for flood prediction.""" + +# Lazy import pattern to handle circular dependencies between domain_adaptation and pretraining +# domain_adaptation imports create_scheduler from pretraining, but __init__ imports both +# Using lazy imports ensures modules are fully loaded before accessing functions +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # For type checking, import normally + from .domain_adaptation import adapt_model + from .pretraining import pretrain_model +else: + # At runtime, use lazy imports to avoid circular dependency issues + # Functions are imported on first access via __getattr__ + adapt_model = None + pretrain_model = None + + +def __getattr__(name: str): + """Lazy import function to resolve circular dependencies at runtime.""" + if name == "adapt_model": + from .domain_adaptation import adapt_model as _adapt_model + globals()["adapt_model"] = _adapt_model + return _adapt_model + elif name == "pretrain_model": + from .pretraining import pretrain_model as _pretrain_model + globals()["pretrain_model"] = _pretrain_model + return _pretrain_model + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = ["pretrain_model", "adapt_model"] + diff --git a/examples/weather/flood_modeling/flood_forecaster/training/domain_adaptation.py b/examples/weather/flood_modeling/flood_forecaster/training/domain_adaptation.py new file mode 100644 index 0000000000..6296e1c449 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/training/domain_adaptation.py @@ -0,0 +1,1381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Domain adaptation module for fine-tuning on target domain. + +This module implements adversarial domain adaptation using gradient reversal layers +to enable transfer learning from source to target domains. Compatible with neuralop 2.0.0 API +and physicsnemo framework. + +Key components: +- GradientReversal: Implements gradient reversal layer for adversarial training +- CNNDomainClassifier: CNN-based domain classifier for domain discrimination +- DomainAdaptationTrainer: Custom trainer for domain adaptation training loop +- adapt_model: High-level function to perform domain adaptation +""" + +import os +import sys +import math +import random +from timeit import default_timer +from pathlib import Path +from typing import Optional, Dict, Union, List, Tuple, Any +from itertools import cycle + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.utils.data import DataLoader, random_split +from tqdm import tqdm + +from neuralop.training import AdamW +from neuralop.losses import LpLoss + +import physicsnemo +from physicsnemo.models.meta import ModelMetaData +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.utils.checkpoint import save_checkpoint, load_checkpoint + + +def _sanitize_args_for_json(args_dict: Dict[str, Any]) -> Dict[str, Any]: + r""" + Recursively convert non-JSON-serializable objects in args_dict to serializable formats. + + This function handles: + - DictConfig -> dict (using OmegaConf.to_container) + - Other non-serializable objects are left as-is (will cause error if encountered) + + Parameters + ---------- + args_dict : Dict[str, Any] + Dictionary to sanitize (modified in-place). + + Returns + ------- + Dict[str, Any] + Sanitized dictionary (same object, modified in-place). + """ + try: + from omegaconf import DictConfig, OmegaConf + has_omegaconf = True + except ImportError: + has_omegaconf = False + + def _convert_value(value: Any) -> Any: + """Recursively convert DictConfig objects to dicts.""" + if has_omegaconf and isinstance(value, DictConfig): + # Convert DictConfig to dict, resolving nested DictConfigs + return OmegaConf.to_container(value, resolve=True) + elif isinstance(value, dict): + # Recursively process dictionary values + return {k: _convert_value(v) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + # Recursively process list/tuple elements + converted = [_convert_value(item) for item in value] + return type(value)(converted) # Preserve list/tuple type + else: + # Other types (int, float, str, bool, None) are already JSON-serializable + return value + + # Process the dictionary recursively + for key, value in list(args_dict.items()): + args_dict[key] = _convert_value(value) + + return args_dict + +from data_processing import LpLossWrapper + +# Try to import comm for distributed training, fallback if not available +try: + import neuralop.mpu.comm as comm + _has_comm = True +except ImportError: + _has_comm = False + # Fallback: create a dummy comm module + class _DummyComm: + @staticmethod + def get_local_rank(): + return 0 + comm = _DummyComm() + +from datasets import FloodDatasetWithQueryPoints, NormalizedDataset +from utils.normalization import collect_all_fields, stack_and_fit_transform +from training.pretraining import create_scheduler +from training.trainer import save_model_checkpoint, _has_pytorch_submodules + + +class GradientReversalFunction(Function): + r""" + Custom autograd function for gradient reversal layer. + + This function implements the gradient reversal layer (GRL) used in adversarial + domain adaptation. During forward pass, it returns the input unchanged. During + backward pass, it negates and scales the gradients by lambda. + + Attributes + ---------- + lambda_ : float + Scaling factor for gradient reversal (typically scheduled during training). + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, lambda_: float) -> torch.Tensor: + r""" + Forward pass: return input unchanged. + + Parameters + ---------- + ctx : Any + Context object to store lambda for backward pass. + x : torch.Tensor + Input tensor of arbitrary shape. + lambda_ : float + Scaling factor for gradient reversal. + + Returns + ------- + torch.Tensor + Cloned input tensor (same shape as input). + """ + ctx.lambda_ = lambda_ + return x.clone() + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + r""" + Backward pass: negate and scale gradients. + + Parameters + ---------- + ctx : Any + Context object containing lambda. + grad_output : torch.Tensor + Gradient from next layer. + + Returns + ------- + Tuple[torch.Tensor, None] + Tuple of (negated and scaled gradient, None for lambda). + """ + return grad_output.neg().mul(ctx.lambda_), None + + +class GradientReversal(physicsnemo.Module): + r""" + Gradient reversal layer module for adversarial domain adaptation. + + This module wraps the GradientReversalFunction to provide a learnable + gradient reversal layer. The lambda parameter can be dynamically updated + during training to schedule the strength of adversarial training. + + Parameters + ---------- + lambda_max : float, optional, default=1.0 + Maximum lambda value (typically 1.0). + + Forward + ------- + x : torch.Tensor + Input tensor of arbitrary shape. + + Outputs + ------- + torch.Tensor + Output tensor (same shape as input, but gradients will be reversed). + + Attributes + ---------- + lambda_ : float + Current scaling factor for gradient reversal. + """ + + def __init__(self, lambda_max: float = 1.0): + r""" + Initialize gradient reversal layer. + + Parameters + ---------- + lambda_max : float, optional, default=1.0 + Maximum lambda value (typically 1.0). + """ + super().__init__(meta=ModelMetaData(name="GradientReversal")) + self.lambda_ = lambda_max + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Forward pass through gradient reversal layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor of arbitrary shape. + + Returns + ------- + torch.Tensor + Output tensor (same shape as input, but gradients will be reversed). + """ + ### Input validation + # Skip validation when running under torch.compile for performance + if not torch.compiler.is_compiling(): + if not isinstance(x, torch.Tensor): + raise ValueError( + f"Expected input to be torch.Tensor, got {type(x)}" + ) + if x.numel() == 0: + raise ValueError( + f"Expected non-empty input tensor, got tensor with shape {tuple(x.shape)}" + ) + + return GradientReversalFunction.apply(x, self.lambda_) + + def set_lambda(self, val: float) -> None: + r""" + Update lambda value for gradient reversal. + + Parameters + ---------- + val : float + New lambda value. + """ + self.lambda_ = val + + +class CNNDomainClassifier(physicsnemo.Module): + r""" + CNN-based domain classifier for adversarial domain adaptation. + + This classifier takes latent features from the GINO model and predicts + whether they come from the source or target domain. The gradient reversal + layer ensures that the feature extractor learns domain-invariant features. + + Architecture: + - Gradient Reversal Layer (GRL) + - Convolutional layers (configurable) + - Adaptive average pooling + - Fully connected layer for binary classification + + Parameters + ---------- + in_channels : int + Number of input channels (should match fno_hidden_channels). + lambda_max : float + Maximum lambda for gradient reversal layer. + da_cfg : Dict[str, Any] + Configuration dict with keys: + - conv_layers: List of dicts with 'out_channels', 'kernel_size', 'pool_size' + - fc_dim: Output dimension of final fully connected layer + + Forward + ------- + x : torch.Tensor + Input features of shape :math:`(B, C, H, W)` where :math:`B` is batch size, + :math:`C` is channels, and :math:`H, W` are spatial dimensions. + + Outputs + ------- + torch.Tensor + Logits for binary classification of shape :math:`(B, D_{fc})` where + :math:`D_{fc}` is the fully connected layer dimension. + """ + + def __init__(self, in_channels: int, lambda_max: float, da_cfg: Dict[str, Any]): + r""" + Initialize domain classifier. + + Parameters + ---------- + in_channels : int + Number of input channels (should match fno_hidden_channels). + lambda_max : float + Maximum lambda for gradient reversal layer. + da_cfg : Dict[str, Any] + Configuration dict with keys: + - conv_layers: List of dicts with 'out_channels', 'kernel_size', 'pool_size' + - fc_dim: Output dimension of final fully connected layer + + Raises + ------ + ValueError + If required keys are missing from ``da_cfg`` or if conv_layers is empty. + """ + # Convert DictConfig to regular dict if needed (for JSON serialization) + # This must be done BEFORE super().__init__() so PhysicsNeMo's __new__ captures + # the converted dict, not the DictConfig + # OmegaConf.to_container recursively converts nested DictConfigs to dicts + try: + from omegaconf import DictConfig, OmegaConf + if isinstance(da_cfg, DictConfig): + # Convert to regular dict, resolving all nested DictConfigs recursively + da_cfg = OmegaConf.to_container(da_cfg, resolve=True) + except ImportError: + # OmegaConf not available, assume da_cfg is already a dict + pass + except Exception: + # If conversion fails for any reason, try to manually convert + # This is a fallback in case OmegaConf.to_container doesn't work + if hasattr(da_cfg, '__dict__'): + # Try to convert manually + da_cfg = dict(da_cfg) + + super().__init__(meta=ModelMetaData(name="CNNDomainClassifier")) + + # CRITICAL: Sanitize _args to convert any DictConfig objects to regular dicts + # PhysicsNeMo's __new__ captures arguments before __init__ runs, so _args + # may contain DictConfig objects that need to be converted for JSON serialization + # We need to sanitize the entire _args structure, not just __args__ + if hasattr(self, '_args'): + # Recursively sanitize the entire _args dictionary + # This will convert any DictConfig objects anywhere in the structure + _sanitize_args_for_json(self._args) + if not da_cfg.get("conv_layers"): + raise ValueError("da_cfg must contain 'conv_layers' list") + if "fc_dim" not in da_cfg: + raise ValueError("da_cfg must contain 'fc_dim'") + + self.grl = GradientReversal(lambda_max=lambda_max) + layers = [] + c_in = in_channels + + for layer_spec in da_cfg["conv_layers"]: + out_channels = layer_spec["out_channels"] + kernel_size = layer_spec.get("kernel_size", 3) + pool_size = layer_spec.get("pool_size", 2) + + layers.extend([ + nn.Conv2d( + c_in, out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2 + ), + nn.ReLU(inplace=True), + nn.MaxPool2d(pool_size) + ]) + c_in = out_channels + + layers.append(nn.AdaptiveAvgPool2d((1, 1))) + self.conv_net = nn.Sequential(*layers) + self.fc = nn.Linear(c_in, da_cfg["fc_dim"]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Forward pass through domain classifier. + + Parameters + ---------- + x : torch.Tensor + Input features of shape :math:`(B, C, H, W)` where :math:`B` is batch size, + :math:`C` is channels, and :math:`H, W` are spatial dimensions. + + Returns + ------- + torch.Tensor + Logits for binary classification of shape :math:`(B, D_{fc})` where + :math:`D_{fc}` is the fully connected layer dimension. + """ + ### Input validation + # Skip validation when running under torch.compile for performance + if not torch.compiler.is_compiling(): + if x.ndim != 4: + raise ValueError( + f"Expected 4D input tensor (B, C, H, W), got {x.ndim}D tensor " + f"with shape {tuple(x.shape)}" + ) + + # Apply gradient reversal, then conv layers, then flatten and classify + x = self.grl(x) # (B, C, H, W) + x = self.conv_net(x) # (B, C, 1, 1) after adaptive pooling + x = x.view(x.size(0), -1) # (B, C) + return self.fc(x) # (B, fc_dim) + + +class DomainAdaptationTrainer: + r""" + Custom trainer for domain adaptation compatible with neuralop 2.0.0. + + Implements adversarial domain adaptation using gradient reversal layers (GRL). + The training process alternates between: + 1. Task loss: Regression loss on both source and target domains + 2. Adversarial loss: Domain classification loss with reversed gradients + + The GRL lambda is scheduled during training to gradually increase the strength + of adversarial training. + + Parameters + ---------- + model : nn.Module + The main model to train (should support return_features=True). + data_processor : nn.Module, optional + Data processor for preprocessing/postprocessing. + domain_classifier : nn.Module + Domain classifier module. + device : str or torch.device, optional, default="cuda" + Device to train on ('cuda', 'cpu', or torch.device). + verbose : bool, optional, default=True + Whether to print training progress. + logger : Any, optional + Optional logger instance (if None, uses print when verbose=True). + + Attributes + ---------- + model : nn.Module + The main GINO model (wrapped with GINOWrapper). + data_processor : nn.Module, optional + Data processor for preprocessing/postprocessing. + domain_classifier : nn.Module + CNN-based domain classifier. + device : str or torch.device + Device to train on. + verbose : bool + Whether to print training progress. + _eval_interval : int + Interval for evaluation (default: 1). + """ + + def __init__( + self, + model: nn.Module, + data_processor: Optional[nn.Module], + domain_classifier: nn.Module, + device: Union[str, torch.device] = "cuda", + verbose: bool = True, + logger: Optional[Any] = None, + wandb_step_offset: int = 0, + ): + r""" + Initialize domain adaptation trainer. + + Parameters + ---------- + model : nn.Module + The main model to train (should support return_features=True). + data_processor : nn.Module, optional + Optional data processor for preprocessing. + domain_classifier : nn.Module + Domain classifier module. + device : str or torch.device, optional, default="cuda" + Device to train on ('cuda', 'cpu', or torch.device). + verbose : bool, optional, default=True + Whether to print training progress. + logger : Any, optional + Optional logger instance (if None, uses print when verbose=True). + wandb_step_offset : int, optional, default=0 + Step offset for wandb logging to continue from pretraining step count. + """ + self.model = model + self.data_processor = data_processor + self.domain_classifier = domain_classifier + self.device = device + self.verbose = verbose + self.logger = logger + self.wandb_step_offset = wandb_step_offset + self._eval_interval = 1 + + def train_domain_adaptation( + self, + src_loader: DataLoader, + tgt_loader: Union[DataLoader, List[DataLoader]], + optimizer, + scheduler, + training_loss, + class_loss_weight: float = 0.1, + adaptation_epochs: int = 100, + save_every: int = None, + save_dir: Union[str, Path] = "./ckpt", + resume_from_dir: Union[str, Path] = None, + resume_classifier_from_dir: Union[str, Path] = None, + val_loaders: Optional[Dict[str, DataLoader]] = None, + ): + r""" + Domain-adaptation training loop with adversarial classifier. + + Handles both a single target DataLoader and a list of target DataLoaders. + This implementation exactly matches the original neuralop trainer's + train_domain_adaptation method. + + Parameters + ---------- + src_loader : DataLoader + Source domain dataloader. + tgt_loader : DataLoader or List[DataLoader] + Target domain dataloader (single or list). + optimizer : torch.optim.Optimizer + Optimizer for both model and classifier. + scheduler : torch.optim.lr_scheduler._LRScheduler + Learning rate scheduler. + training_loss : callable + Loss function for main task. + class_loss_weight : float, optional, default=0.1 + Weight for domain classification loss. + adaptation_epochs : int, optional, default=100 + Number of epochs to train. + save_every : int, optional + Interval at which to save checkpoints. + save_dir : str or Path, optional, default="./ckpt" + Directory to save checkpoints. + resume_from_dir : str or Path, optional + Directory to resume training from. + resume_classifier_from_dir : str or Path, optional + Directory to resume classifier from. + val_loaders : Dict[str, DataLoader], optional + Dict of validation dataloaders. + + Returns + ------- + nn.Module + Trained model. + """ + self.model = self.model.to(self.device) + self.domain_classifier = self.domain_classifier.to(self.device) + if self.data_processor is not None: + self.data_processor = self.data_processor.to(self.device) + + # Domain classification loss (binary cross-entropy) + adv_criterion = nn.BCEWithLogitsLoss() + + # Optionally resume model and classifier state + start_epoch = 0 + if resume_from_dir is not None: + start_epoch = self._resume_from_checkpoint(resume_from_dir, optimizer, scheduler) + + # Optionally resume classifier from separate directory (fallback mechanism) + if resume_classifier_from_dir is not None: + classifier_loaded = False + resume_classifier_dir = Path(resume_classifier_from_dir) + + # Try PhysicsNeMo format first (classifier saved as second model) + try: + metadata_dict = {} + load_checkpoint( + path=str(resume_classifier_dir), + models=[self.domain_classifier], + optimizer=None, + scheduler=None, + scaler=None, + epoch=None, # Load latest + metadata_dict=metadata_dict, + device=self.device, + ) + msg = f"Loaded classifier from PhysicsNeMo checkpoint: {resume_classifier_dir}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + classifier_loaded = True + except (FileNotFoundError, KeyError, ValueError): + # Fall back to old format (separate classifier_state_dict.pt file) + ckpt = resume_classifier_dir / "classifier_state_dict.pt" + if ckpt.exists(): + self.domain_classifier.load_state_dict( + torch.load(str(ckpt), map_location=self.device) + ) + msg = f"Loaded classifier from neuralop checkpoint: {ckpt}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + classifier_loaded = True + + if not classifier_loaded: + msg = f"Warning: Could not load classifier from {resume_classifier_from_dir}" + if self.logger: + self.logger.warning(msg) + elif self.verbose: + print(msg) + + val_loaders = val_loaders or {} + + # Handle both single loader and list of loaders + if not isinstance(tgt_loader, list): + tgt_loaders = [tgt_loader] + else: + tgt_loaders = tgt_loader + + # Determine iteration strategy based on source loader + base_batches = len(src_loader) + total_iters = adaptation_epochs * base_batches + + # Create cycling iterators for all loaders + src_iter = cycle(src_loader) + tgt_iters = [cycle(loader) for loader in tgt_loaders] + + msg1 = f"Starting domain adaptation training for {adaptation_epochs} epochs" + msg2 = f"Source samples: {len(src_loader.dataset)}, Target loaders: {len(tgt_loaders)}" + if self.logger: + self.logger.info(msg1) + self.logger.info(msg2) + elif self.verbose: + print(msg1) + print(msg2) + + for epoch in range(start_epoch + 1, adaptation_epochs): + self.on_epoch_start(epoch) + self.model.train() + self.domain_classifier.train() + if self.data_processor is not None: + self.data_processor.train() + + total_reg, total_adv = 0.0, 0.0 + + # Progress bar + pbar = tqdm( + range(base_batches), + desc=f"DA Epoch {epoch}/{adaptation_epochs}", + disable=not self.verbose, + file=sys.stdout + ) + + for batch_idx in pbar: + # Update GRL lambda (scheduled from 0 to lambda_max) + # Formula: lambda_val = 2.0 / (1.0 + exp(-10 * p)) - 1.0 + # where p = (epoch * base_batches + batch_idx) / total_iters + p = (epoch * base_batches + batch_idx) / total_iters + lambda_val = 2.0 / (1.0 + math.exp(-10 * p)) - 1.0 + self.domain_classifier.grl.set_lambda(lambda_val) + + # Randomly select one target domain from the list for this training step + chosen_tgt_iter = random.choice(tgt_iters) + + src_batch = next(src_iter) + tgt_batch = next(chosen_tgt_iter) + + # Preprocess batches + if self.data_processor is not None: + s = self.data_processor.preprocess(src_batch) + t = self.data_processor.preprocess(tgt_batch) + else: + s = {k: v.to(self.device) for k, v in src_batch.items() if torch.is_tensor(v)} + t = {k: v.to(self.device) for k, v in tgt_batch.items() if torch.is_tensor(v)} + + # Forward pass with feature extraction + # Extract features using return_features=True + # Features should be in shape (batch, channels, H, W) for 2D + try: + out_s, f_s = self.model(**s, return_features=True) + out_t, f_t = self.model(**t, return_features=True) + except TypeError as e: + # Fallback if model doesn't support return_features + raise RuntimeError( + "Model must support return_features=True for domain adaptation. " + "Ensure model is wrapped with GINOWrapper." + ) from e + + # Postprocess outputs (after feature extraction, before loss) + if self.data_processor is not None: + out_s, s = self.data_processor.postprocess(out_s, s) + out_t, t = self.data_processor.postprocess(out_t, t) + + # Regression loss on source and target + # Note: training_loss expects (y_pred, **sample) where sample contains 'y' + reg_loss = training_loss(out_s, **s) + training_loss(out_t, **t) + + # Prepare features for domain classifier + # Features from GINOWrapper are already in shape (batch, channels, H, W) + # Concatenate source and target features along batch dimension + if f_s.dim() != 4 or f_t.dim() != 4: + raise ValueError( + f"Expected 4D features (B, C, H, W), got f_s.shape={f_s.shape}, f_t.shape={f_t.shape}. " + "Ensure GINOWrapper returns features in correct format." + ) + feats = torch.cat([f_s, f_t], dim=0) + + # Domain classification adversarial loss + logits = self.domain_classifier(feats).squeeze(1) + labels = torch.cat([ + torch.ones(f_s.size(0), device=self.device), + torch.zeros(f_t.size(0), device=self.device) + ], dim=0).float() + adv_loss = adv_criterion(logits, labels) + + # Combined loss + loss = reg_loss + class_loss_weight * adv_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_reg += reg_loss.item() + total_adv += adv_loss.item() + + # Update progress bar + if self.verbose: + pbar.set_postfix({ + 'loss': f'{loss.item():.4f}', + 'reg': f'{reg_loss.item():.4f}', + 'adv': f'{adv_loss.item():.4f}', + 'lambda': f'{lambda_val:.3f}' + }) + + # Step scheduler + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + scheduler.step(total_reg + total_adv) + else: + scheduler.step() + + avg_reg = total_reg / base_batches + avg_adv = total_adv / base_batches + msg = f"[DA Epoch {epoch}] reg={avg_reg:.4f}, adv={avg_adv:.4f}, lambda={lambda_val:.3f}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + + # Validation (if val_loaders provided) + if val_loaders and (epoch % self.eval_interval == 0 or epoch == adaptation_epochs - 1): + self._evaluate(val_loaders, training_loss, epoch) + + # Optional checkpointing using PhysicsNeMo checkpoint system + if save_every is not None and (epoch % save_every == 0): + # Only save on rank 0 in distributed training + should_save = True + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + should_save = (dist_manager.rank == 0) + elif _has_comm: + should_save = (comm.get_local_rank() == 0) + + if should_save: + sd = Path(save_dir) + sd.mkdir(parents=True, exist_ok=True) + + # Determine model parallel rank + model_parallel_rank = 0 + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + if "model_parallel" in dist_manager.group_names: + model_parallel_rank = dist_manager.group_rank("model_parallel") + + # Save model(s) - handle PyTorch submodules if needed + model_to_save = self.model + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model_to_save = self.model.module + + model_saved_separately = save_model_checkpoint( + model=model_to_save, + save_dir=sd, + epoch=epoch, + model_parallel_rank=model_parallel_rank, + ) + + # Prepare models list for PhysicsNeMo (includes domain classifier) + models_to_save = [] + if not model_saved_separately: + models_to_save.append(model_to_save) + # Always add domain classifier as second model + models_to_save.append(self.domain_classifier) + + # Save checkpoint with both models and training state + save_checkpoint( + path=str(sd), + models=models_to_save if not model_saved_separately else [self.domain_classifier], + optimizer=optimizer, + scheduler=scheduler, + scaler=None, + epoch=epoch, + metadata={"stage": "domain_adaptation", "epoch": epoch}, + ) + + msg = f"Saved DA checkpoint at epoch {epoch}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + + # Save final checkpoint (outside epoch loop) using PhysicsNeMo checkpoint system + should_save = True + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + should_save = (dist_manager.rank == 0) + elif _has_comm: + should_save = (comm.get_local_rank() == 0) + + if should_save: + sd = Path(save_dir) + sd.mkdir(parents=True, exist_ok=True) + + # Determine model parallel rank + model_parallel_rank = 0 + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + if "model_parallel" in dist_manager.group_names: + model_parallel_rank = dist_manager.group_rank("model_parallel") + + # Save model(s) - handle PyTorch submodules if needed + model_to_save = self.model + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model_to_save = self.model.module + + final_epoch = adaptation_epochs - 1 + model_saved_separately = save_model_checkpoint( + model=model_to_save, + save_dir=sd, + epoch=final_epoch, + model_parallel_rank=model_parallel_rank, + ) + + # Prepare models list for PhysicsNeMo (includes domain classifier) + models_to_save = [] + if not model_saved_separately: + models_to_save.append(model_to_save) + # Always add domain classifier as second model + models_to_save.append(self.domain_classifier) + + # Save checkpoint with both models and training state + save_checkpoint( + path=str(sd), + models=models_to_save if not model_saved_separately else [self.domain_classifier], + optimizer=optimizer, + scheduler=scheduler, + scaler=None, + epoch=final_epoch, + metadata={"stage": "domain_adaptation", "final_epoch": True, "epoch": final_epoch}, + ) + + msg = "Saved final DA checkpoint using PhysicsNeMo format" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + + msg = "Domain adaptation training completed!" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + return self.model + + def on_epoch_start(self, epoch): + r"""Stub called at the start of each epoch.""" + self.epoch = epoch + return None + + @property + def eval_interval(self): + r"""Evaluation interval (default 1).""" + return getattr(self, '_eval_interval', 1) + + @eval_interval.setter + def eval_interval(self, value): + self._eval_interval = value + + def _evaluate( + self, + val_loaders: Dict[str, DataLoader], + loss_fn: Any, + epoch: int + ) -> None: + r""" + Evaluate model on validation loaders. + + Parameters + ---------- + val_loaders : Dict[str, DataLoader] + Dictionary of validation dataloaders. + loss_fn : callable + Loss function to use for evaluation. + epoch : int + Current epoch number (for logging). + """ + self.model.eval() + if self.data_processor is not None: + self.data_processor.eval() + + with torch.no_grad(): + for name, loader in val_loaders.items(): + total_loss = 0.0 + n_samples = 0 + + for sample in loader: + try: + if self.data_processor is not None: + sample = self.data_processor.preprocess(sample) + else: + sample = {k: v.to(self.device) for k, v in sample.items() if torch.is_tensor(v)} + + out = self.model(**sample) + + if self.data_processor is not None: + out, sample = self.data_processor.postprocess(out, sample) + + # Loss function expects (y_pred, **sample) where sample contains 'y' + loss = loss_fn(out, **sample) + total_loss += loss.item() + n_samples += sample.get("y", out).shape[0] if isinstance(sample.get("y"), torch.Tensor) else out.shape[0] + except Exception as e: + msg = f"Error evaluating on {name}: {e}" + if self.logger: + self.logger.error(msg) + elif self.verbose: + print(msg) + raise + + avg_loss = total_loss / len(loader) if len(loader) > 0 else 0.0 + msg = f" Eval {name}: loss={avg_loss:.6f}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + + def _save_checkpoint( + self, + save_dir: Union[str, Path], + optimizer: torch.optim.Optimizer, + scheduler: Any, + epoch: int, + save_classifier: bool = False + ) -> None: + r""" + Save training checkpoint using PhysicsNeMo checkpoint system. + + This method saves both the main model and domain classifier (if requested) + using PhysicsNeMo's checkpoint format. + + Parameters + ---------- + save_dir : str or Path + Directory to save checkpoint. + optimizer : torch.optim.Optimizer + Optimizer instance. + scheduler : Any + Scheduler instance. + epoch : int + Current epoch number. + save_classifier : bool, optional, default=False + Whether to save classifier state dict. + """ + save_dir = Path(save_dir) + + # Only save on rank 0 in distributed training + should_save = True + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + should_save = (dist_manager.rank == 0) + elif _has_comm: + should_save = (comm.get_local_rank() == 0) + + if not should_save: + return + + try: + save_dir.mkdir(parents=True, exist_ok=True) + + # Determine model parallel rank + model_parallel_rank = 0 + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + if "model_parallel" in dist_manager.group_names: + model_parallel_rank = dist_manager.group_rank("model_parallel") + + # Save model(s) - handle PyTorch submodules if needed + model_to_save = self.model + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model_to_save = self.model.module + + model_saved_separately = save_model_checkpoint( + model=model_to_save, + save_dir=save_dir, + epoch=epoch, + model_parallel_rank=model_parallel_rank, + ) + + # Prepare models list for PhysicsNeMo + models_to_save = [] + if not model_saved_separately: + models_to_save.append(model_to_save) + if save_classifier: + models_to_save.append(self.domain_classifier) + + # Save checkpoint with models and training state + save_checkpoint( + path=str(save_dir), + models=models_to_save if models_to_save else None, + optimizer=optimizer, + scheduler=scheduler, + scaler=None, + epoch=epoch, + metadata={"stage": "domain_adaptation", "epoch": epoch}, + ) + + msg = f"Saved checkpoint to {save_dir}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + except Exception as e: + msg = f"Error saving checkpoint to {save_dir}: {e}" + if self.logger: + self.logger.error(msg) + elif self.verbose: + print(msg) + raise + + def _resume_from_checkpoint( + self, + resume_dir: Union[str, Path], + optimizer: torch.optim.Optimizer, + scheduler: Any + ) -> int: + r""" + Resume training from checkpoint using PhysicsNeMo checkpoint system. + + Supports both PhysicsNeMo format (new) and neuralop format (old) for backward compatibility. + + Parameters + ---------- + resume_dir : str or Path + Directory containing checkpoint. + optimizer : torch.optim.Optimizer + Optimizer instance (will be updated with checkpoint state). + scheduler : Any + Scheduler instance (will be updated with checkpoint state). + + Returns + ------- + int + Epoch number to resume from (0 if no checkpoint found). + """ + resume_dir = Path(resume_dir) + + if not resume_dir.exists(): + msg = f"Resume directory does not exist: {resume_dir}" + if self.logger: + self.logger.warning(msg) + elif self.verbose: + print(msg) + return 0 + + try: + # Try PhysicsNeMo format first (new format) + checkpoint_loaded = False + resume_epoch = 0 + metadata_dict = {} + + try: + # Try to load using PhysicsNeMo format + # Prepare models list (main model + domain classifier if available) + models_to_load = [self.model] + if hasattr(self, 'domain_classifier') and self.domain_classifier is not None: + models_to_load.append(self.domain_classifier) + + resume_epoch = load_checkpoint( + path=str(resume_dir), + models=models_to_load, + optimizer=optimizer, + scheduler=scheduler, + scaler=None, + epoch=None, # Load latest + metadata_dict=metadata_dict, + device=self.device, + ) + + if self.logger: + self.logger.info("Loaded checkpoint using PhysicsNeMo format") + checkpoint_loaded = True + except (FileNotFoundError, KeyError, ValueError) as e: + # Fall back to neuralop format (old format) + if self.logger: + self.logger.info(f"PhysicsNeMo checkpoint not found, trying neuralop format: {e}") + + # Check for neuralop checkpoint files + if (resume_dir / "best_model_state_dict.pt").exists(): + save_name = "best_model" + elif (resume_dir / "model_state_dict.pt").exists(): + save_name = "model" + else: + msg = f"No checkpoint found in {resume_dir} (tried both formats)" + if self.logger: + self.logger.warning(msg) + elif self.verbose: + print(msg) + return 0 + + # Load using neuralop format + from neuralop.training.training_state import load_training_state + + self.model, optimizer, scheduler, _, resume_epoch = load_training_state( + save_dir=resume_dir, + save_name=save_name, + model=self.model, + optimizer=optimizer, + scheduler=scheduler, + ) + + # Try to load domain classifier if it exists + classifier_path = resume_dir / "classifier_state_dict.pt" + if classifier_path.exists() and hasattr(self, 'domain_classifier') and self.domain_classifier is not None: + self.domain_classifier.load_state_dict( + torch.load(str(classifier_path), map_location=self.device) + ) + if self.logger: + self.logger.info("Loaded domain classifier from neuralop checkpoint") + + if self.logger: + self.logger.info(f"Loaded checkpoint using neuralop format: {save_name}") + checkpoint_loaded = True + + if checkpoint_loaded and resume_epoch is not None: + msg = f"Resumed from epoch {resume_epoch}" + if self.logger: + self.logger.info(msg) + elif self.verbose: + print(msg) + return resume_epoch + else: + return 0 + + except Exception as e: + msg = f"Error loading checkpoint from {resume_dir}: {e}" + if self.logger: + self.logger.error(msg) + elif self.verbose: + print(msg) + raise + + +def adapt_model( + model: nn.Module, + normalizers: Dict[str, Any], + data_processor: Optional[nn.Module], + config: Any, + device: Union[str, torch.device], + is_logger: bool, + source_train_loader: DataLoader, + source_val_loader: DataLoader, + target_data_config: Any, + logger: Optional[Any] = None, + wandb_step_offset: int = 0, +) -> Tuple[nn.Module, nn.Module, "DomainAdaptationTrainer"]: + r""" + Perform domain adaptation on target domain data. + + This function orchestrates the domain adaptation training process: + 1. Loads and normalizes target domain data + 2. Creates domain classifier + 3. Sets up optimizer and scheduler + 4. Trains model with adversarial domain adaptation + + Parameters + ---------- + model : nn.Module + Pretrained model (should be wrapped with GINOWrapper). + normalizers : Dict[str, Any] + Dictionary of normalizers from pretraining. + data_processor : nn.Module, optional + Data processor instance. + config : Any + Configuration object (OmegaConf DictConfig). + device : str or torch.device + Device to train on ('cuda', 'cpu', or torch.device). + is_logger : bool + Whether this process is the logger (for distributed training). + source_train_loader : DataLoader + Source domain training dataloader. + source_val_loader : DataLoader + Source domain validation dataloader. + target_data_config : Any + Target data configuration (OmegaConf DictConfig). + logger : Any, optional + Optional logger instance (physicsnemo PythonLogger or compatible). + wandb_step_offset : int, optional, default=0 + Step offset for wandb logging to continue from pretraining step count. + Currently not used but reserved for future wandb integration. + + Returns + ------- + Tuple[nn.Module, nn.Module, DomainAdaptationTrainer] + Tuple of (adapted_model, domain_classifier, trainer). + + Raises + ------ + ValueError + If required config keys are missing. + RuntimeError + If model doesn't support return_features. + """ + if logger is None: + # Fallback to print if no logger provided (for backward compatibility) + def log_info(msg: str) -> None: + print(msg) + + logger = type("Logger", (), {"info": lambda self, msg: log_info(msg)})() + + logger.info("Starting domain adaptation on source + target...") + + # Validate inputs + if not hasattr(model, 'fno_hidden_channels'): + raise AttributeError( + "Model must have 'fno_hidden_channels' attribute. " + "Ensure model is a GINO model or wrapped with GINOWrapper." + ) + + # Create target dataset + logger.info(f"Loading target dataset from: {target_data_config.root}") + try: + target_full_dataset = FloodDatasetWithQueryPoints( + data_root=target_data_config.root, + n_history=target_data_config.n_history, + xy_file=getattr(target_data_config, "xy_file", None), + query_res=getattr(target_data_config, "query_res", [64, 64]), + static_files=getattr(target_data_config, "static_files", []), + dynamic_patterns=getattr(target_data_config, "dynamic_patterns", {}), + boundary_patterns=getattr(target_data_config, "boundary_patterns", {}), + raise_on_smaller=True, + skip_before_timestep=getattr(target_data_config, "skip_before_timestep", 0), + noise_type=getattr(target_data_config, "noise_type", "none"), + noise_std=getattr(target_data_config, "noise_std", None) + ) + except Exception as e: + raise RuntimeError(f"Failed to load target dataset from {target_data_config.root}: {e}") from e + + # Split into train/val + train_sz_target = int(0.9 * len(target_full_dataset)) + target_train_raw, target_val_raw = random_split( + target_full_dataset, + [train_sz_target, len(target_full_dataset) - train_sz_target] + ) + logger.info(f"Target domain: total={len(target_full_dataset)}, train={train_sz_target}, val={len(target_val_raw)}") + + # Move normalizers to CPU temporarily + for nm in normalizers.values(): + nm.to('cpu') + + # Collect and normalize target training data + logger.info("Collecting and normalizing target training data...") + geom_t_tr, static_t_tr, boundary_t_tr, dyn_t_tr, tgt_t_tr = collect_all_fields(target_train_raw, True) + _, big_target_train = stack_and_fit_transform( + geom_t_tr, static_t_tr, boundary_t_tr, dyn_t_tr, tgt_t_tr, + normalizers=normalizers, fit_normalizers=False + ) + target_train_ds = NormalizedDataset( + geometry=big_target_train["geometry"], + static=big_target_train["static"], + boundary=big_target_train["boundary"], + dynamic=big_target_train["dynamic"], + target=big_target_train["target"], + query_res=target_data_config.query_res + ) + + # Collect and normalize target validation data + logger.info("Collecting and normalizing target validation data...") + geom_t_val, static_t_val, boundary_t_val, dyn_t_val, tgt_t_val = collect_all_fields(target_val_raw, True) + _, big_target_val = stack_and_fit_transform( + geom_t_val, static_t_val, boundary_t_val, dyn_t_val, tgt_t_val, + normalizers=normalizers, fit_normalizers=False + ) + target_val_ds = NormalizedDataset( + geometry=big_target_val["geometry"], + static=big_target_val["static"], + boundary=big_target_val["boundary"], + dynamic=big_target_val["dynamic"], + target=big_target_val["target"], + query_res=target_data_config.query_res + ) + target_val_loader = DataLoader( + target_val_ds, batch_size=target_data_config.batch_size, shuffle=False + ) + + # Create domain classifier + logger.info("Creating domain classifier...") + try: + da_cfg = config.training.get("da_classifier", {}) + if not da_cfg: + raise ValueError("config.training.da_classifier is required for domain adaptation") + domain_classifier = CNNDomainClassifier( + model.fno_hidden_channels, + config.training.get("da_lambda_max", 1.0), + da_cfg, + ).to(device) + except (AttributeError, KeyError) as e: + raise ValueError( + f"Invalid domain adaptation configuration: {e}. " + "Ensure config.training.da_classifier contains 'conv_layers' and 'fc_dim'." + ) from e + + # Create optimizer and scheduler + adapt_lr = config.training.get("adapt_learning_rate", config.training.get("learning_rate", 1e-4)) + weight_decay = config.training.get("weight_decay", 1e-4) + optimizer_adapt = AdamW( + list(model.parameters()) + list(domain_classifier.parameters()), + lr=adapt_lr, + weight_decay=weight_decay, + ) + logger.info(f"Optimizer: AdamW (lr={adapt_lr}, weight_decay={weight_decay})") + scheduler_adapt = create_scheduler(optimizer_adapt, config, logger) + + # Create loss - wrap with LpLossWrapper to filter out unexpected kwargs + # Get loss type from config, default to 'l2' + def create_loss(loss_type_str, default="l2"): + """Helper function to create loss function from string.""" + loss_type_str = loss_type_str.lower() + if loss_type_str == "l1": + return LpLossWrapper(LpLoss(d=2, p=1)), "l1" + elif loss_type_str == "l2": + return LpLossWrapper(LpLoss(d=2, p=2)), "l2" + else: + if logger: + logger.warning(f"Unknown loss type '{loss_type_str}', defaulting to '{default}'") + return LpLossWrapper(LpLoss(d=2, p=2)), default + + training_loss_type = config.training.get("training_loss", "l2") + training_loss_fn, training_loss_name = create_loss(training_loss_type) + if logger: + logger.info(f"Using {training_loss_name.upper()} loss for domain adaptation training") + + # Use testing_loss for evaluation if specified, otherwise use training_loss + # Note: Currently domain adaptation uses the same loss for both training and evaluation + # but we create it here for consistency and future extensibility + testing_loss_type = config.training.get("testing_loss", training_loss_type) + eval_loss_fn, eval_loss_name = create_loss(testing_loss_type, default=training_loss_name) + if testing_loss_type.lower() != training_loss_type.lower() and logger: + logger.info(f"Note: testing_loss specified but domain adaptation currently uses training_loss for evaluation") + + # Create custom domain adaptation trainer + trainer_adapt = DomainAdaptationTrainer( + model=model, + data_processor=data_processor, + domain_classifier=domain_classifier, + device=device, + verbose=is_logger, + logger=logger, + wandb_step_offset=wandb_step_offset, + ) + trainer_adapt.eval_interval = 1 # Evaluate every epoch + + # Train with domain adaptation + save_dir = os.path.join(config.checkpoint.get("save_dir", "./checkpoints"), "adapt") + logger.info(f"Starting training... Checkpoints will be saved to: {save_dir}") + logger.info(f"Starting domain adaptation training for {config.training.get('n_epochs_adapt', 50)} epochs") + trainer_adapt.train_domain_adaptation( + src_loader=source_train_loader, + tgt_loader=DataLoader(target_train_ds, batch_size=target_data_config.batch_size, shuffle=True), + optimizer=optimizer_adapt, + scheduler=scheduler_adapt, + training_loss=training_loss_fn, + # da_class_loss_weight controls adversarial training strength + # Default 0.0 disables adversarial training (standard fine-tuning) + # Set to positive value (e.g., 0.1) to enable domain adaptation + class_loss_weight=config.training.get("da_class_loss_weight", 0.0), + adaptation_epochs=config.training.get("n_epochs_adapt", 50), + save_every=None, # Save at end only, or set to save interval + save_dir=save_dir, + resume_from_dir=config.checkpoint.get("resume_from_adapt", None), + resume_classifier_from_dir=config.checkpoint.get("resume_from_adapt", None), + val_loaders={"source_val": source_val_loader, "target_val": target_val_loader}, + ) + + return model, domain_classifier, trainer_adapt diff --git a/examples/weather/flood_modeling/flood_forecaster/training/pretraining.py b/examples/weather/flood_modeling/flood_forecaster/training/pretraining.py new file mode 100644 index 0000000000..cbac61765c --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/training/pretraining.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Pretraining module for source domain training. + +Compatible with neuralop 2.0.0 API. +""" + +import os +from typing import Optional + +import torch +from torch.utils.data import DataLoader, random_split + +from neuralop.training import AdamW +from neuralop.losses import LpLoss +from neuralop import get_model + +from physicsnemo.launch.utils.checkpoint import save_checkpoint +from training.trainer import save_model_checkpoint + +from datasets import FloodDatasetWithQueryPoints, NormalizedDataset +from data_processing import FloodGINODataProcessor, GINOWrapper, LpLossWrapper +from utils.normalization import collect_all_fields, stack_and_fit_transform +from training.trainer import NeuralOperatorTrainer + + +def create_scheduler(optimizer, config, logger=None): + r""" + Create learning rate scheduler based on config. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + Optimizer instance. + config : Any + Configuration object. + logger : Any, optional + Optional logger instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + Scheduler instance. + + Raises + ------ + ValueError + If scheduler name is unknown. + """ + scheduler_name = config.training.get("scheduler", "StepLR") + if scheduler_name == "ReduceLROnPlateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=config.training.get("gamma", 0.5), + patience=config.training.get("scheduler_patience", 5), + mode="min", + ) + elif scheduler_name == "CosineAnnealingLR": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=config.training.get("scheduler_T_max", 200) + ) + elif scheduler_name == "StepLR": + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=config.training.get("step_size", 50), + gamma=config.training.get("gamma", 0.5), + ) + else: + raise ValueError(f"Unknown scheduler {scheduler_name}") + + if logger: + # Safely log debug message - PythonLogger doesn't have debug method + # Access underlying logging.Logger if available (PythonLogger wraps it) + try: + # Handle RankZeroLoggingWrapper which wraps PythonLogger + if hasattr(logger, 'obj') and hasattr(logger.obj, 'logger'): + logger.obj.logger.debug(f"Created {scheduler_name} scheduler") + # Handle direct PythonLogger + elif hasattr(logger, 'logger') and hasattr(logger.logger, 'debug'): + logger.logger.debug(f"Created {scheduler_name} scheduler") + # Fallback: try direct debug method (for loggers that support it) + elif hasattr(logger, 'debug'): + logger.debug(f"Created {scheduler_name} scheduler") + except (AttributeError, TypeError): + # Skip debug logging if not available (not critical for scheduler creation) + pass + return scheduler + + +def pretrain_model(config, device, is_logger, source_data_config, logger=None): + r""" + Pretrain model on source domain data. + + Compatible with neuralop 2.0.0 Trainer API. + + Parameters + ---------- + config : Any + Configuration object. + device : str or torch.device + Device to train on. + is_logger : bool + Whether this process is the logger. + source_data_config : Any + Source data configuration. + logger : Any, optional + Optional logger instance. + + Returns + ------- + Tuple[nn.Module, Dict[str, Any], Any] + Tuple of (model, normalizers, trainer). + """ + if logger is None: + # Fallback to print if no logger provided + def log_info(msg): + print(msg) + + def log_debug(msg): + pass + + logger = type("Logger", (), {"info": lambda self, msg: log_info(msg), "debug": lambda self, msg: log_debug(msg)})() + + logger.info("Starting pretraining on source domain...") + + # Create source dataset + logger.info(f"Loading source dataset from: {source_data_config.root}") + source_full_dataset = FloodDatasetWithQueryPoints( + data_root=source_data_config.root, + n_history=source_data_config.n_history, + xy_file=getattr(source_data_config, "xy_file", None), + query_res=getattr(source_data_config, "query_res", [64, 64]), + static_files=getattr(source_data_config, "static_files", []), + dynamic_patterns=getattr(source_data_config, "dynamic_patterns", {}), + boundary_patterns=getattr(source_data_config, "boundary_patterns", {}), + raise_on_smaller=True, + skip_before_timestep=getattr(source_data_config, "skip_before_timestep", 0), + noise_type=getattr(source_data_config, "noise_type", "none"), + noise_std=getattr(source_data_config, "noise_std", None) + ) + + # Split into train/val + train_sz_source = int(0.9 * len(source_full_dataset)) + source_train_raw, source_val_raw = random_split( + source_full_dataset, + [train_sz_source, len(source_full_dataset) - train_sz_source] + ) + logger.info(f"Source domain: total={len(source_full_dataset)}, train={train_sz_source}, val={len(source_val_raw)}") + + # Collect and normalize training data + logger.info("Collecting and normalizing training data...") + geom_s_tr, static_s_tr, boundary_s_tr, dyn_s_tr, tgt_s_tr = collect_all_fields(source_train_raw, True) + normalizers, big_source_train = stack_and_fit_transform( + geom_s_tr, static_s_tr, boundary_s_tr, dyn_s_tr, tgt_s_tr + ) + source_train_ds = NormalizedDataset( + geometry=big_source_train["geometry"], + static=big_source_train["static"], + boundary=big_source_train["boundary"], + dynamic=big_source_train["dynamic"], + target=big_source_train["target"], + query_res=source_data_config.query_res + ) + source_train_loader = DataLoader( + source_train_ds, batch_size=source_data_config.batch_size, shuffle=True + ) + + # Collect and normalize validation data + logger.info("Collecting and normalizing validation data...") + geom_s_val, static_s_val, boundary_s_val, dyn_s_val, tgt_s_val = collect_all_fields(source_val_raw, True) + _, big_source_val = stack_and_fit_transform( + geom_s_val, static_s_val, boundary_s_val, dyn_s_val, tgt_s_val, + normalizers=normalizers, fit_normalizers=False + ) + source_val_ds = NormalizedDataset( + geometry=big_source_val["geometry"], + static=big_source_val["static"], + boundary=big_source_val["boundary"], + dynamic=big_source_val["dynamic"], + target=big_source_val["target"], + query_res=source_data_config.query_res + ) + source_val_loader = DataLoader( + source_val_ds, batch_size=source_data_config.batch_size, shuffle=False + ) + + # Create model + logger.info("Creating GINO model...") + # Convert config.model to dict to avoid struct mode issues with neuralop's get_model + # neuralop's get_model tries to pop from config, which doesn't work with struct mode + # It expects config.model to exist, so we wrap it in a new OmegaConf DictConfig + # (not in struct mode) that supports both attribute and dict access + from omegaconf import OmegaConf + model_config_dict = OmegaConf.to_container(config.model, resolve=True) + + # Extract autoregressive parameter before passing to get_model (GINO doesn't accept it) + autoregressive = model_config_dict.pop("autoregressive", False) + + # Create a wrapper config that neuralop expects: {"model": {...}} + # Convert to OmegaConf DictConfig (not struct mode) so it supports attribute access + wrapper_config = OmegaConf.create({"model": model_config_dict}) + model = get_model(wrapper_config) + n_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Model created with {n_params:,} parameters") + + # Wrap model to filter out unexpected kwargs (like 'y') from Trainer + # Enable autoregressive residual connection if specified in config + model = GINOWrapper(model, autoregressive=autoregressive) + + # Create optimizer and scheduler + lr = config.training.get("learning_rate", 1e-4) + weight_decay = config.training.get("weight_decay", 1e-4) + optimizer_src = AdamW( + model.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + logger.info(f"Optimizer: AdamW (lr={lr}, weight_decay={weight_decay})") + scheduler_src = create_scheduler(optimizer_src, config, logger) + + # Create loss and data processor + # Get loss type from config, default to 'l2' + def create_loss(loss_type_str, default="l2"): + """Helper function to create loss function from string.""" + loss_type_str = loss_type_str.lower() + if loss_type_str == "l1": + return LpLossWrapper(LpLoss(d=2, p=1)), "l1" + elif loss_type_str == "l2": + return LpLossWrapper(LpLoss(d=2, p=2)), "l2" + else: + logger.warning(f"Unknown loss type '{loss_type_str}', defaulting to '{default}'") + return LpLossWrapper(LpLoss(d=2, p=2)), default + + training_loss_type = config.training.get("training_loss", "l2") + training_loss_fn, training_loss_name = create_loss(training_loss_type) + logger.info(f"Using {training_loss_name.upper()} loss for training") + + # Use testing_loss for evaluation if specified, otherwise use training_loss + testing_loss_type = config.training.get("testing_loss", training_loss_type) + eval_loss_fn, eval_loss_name = create_loss(testing_loss_type, default=training_loss_name) + if testing_loss_type.lower() != training_loss_type.lower(): + logger.info(f"Using {eval_loss_name.upper()} loss for evaluation (different from training)") + data_processor = FloodGINODataProcessor( + device=device, + target_norm=normalizers.get("target", None), + inverse_test=True + ) + data_processor.wrap(model) + + # Create trainer using PhysicsNeMo-style trainer + n_epochs = config.training.get("n_epochs_source", config.training.get("n_epochs", 100)) + logger.info(f"Creating NeuralOperatorTrainer for {n_epochs} epochs...") + trainer_src = NeuralOperatorTrainer( + model=model, + n_epochs=n_epochs, + data_processor=data_processor, + device=device, + wandb_log=config.wandb.get("log", False), + verbose=is_logger, + logger=logger if hasattr(logger, 'info') else None, + ) + + # Train using neuralop 2.0.0 API + save_dir = os.path.join(config.checkpoint.get("save_dir", "./checkpoints"), "pretrain") + logger.info(f"Starting training... Checkpoints will be saved to: {save_dir}") + logger.info(f"Training samples: {len(source_train_ds)}, Validation samples: {len(source_val_ds)}") + + # Get checkpoint saving options from config + save_best = config.checkpoint.get("save_best", None) + save_every = config.checkpoint.get("save_every", None) + + trainer_src.train( + train_loader=source_train_loader, + test_loaders={"source_val": source_val_loader}, + optimizer=optimizer_src, + scheduler=scheduler_src, + training_loss=training_loss_fn, + eval_losses={eval_loss_name: eval_loss_fn}, + save_dir=save_dir, + save_best=save_best, # Save best model based on validation metric (from config) + save_every=save_every, # Save checkpoint every N epochs (from config) + resume_from_dir=config.checkpoint.get("resume_from_source", None), + ) + + # Explicitly save final pretrained model checkpoint using PhysicsNeMo checkpoint system + # Ensure directory exists before saving + os.makedirs(save_dir, exist_ok=True) + logger.info(f"Saving final pretrained model checkpoint to {save_dir}") + + # Save model using helper function that handles PyTorch submodules + # If it returns True, model was saved separately (as PyTorch model) + model_saved_separately = save_model_checkpoint( + model=model, + save_dir=save_dir, + epoch=n_epochs - 1, # Final epoch (0-indexed) + metadata={"stage": "pretrain", "final_epoch": True}, + ) + + # Save optimizer, scheduler, and metadata using PhysicsNeMo + # Include model only if it wasn't saved separately + save_checkpoint( + path=save_dir, + models=None if model_saved_separately else model, + optimizer=optimizer_src, + scheduler=scheduler_src, + scaler=None, + epoch=n_epochs - 1, + metadata={"stage": "pretrain", "final_epoch": True}, + ) + logger.info("Saved pretrained model checkpoint using PhysicsNeMo format") + + # Save normalizers to checkpoint directory + normalizers_path = os.path.join(save_dir, "normalizers.pt") + torch.save(normalizers, normalizers_path) + logger.info(f"Saved normalizers to {normalizers_path}") + + logger.info("Pretraining completed!") + return model, normalizers, trainer_src diff --git a/examples/weather/flood_modeling/flood_forecaster/training/trainer.py b/examples/weather/flood_modeling/flood_forecaster/training/trainer.py new file mode 100644 index 0000000000..e84b46dc19 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/training/trainer.py @@ -0,0 +1,1171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +PhysicsNeMo-style Trainer for neural operator training. + +This module provides a Trainer class rewritten from neuralop's Trainer to follow +PhysicsNeMo patterns and conventions. It integrates with PhysicsNeMo's checkpointing, +logging, and distributed training infrastructure. + +Key features: +- PhysicsNeMo checkpoint system (save_checkpoint/load_checkpoint) +- DistributedManager for distributed training +- PhysicsNeMo logging patterns +- Support for data processors, regularizers, and mixed precision +- Best model tracking and interval checkpointing +- Autoregressive evaluation support +""" + +from __future__ import annotations + +import sys +import warnings +from pathlib import Path +from timeit import default_timer +from typing import Any, Dict, Literal, Optional, Union + +import torch +import torch.nn as nn +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from tqdm import tqdm + +import physicsnemo +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper +from physicsnemo.launch.utils.checkpoint import load_checkpoint, save_checkpoint + +import fsspec + +# Optional wandb import +try: + import wandb + + _WANDB_AVAILABLE = True +except ImportError: + _WANDB_AVAILABLE = False + + +def _has_pytorch_submodules(model: nn.Module) -> bool: + r""" + Check if a PhysicsNeMo Module contains PyTorch submodules that would prevent saving. + + PhysicsNeMo's Module.save() doesn't support saving modules that contain + PyTorch submodules (they must be converted using Module.from_torch). + This helper detects such cases so we can save them as PyTorch models instead. + + Note: With Option 1 implementation, GINOWrapper now auto-converts PyTorch models + at initialization, so this check is mainly for backward compatibility and + other edge cases. + + Parameters + ---------- + model : nn.Module + Model to check. + + Returns + ------- + bool + True if model is a PhysicsNeMo Module containing PyTorch submodules. + """ + if not isinstance(model, physicsnemo.models.Module): + return False + + # Check if any direct submodules are PyTorch modules (not PhysicsNeMo modules) + # Skip checking inner_model of converted wrappers (they're intentionally PyTorch) + for name, child in model.named_children(): + # Skip inner_model - it's a PyTorch model wrapped by PhysicsNeMo, which is fine + if name == 'inner_model': + continue + if isinstance(child, torch.nn.Module) and not isinstance(child, physicsnemo.models.Module): + return True + return False + + +def save_model_checkpoint( + model: nn.Module, + save_dir: Union[str, Path], + epoch: int, + metadata: Optional[Dict[str, Any]] = None, + model_parallel_rank: int = 0, +) -> bool: + r""" + Save a model checkpoint, handling both PhysicsNeMo modules and wrappers with PyTorch submodules. + + This function intelligently saves models: + - Pure PhysicsNeMo modules: Returns False (caller should use save_checkpoint normally) + - Wrappers with PyTorch submodules: Saves state_dict as PyTorch model, returns True + + Parameters + ---------- + model : nn.Module + Model to save. Can be a PhysicsNeMo Module or a wrapper. + save_dir : str or Path + Directory to save checkpoint. + epoch : int + Epoch number for checkpoint filename. + metadata : Dict[str, Any], optional + Additional metadata (not used here, but kept for API consistency). + model_parallel_rank : int, optional + Model parallel rank for distributed training. Default is 0. + + Returns + ------- + bool + True if model was saved as PyTorch model (caller should skip model in save_checkpoint), + False if model should be saved via save_checkpoint normally. + """ + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + # Handle DDP-wrapped models + if isinstance(model, DDP): + model = model.module + + # Check if we need to save as PyTorch model (due to PyTorch submodules) + save_as_pytorch = _has_pytorch_submodules(model) + + if save_as_pytorch: + # Save model state_dict manually as PyTorch model + # This bypasses PhysicsNeMo's Module.save() which doesn't support PyTorch submodules + model_name = model.__class__.__name__ + + # Create filename matching PhysicsNeMo format: {model_name}.{rank}.{epoch}.pt + model_filename = f"{model_name}.{model_parallel_rank}.{epoch}.pt" + model_path = save_dir / model_filename + + # Save model state_dict + protocol = fsspec.utils.get_protocol(str(save_dir)) + fs = fsspec.filesystem(protocol) + with fs.open(str(model_path), "wb") as fp: + torch.save(model.state_dict(), fp) + return True # Indicate model was saved separately + + return False # Model should be saved via save_checkpoint + + +class NeuralOperatorTrainer: + r""" + A Trainer class for neural operators following PhysicsNeMo patterns. + + This trainer provides a comprehensive training loop for neural operator models + with support for: + - Multiple evaluation loaders with different metrics + - Autoregressive evaluation modes + - Best model checkpointing based on validation metrics + - Interval-based checkpointing + - Mixed precision training + - Data preprocessing/postprocessing via data processors + - Regularizers (e.g., L1/L2 regularization) + - Distributed training via PhysicsNeMo's DistributedManager + + The trainer expects datasets to provide batches as key-value dictionaries, + e.g., ``{'x': x, 'y': y}``, that are keyed to the arguments expected by + models and losses. + + Parameters + ---------- + model : nn.Module + The neural operator model to train. + n_epochs : int + Total number of training epochs. + device : str or torch.device, optional + Device to train on. If None, uses DistributedManager.device if available, + otherwise defaults to 'cpu'. Default is None. + mixed_precision : bool, optional + Whether to use mixed precision training with torch.autocast. + Default is False. + data_processor : nn.Module, optional + Data processor module to transform data before/after model forward pass. + If provided, data is preprocessed with ``data_processor.preprocess()`` + before model forward, and postprocessed with ``data_processor.postprocess()`` + after model forward. Default is None. + eval_interval : int, optional + Frequency (in epochs) to evaluate model on validation sets. + Default is 1 (evaluate every epoch). + log_output : bool, optional + If True and wandb_log is True, log output images to wandb. + Default is False. + wandb_log : bool, optional + Whether to log results to wandb. Only logs if wandb is installed + and a wandb run is active. Default is False. + verbose : bool, optional + Whether to print training progress to stdout. Default is False. + logger : PythonLogger or RankZeroLoggingWrapper, optional + Optional logger instance. If None, creates a default logger. + Default is None. + scaler : GradScaler, optional + Optional gradient scaler for mixed precision training. If None and + mixed_precision is True, creates a new scaler. Default is None. + + Examples + -------- + >>> from neuralop import get_model + >>> from neuralop.training import AdamW + >>> from neuralop.losses import LpLoss + >>> + >>> model = get_model(config) + >>> trainer = NeuralOperatorTrainer( + ... model=model, + ... n_epochs=100, + ... device="cuda", + ... wandb_log=True, + ... verbose=True + ... ) + >>> + >>> trainer.train( + ... train_loader=train_loader, + ... test_loaders={"val": val_loader}, + ... optimizer=optimizer, + ... scheduler=scheduler, + ... training_loss=LpLoss(d=2), + ... eval_losses={"l2": LpLoss(d=2)}, + ... save_dir="./checkpoints", + ... save_best="val_l2" + ... ) + """ + + def __init__( + self, + *, + model: nn.Module, + n_epochs: int, + device: Optional[Union[str, torch.device]] = None, + mixed_precision: bool = False, + data_processor: Optional[nn.Module] = None, + eval_interval: int = 1, + log_output: bool = False, + wandb_log: bool = False, + verbose: bool = False, + logger: Optional[Union[PythonLogger, RankZeroLoggingWrapper]] = None, + scaler: Optional[GradScaler] = None, + ) -> None: + # Model and training configuration + self.model = model + self.n_epochs = n_epochs + self.eval_interval = eval_interval + self.log_output = log_output + self.verbose = verbose + self.data_processor = data_processor + + # Mixed precision configuration + self.mixed_precision = mixed_precision + if mixed_precision and scaler is None: + self.scaler = GradScaler() + else: + self.scaler = scaler + + # Device configuration + if device is None: + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + self.device = dist_manager.device if dist_manager.device else torch.device("cpu") + else: + self.device = torch.device("cpu") + elif isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + # Determine autocast device type + if isinstance(self.device, torch.device): + self.autocast_device_type = self.device.type + else: + self.autocast_device_type = "cuda" if "cuda" in str(self.device) else "cpu" + + # Wandb logging (only if available and run is active) + self.wandb_log = False + if _WANDB_AVAILABLE and wandb_log and wandb.run is not None: + self.wandb_log = True + + # Logging setup + if logger is None: + self.logger = PythonLogger(name="neural_operator_trainer") + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + self.logger = RankZeroLoggingWrapper(self.logger, dist_manager) + else: + self.logger = logger + + # Training state + self.start_epoch = 0 + self.epoch = 0 + self.optimizer: Optional[Optimizer] = None + self.scheduler: Optional[_LRScheduler] = None + self.regularizer: Optional[Any] = None + + # Checkpointing configuration + self.save_every: Optional[int] = None + self.save_best: Optional[str] = None + self.best_metric_value: float = float("inf") + + # Metrics accumulation for wandb + self.wandb_epoch_metrics: Optional[Dict[str, Any]] = None + + def train( + self, + train_loader: DataLoader, + test_loaders: Dict[str, DataLoader], + optimizer: Optimizer, + scheduler: _LRScheduler, + regularizer: Optional[Any] = None, + training_loss: Optional[Any] = None, + eval_losses: Optional[Dict[str, Any]] = None, + eval_modes: Optional[Dict[str, Literal["single_step", "autoregression"]]] = None, + save_every: Optional[int] = None, + save_best: Optional[str] = None, + save_dir: Union[str, Path] = "./checkpoints", + resume_from_dir: Optional[Union[str, Path]] = None, + max_autoregressive_steps: Optional[int] = None, + ) -> Dict[str, Any]: + r""" + Train the model on the given dataset. + + This method implements the main training loop with support for: + - Training on a training dataloader + - Evaluation on multiple test dataloaders with different metrics + - Best model checkpointing based on validation metrics + - Interval-based checkpointing + - Resuming from checkpoints + + Parameters + ---------- + train_loader : DataLoader + Training dataloader providing batches for training. + test_loaders : Dict[str, DataLoader] + Dictionary of test/validation dataloaders keyed by name. + Each loader will be evaluated with all metrics in eval_losses. + optimizer : Optimizer + Optimizer to use during training. + scheduler : _LRScheduler + Learning rate scheduler to use during training. + regularizer : Any, optional + Optional regularizer (e.g., L1/L2) to add to training loss. + Must have a ``loss`` attribute and ``reset()`` method. + Default is None. + training_loss : Any, optional + Loss function for training. Must be callable as ``loss(pred, **kwargs)``. + If None, defaults to LpLoss(d=2). Default is None. + eval_losses : Dict[str, Any], optional + Dictionary of loss functions for evaluation, keyed by loss name. + Each loss will be evaluated on all test_loaders, with metrics + named as ``{loader_name}_{loss_name}``. If None, uses training_loss + as "l2". Default is None. + eval_modes : Dict[str, Literal["single_step", "autoregression"]], optional + Optional mapping from loader name to evaluation mode. + - "single_step": Predict one input-output pair and evaluate loss. + - "autoregression": Autoregressively predict output using last step's + output as input for multiple steps. Requires data processor with + step-aware preprocess/postprocess methods. + If not provided, defaults to "single_step" for all loaders. + Default is None. + save_every : int, optional + Interval (in epochs) at which to save checkpoints. + If None, no interval-based checkpointing is performed. + Default is None. + save_best : str, optional + Metric name (format: ``{loader_name}_{loss_name}``) to monitor + for best model saving. When this metric improves, a checkpoint is saved. + Overrides save_every when set. Default is None. + save_dir : str or Path, optional + Directory to save training checkpoints. Default is "./checkpoints". + resume_from_dir : str or Path, optional + Directory containing checkpoint to resume from. If provided, loads + model, optimizer, scheduler, and regularizer states and resumes + training from the saved epoch. Default is None. + max_autoregressive_steps : int, optional + Maximum number of autoregressive steps to perform during evaluation. + Only used when eval_mode is "autoregression". If None, runs full rollout. + Default is None. + + Returns + ------- + Dict[str, Any] + Dictionary of metrics from the last validation epoch, keyed as + ``{loader_name}_{loss_name}`` for each test loader and loss combination. + + Raises + ------ + ValueError + If save_best metric name is not found in available metrics. + FileNotFoundError + If resume_from_dir is provided but checkpoint files are not found. + """ + # Store training components + self.optimizer = optimizer + self.scheduler = scheduler + self.regularizer = regularizer + + # Default training loss + if training_loss is None: + from neuralop.losses import LpLoss + + training_loss = LpLoss(d=2) + + # Warn if training loss reduces across batch dimension + if hasattr(training_loss, "reduction") and training_loss.reduction == "mean": + warnings.warn( + f"Training loss has reduction='mean'. Trainer expects losses " + f"to sum across batch dimension, not average.", + UserWarning, + stacklevel=2, + ) + + # Default evaluation losses + if eval_losses is None: + eval_losses = {"l2": training_loss} + + # Default evaluation modes + if eval_modes is None: + eval_modes = {} + + # Checkpointing configuration + self.save_every = save_every + self.save_best = save_best + + # Resume from checkpoint if provided + if resume_from_dir is not None: + self._resume_from_checkpoint(resume_from_dir) + + # Move model to device + self.model = self.model.to(self.device) + + # Setup distributed training if available + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + if dist_manager.distributed: + self.model = DDP( + self.model, + device_ids=[dist_manager.local_rank], + output_device=dist_manager.local_rank, + ) + if self.verbose and dist_manager.rank == 0: + self.logger.info(f"Using distributed training (rank {dist_manager.rank})") + + # Move data processor to device + if self.data_processor is not None: + self.data_processor = self.data_processor.to(self.device) + + # Validate save_best metric exists + if self.save_best is not None: + available_metrics = [] + for loader_name in test_loaders.keys(): + for loss_name in eval_losses.keys(): + available_metrics.append(f"{loader_name}_{loss_name}") + if self.save_best not in available_metrics: + raise ValueError( + f"save_best metric '{self.save_best}' not found in available metrics. " + f"Available metrics: {available_metrics}" + ) + self.best_metric_value = float("inf") + # Best model saving overrides interval saving + self.save_every = None + + # Log training setup + if self.verbose: + self.logger.info(f"Training on {len(train_loader.dataset)} samples") + self.logger.info( + f"Testing on {[len(loader.dataset) for loader in test_loaders.values()]} samples " + f"on loaders {list(test_loaders.keys())}" + ) + + # Initialize epoch_metrics in case loop doesn't execute + epoch_metrics = {} + + # Main training loop + # Only show progress bar on rank 0 in distributed training + is_rank_zero = True + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + is_rank_zero = dist_manager.rank == 0 + epoch_range = range(self.start_epoch, self.n_epochs) + if is_rank_zero and self.verbose: + epoch_range = tqdm(epoch_range, desc="Training", unit="epoch") + + for epoch in epoch_range: + self.epoch = epoch + + # Train for one epoch + train_metrics = self._train_one_epoch(epoch, train_loader, training_loss) + epoch_metrics = train_metrics.copy() + + # Evaluate if at eval interval + if epoch % self.eval_interval == 0: + eval_metrics = self._evaluate_all( + epoch=epoch, + eval_losses=eval_losses, + test_loaders=test_loaders, + eval_modes=eval_modes, + max_autoregressive_steps=max_autoregressive_steps, + ) + epoch_metrics.update(eval_metrics) + + # Save best model if metric improved + if save_best is not None and eval_metrics[save_best] < self.best_metric_value: + self.best_metric_value = eval_metrics[save_best] + self._save_checkpoint(save_dir, is_best=True) + + # Save checkpoint at interval + if self.save_every is not None and epoch % self.save_every == 0: + self._save_checkpoint(save_dir, is_best=False) + + return epoch_metrics + + def _train_one_epoch( + self, epoch: int, train_loader: DataLoader, training_loss: Any + ) -> Dict[str, Any]: + r""" + Train the model for one epoch. + + Parameters + ---------- + epoch : int + Current epoch number. + train_loader : DataLoader + Training dataloader. + training_loss : Any + Training loss function. + + Returns + ------- + Dict[str, Any] + Dictionary containing training metrics: + - train_err: Average training error per batch + - avg_loss: Average loss per sample + - avg_lasso_loss: Average regularizer loss (if regularizer exists) + - epoch_train_time: Time taken for epoch + """ + self.model.train() + if self.data_processor is not None: + self.data_processor.train() + + avg_loss = 0.0 + avg_lasso_loss = 0.0 + train_err = 0.0 + n_samples = 0 + + t1 = default_timer() + + # Only show progress bar on rank 0 in distributed training + is_rank_zero = True + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + is_rank_zero = dist_manager.rank == 0 + loader_iter = train_loader + if is_rank_zero and self.verbose: + loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.n_epochs}", unit="batch", leave=False) + + for idx, sample in enumerate(loader_iter): + loss = self._train_one_batch(idx, sample, training_loss) + + # Track number of samples in batch + if isinstance(sample.get("y"), torch.Tensor): + n_samples += sample["y"].shape[0] + else: + n_samples += 1 + + # Backward pass with optional mixed precision + if self.mixed_precision and self.scaler is not None: + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + loss.backward() + self.optimizer.step() + + train_err += loss.item() + with torch.no_grad(): + avg_loss += loss.item() + if self.regularizer is not None: + avg_lasso_loss += self.regularizer.loss + + # Update progress bar with current loss + if is_rank_zero and self.verbose and hasattr(loader_iter, 'set_postfix'): + loader_iter.set_postfix({'loss': f'{loss.item():.6f}'}) + + # Update learning rate scheduler + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(train_err) + else: + self.scheduler.step() + + epoch_train_time = default_timer() - t1 + + # Normalize metrics + train_err /= len(train_loader) + avg_loss /= n_samples if n_samples > 0 else 1 + if self.regularizer is not None: + avg_lasso_loss /= n_samples if n_samples > 0 else 1 + else: + avg_lasso_loss = None + + # Get current learning rate + lr = None + for param_group in self.optimizer.param_groups: + lr = param_group["lr"] + break + + # Log training metrics + if self.verbose and epoch % self.eval_interval == 0: + self._log_training( + epoch=epoch, + time=epoch_train_time, + avg_loss=avg_loss, + train_err=train_err, + avg_lasso_loss=avg_lasso_loss, + lr=lr, + ) + + return { + "train_err": train_err, + "avg_loss": avg_loss, + "avg_lasso_loss": avg_lasso_loss, + "epoch_train_time": epoch_train_time, + } + + def _train_one_batch(self, idx: int, sample: Dict[str, Any], training_loss: Any) -> torch.Tensor: + r""" + Train on a single batch. + + Parameters + ---------- + idx : int + Batch index. + sample : Dict[str, Any] + Batch data dictionary. + training_loss : Any + Training loss function. + + Returns + ------- + torch.Tensor + Training loss tensor. + """ + self.optimizer.zero_grad(set_to_none=True) + if self.regularizer is not None: + self.regularizer.reset() + + # Preprocess data + if self.data_processor is not None: + sample = self.data_processor.preprocess(sample) + else: + # Move tensors to device if no processor + sample = { + k: v.to(self.device) if torch.is_tensor(v) else v + for k, v in sample.items() + } + + # Forward pass with optional mixed precision + if self.mixed_precision: + # Use autocast with device_type for newer PyTorch, fallback for older versions + if hasattr(torch.amp, 'autocast') and self.autocast_device_type == 'cuda': + # PyTorch 2.0+ with torch.amp.autocast + with torch.amp.autocast(device_type=self.autocast_device_type): + out = self.model(**sample) + else: + # Older PyTorch versions or CPU - use torch.cuda.amp.autocast (CPU will be no-op) + with autocast(): + out = self.model(**sample) + else: + out = self.model(**sample) + + # Log output shape on first batch of first epoch + if self.epoch == 0 and idx == 0 and self.verbose and isinstance(out, torch.Tensor): + self.logger.info(f"Model output shape: {out.shape}") + + # Postprocess output + if self.data_processor is not None: + out, sample = self.data_processor.postprocess(out, sample) + + # Compute loss + if self.mixed_precision: + if hasattr(torch.amp, 'autocast') and self.autocast_device_type == 'cuda': + # PyTorch 2.0+ with torch.amp.autocast + with torch.amp.autocast(device_type=self.autocast_device_type): + loss = training_loss(out, **sample) + else: + # Older PyTorch versions or CPU + with autocast(): + loss = training_loss(out, **sample) + else: + loss = training_loss(out, **sample) + + # Add regularizer loss + if self.regularizer is not None: + loss = loss + self.regularizer.loss + + return loss + + def _evaluate_all( + self, + epoch: int, + eval_losses: Dict[str, Any], + test_loaders: Dict[str, DataLoader], + eval_modes: Dict[str, Literal["single_step", "autoregression"]], + max_autoregressive_steps: Optional[int] = None, + ) -> Dict[str, Any]: + r""" + Evaluate model on all test loaders. + + Parameters + ---------- + epoch : int + Current epoch number. + eval_losses : Dict[str, Any] + Dictionary of loss functions for evaluation. + test_loaders : Dict[str, DataLoader] + Dictionary of test dataloaders. + eval_modes : Dict[str, Literal["single_step", "autoregression"]] + Evaluation mode for each loader. + max_autoregressive_steps : int, optional + Maximum autoregressive steps. + + Returns + ------- + Dict[str, Any] + Dictionary of evaluation metrics keyed as ``{loader_name}_{loss_name}``. + """ + all_metrics = {} + for loader_name, loader in test_loaders.items(): + loader_eval_mode = eval_modes.get(loader_name, "single_step") + loader_metrics = self._evaluate( + eval_losses=eval_losses, + data_loader=loader, + log_prefix=loader_name, + mode=loader_eval_mode, + max_steps=max_autoregressive_steps, + ) + all_metrics.update(loader_metrics) + + if self.verbose: + self._log_eval(epoch=epoch, eval_metrics=all_metrics) + + return all_metrics + + def _evaluate( + self, + eval_losses: Dict[str, Any], + data_loader: DataLoader, + log_prefix: str = "", + mode: Literal["single_step", "autoregression"] = "single_step", + max_steps: Optional[int] = None, + ) -> Dict[str, Any]: + r""" + Evaluate model on a single dataloader. + + Parameters + ---------- + eval_losses : Dict[str, Any] + Dictionary of loss functions. + data_loader : DataLoader + Dataloader to evaluate on. + log_prefix : str, optional + Prefix for metric names. Default is "". + mode : Literal["single_step", "autoregression"], optional + Evaluation mode. Default is "single_step". + max_steps : int, optional + Maximum steps for autoregressive mode. Default is None. + + Returns + ------- + Dict[str, Any] + Dictionary of evaluation metrics. + """ + self.model.eval() + if self.data_processor is not None: + self.data_processor.eval() + + # Initialize error tracking + errors = {f"{log_prefix}_{loss_name}": 0.0 for loss_name in eval_losses.keys()} + + # Warn if eval losses reduce across batch + for eval_loss in eval_losses.values(): + if hasattr(eval_loss, "reduction") and eval_loss.reduction == "mean": + warnings.warn( + f"Eval loss has reduction='mean'. Trainer expects losses " + f"to sum across batch dimension.", + UserWarning, + stacklevel=2, + ) + + n_samples = 0 + # Only show progress bar on rank 0 in distributed training + is_rank_zero = True + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + is_rank_zero = dist_manager.rank == 0 + loader_iter = data_loader + if is_rank_zero and self.verbose: + loader_iter = tqdm(data_loader, desc=f"Evaluating ({log_prefix})", unit="batch", leave=False) + + with torch.no_grad(): + for idx, sample in enumerate(loader_iter): + return_output = idx == len(data_loader) - 1 + + # Track samples before processing + if "y" in sample: + if isinstance(sample["y"], torch.Tensor): + n_samples += sample["y"].shape[0] + else: + n_samples += 1 + + if mode == "single_step": + eval_step_losses, outs = self._eval_one_batch( + sample, eval_losses, return_output=return_output + ) + elif mode == "autoregression": + eval_step_losses, outs = self._eval_one_batch_autoreg( + sample, + eval_losses, + return_output=return_output, + max_steps=max_steps, + ) + else: + raise ValueError(f"Unknown evaluation mode: {mode}") + + # Accumulate losses + for loss_name, val_loss in eval_step_losses.items(): + errors[f"{log_prefix}_{loss_name}"] += val_loss + + # Normalize by number of samples + for key in errors.keys(): + errors[key] /= n_samples if n_samples > 0 else 1 + + # Log outputs to wandb if requested + if self.log_output and self.wandb_log and outs is not None: + errors[f"{log_prefix}_outputs"] = wandb.Image(outs) + + return errors + + def _eval_one_batch( + self, sample: Dict[str, Any], eval_losses: Dict[str, Any], return_output: bool = False + ) -> tuple[Dict[str, float], Optional[torch.Tensor]]: + r""" + Evaluate on a single batch (single step mode). + + Parameters + ---------- + sample : Dict[str, Any] + Batch data dictionary. + eval_losses : Dict[str, Any] + Dictionary of loss functions. + return_output : bool, optional + Whether to return model outputs. Default is False. + + Returns + ------- + tuple[Dict[str, float], Optional[torch.Tensor]] + Dictionary of losses and optional model outputs. + """ + # Preprocess data + if self.data_processor is not None: + sample = self.data_processor.preprocess(sample) + else: + sample = { + k: v.to(self.device) if torch.is_tensor(v) else v + for k, v in sample.items() + } + + # Forward pass + out = self.model(**sample) + + # Postprocess output + if self.data_processor is not None: + out, sample = self.data_processor.postprocess(out, sample) + + # Compute losses + eval_step_losses = {} + for loss_name, loss_fn in eval_losses.items(): + val_loss = loss_fn(out, **sample) + eval_step_losses[loss_name] = val_loss.item() if isinstance(val_loss, torch.Tensor) else val_loss + + if return_output: + return eval_step_losses, out + else: + return eval_step_losses, None + + def _eval_one_batch_autoreg( + self, + sample: Dict[str, Any], + eval_losses: Dict[str, Any], + return_output: bool = False, + max_steps: Optional[int] = None, + ) -> tuple[Dict[str, float], Optional[torch.Tensor]]: + r""" + Evaluate on a single batch (autoregressive mode). + + Parameters + ---------- + sample : Dict[str, Any] + Batch data dictionary. + eval_losses : Dict[str, Any] + Dictionary of loss functions. + return_output : bool, optional + Whether to return model outputs. Default is False. + max_steps : int, optional + Maximum number of autoregressive steps. Default is None. + + Returns + ------- + tuple[Dict[str, float], Optional[torch.Tensor]] + Dictionary of losses and optional model outputs. + """ + eval_step_losses = {loss_name: 0.0 for loss_name in eval_losses.keys()} + t = 0 + max_steps = max_steps if max_steps is not None else float("inf") + final_out = None + + while sample is not None and t < max_steps: + # Preprocess data with step index + if self.data_processor is not None: + sample = self.data_processor.preprocess(sample, step=t) + else: + sample = { + k: v.to(self.device) if torch.is_tensor(v) else v + for k, v in sample.items() + } + + if sample is None: + break + + # Forward pass + out = self.model(**sample) + + # Postprocess output with step index + if self.data_processor is not None: + out, sample = self.data_processor.postprocess(out, sample, step=t) + + # Accumulate losses + for loss_name, loss_fn in eval_losses.items(): + step_loss = loss_fn(out, **sample) + step_loss_val = step_loss.item() if isinstance(step_loss, torch.Tensor) else step_loss + eval_step_losses[loss_name] += step_loss_val + + final_out = out + t += 1 + + # Average over steps + if t > 0: + for loss_name in eval_step_losses.keys(): + eval_step_losses[loss_name] /= t + + if return_output: + return eval_step_losses, final_out + else: + return eval_step_losses, None + + def _log_training( + self, + epoch: int, + time: float, + avg_loss: float, + train_err: float, + avg_lasso_loss: Optional[float] = None, + lr: Optional[float] = None, + ) -> None: + r""" + Log training metrics. + + Parameters + ---------- + epoch : int + Current epoch. + time : float + Training time for epoch. + avg_loss : float + Average loss per sample. + train_err : float + Training error per batch. + avg_lasso_loss : float, optional + Average regularizer loss. + lr : float, optional + Current learning rate. + """ + msg = f"[Epoch {epoch}] time={time:.2f}s, " + msg += f"avg_loss={avg_loss:.4f}, " + msg += f"train_err={train_err:.4f}" + if avg_lasso_loss is not None: + msg += f", avg_lasso={avg_lasso_loss:.4f}" + if lr is not None: + msg += f", lr={lr:.6f}" + + self.logger.info(msg) + + # Log to wandb + if self.wandb_log: + values_to_log = { + "train_err": train_err, + "time": time, + "avg_loss": avg_loss, + "lr": lr, + } + if avg_lasso_loss is not None: + values_to_log["avg_lasso_loss"] = avg_lasso_loss + wandb.log(data=values_to_log, step=epoch + 1, commit=False) + + def _log_eval(self, epoch: int, eval_metrics: Dict[str, Any]) -> None: + r""" + Log evaluation metrics. + + Parameters + ---------- + epoch : int + Current epoch. + eval_metrics : Dict[str, Any] + Dictionary of evaluation metrics. + """ + msg = "Eval: " + values_to_log = {} + for metric, value in eval_metrics.items(): + if isinstance(value, (float, int)) or (isinstance(value, torch.Tensor) and value.numel() == 1): + val = float(value.item() if isinstance(value, torch.Tensor) else value) + msg += f"{metric}={val:.4f}, " + if self.wandb_log: + values_to_log[metric] = val + + msg = msg.rstrip(", ") + self.logger.info(msg) + + # Log to wandb + if self.wandb_log: + wandb.log(data=values_to_log, step=epoch + 1, commit=True) + + def _save_checkpoint(self, save_dir: Union[str, Path], is_best: bool = False) -> None: + r""" + Save training checkpoint using PhysicsNeMo checkpoint system. + + This method handles both pure PhysicsNeMo modules and wrapper modules + that contain PyTorch submodules (like GINOWrapper). For wrappers with + PyTorch submodules, it saves the model's state_dict as a PyTorch model + instead of using PhysicsNeMo's Module.save(). + + Parameters + ---------- + save_dir : str or Path + Directory to save checkpoint. + is_best : bool, optional + Whether this is the best model checkpoint. Default is False. + """ + # Only save on rank 0 in distributed training + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + if dist_manager.rank != 0: + return + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + # Prepare metadata + metadata = { + "epoch": self.epoch, + "is_best": is_best, + "best_metric_value": self.best_metric_value if is_best else None, + } + + # Determine model parallel rank (for distributed training compatibility) + model_parallel_rank = 0 + if DistributedManager.is_initialized(): + dist_manager = DistributedManager() + if "model_parallel" in dist_manager.group_names: + model_parallel_rank = dist_manager.group_rank("model_parallel") + + # Use actual epoch number for checkpoint filename + # Best model is tracked via metadata, not filename + save_epoch = self.epoch + + # Handle model saving separately if it contains PyTorch submodules + model_to_save = self.model + if isinstance(self.model, DDP): + model_to_save = self.model.module + + save_as_pytorch = _has_pytorch_submodules(model_to_save) + + if save_as_pytorch: + # Save model state_dict manually as PyTorch model + model_name = model_to_save.__class__.__name__ + model_filename = f"{model_name}.{model_parallel_rank}.{save_epoch}.pt" + model_path = save_dir / model_filename + + protocol = fsspec.utils.get_protocol(str(save_dir)) + fs = fsspec.filesystem(protocol) + with fs.open(str(model_path), "wb") as fp: + torch.save(model_to_save.state_dict(), fp) + + if self.verbose: + self.logger.info(f"Saved model state_dict as PyTorch model: {model_path}") + + # Save training state (optimizer, scheduler, etc.) using PhysicsNeMo + # Include model only if it's not a PyTorch wrapper + save_checkpoint( + path=str(save_dir), + models=None if save_as_pytorch else model_to_save, + optimizer=self.optimizer, + scheduler=self.scheduler, + scaler=self.scaler, + epoch=save_epoch, + metadata=metadata, + ) + + if self.verbose: + checkpoint_type = "best model" if is_best else "checkpoint" + self.logger.info(f"Saved {checkpoint_type} to {save_dir} (epoch {self.epoch})") + + def _resume_from_checkpoint(self, resume_dir: Union[str, Path]) -> None: + r""" + Resume training from checkpoint. + + Parameters + ---------- + resume_dir : str or Path + Directory containing checkpoint. + + Raises + ------ + FileNotFoundError + If checkpoint directory or files are not found. + """ + resume_dir = Path(resume_dir) + if not resume_dir.exists(): + raise FileNotFoundError(f"Checkpoint directory not found: {resume_dir}") + + # Load checkpoint using PhysicsNeMo system + # Load latest checkpoint (epoch=None loads most recent) + metadata_dict = {} + resume_epoch = load_checkpoint( + path=str(resume_dir), + models=self.model, + optimizer=self.optimizer, + scheduler=self.scheduler, + scaler=self.scaler, + epoch=None, # Load latest + metadata_dict=metadata_dict, + device=self.device, + ) + + # Update training state + if resume_epoch is not None and resume_epoch > self.start_epoch: + self.start_epoch = resume_epoch + 1 # Resume from next epoch + if self.verbose: + self.logger.info(f"Resuming training from epoch {resume_epoch}") + + # Extract best metric value if available + if "best_metric_value" in metadata_dict: + self.best_metric_value = metadata_dict["best_metric_value"] + diff --git a/examples/weather/flood_modeling/flood_forecaster/utils/__init__.py b/examples/weather/flood_modeling/flood_forecaster/utils/__init__.py new file mode 100644 index 0000000000..7826f56923 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/utils/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Utility modules for flood prediction. + +This module provides normalization utilities specific to FloodForecaster. +For logging, use physicsnemo.launch.logging (PythonLogger, RankZeroLoggingWrapper). +For configuration, use Hydra (already integrated in train.py and inference.py). +""" + +from .normalization import ( + collect_all_fields, + stack_and_fit_transform, + transform_with_existing_normalizers, +) + +__all__ = [ + "collect_all_fields", + "stack_and_fit_transform", + "transform_with_existing_normalizers", +] + diff --git a/examples/weather/flood_modeling/flood_forecaster/utils/normalization.py b/examples/weather/flood_modeling/flood_forecaster/utils/normalization.py new file mode 100644 index 0000000000..362b6497c1 --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/utils/normalization.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Normalization utilities for flood prediction datasets. +""" + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from neuralop.data.transforms.normalizers import UnitGaussianNormalizer + + +def collect_all_fields( + dataset, + expect_target: bool = True +) -> Union[ + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]]], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[Optional[torch.Tensor]], List[torch.Tensor]] +]: + r""" + Collect all fields from a dataset into lists. + + Parameters + ---------- + dataset : Dataset + Dataset to collect fields from. + expect_target : bool, optional, default=True + Whether to expect target field. + + Returns + ------- + Tuple[List[torch.Tensor], ...] + Tuple of lists: (geometry, static, boundary, dynamic, target, [cell_area]). + If cell_area is found, returns 6-tuple, otherwise 5-tuple. + + Raises + ------ + KeyError + If required fields are missing. + """ + geometry_list = [] + static_list = [] + boundary_list = [] + dynamic_list = [] + target_list = [] + cell_area_list = [] + + for i in range(len(dataset)): + sample = dataset[i] + # Validate required fields + required_fields = ["geometry", "static", "boundary", "dynamic"] + missing_fields = [field for field in required_fields if field not in sample] + if missing_fields: + raise KeyError(f"Sample {i} missing required fields: {missing_fields}") + + geometry_list.append(sample["geometry"]) + static_list.append(sample["static"]) + boundary_list.append(sample["boundary"]) + dynamic_list.append(sample["dynamic"]) + if expect_target: + target_list.append(sample.get("target", None)) + if "cell_area" in sample: + cell_area_list.append(sample["cell_area"]) + + # Return cell_area if it was found + if cell_area_list: + return geometry_list, static_list, boundary_list, dynamic_list, target_list, cell_area_list + else: + return geometry_list, static_list, boundary_list, dynamic_list, target_list + + +def stack_and_fit_transform( + geom_list: List[torch.Tensor], + static_list: List[torch.Tensor], + boundary_list: List[torch.Tensor], + dyn_list: List[torch.Tensor], + tgt_list: List[Optional[torch.Tensor]], + normalizers: Optional[Dict[str, UnitGaussianNormalizer]] = None, + fit_normalizers: bool = True +) -> Tuple[Dict[str, UnitGaussianNormalizer], Dict[str, torch.Tensor]]: + r""" + Stack field lists into tensors and apply normalization. + + Parameters + ---------- + geom_list : List[torch.Tensor] + List of geometry tensors. + static_list : List[torch.Tensor] + List of static feature tensors. + boundary_list : List[torch.Tensor] + List of boundary condition tensors. + dyn_list : List[torch.Tensor] + List of dynamic feature tensors. + tgt_list : List[Optional[torch.Tensor]] + List of target tensors. + normalizers : Dict[str, UnitGaussianNormalizer], optional + Dict of existing normalizers (if fit_normalizers=False). + fit_normalizers : bool, optional, default=True + Whether to fit new normalizers. + + Returns + ------- + Tuple[Dict[str, UnitGaussianNormalizer], Dict[str, torch.Tensor]] + Tuple of (normalizers dict, big_tensors dict). + + Raises + ------ + ValueError + If lists are empty or have incompatible shapes. + """ + # Filter out None values before stacking to avoid errors + geom_list = [g for g in geom_list if g is not None] if geom_list else [] + static_list = [s for s in static_list if s is not None] if static_list else [] + boundary_list = [b for b in boundary_list if b is not None] if boundary_list else [] + dyn_list = [d for d in dyn_list if d is not None] if dyn_list else [] + tgt_list = [t for t in tgt_list if t is not None] if tgt_list else [] + + geometry_big = torch.stack(geom_list, dim=0) if geom_list else None + static_big = torch.stack(static_list, dim=0) if static_list else None + boundary_big = torch.stack(boundary_list, dim=0) if boundary_list else None + dynamic_big = torch.stack(dyn_list, dim=0) if dyn_list else None + target_big = torch.stack(tgt_list, dim=0) if tgt_list else None + + if normalizers is None: + normalizers = {} + + if geometry_big is not None: + if fit_normalizers: + geometry_norm = UnitGaussianNormalizer(dim=[0, 1]) + geometry_norm.fit(geometry_big) + geometry_big = geometry_norm.transform(geometry_big) + normalizers["geometry"] = geometry_norm + else: + geometry_big = normalizers["geometry"].transform(geometry_big) + + if static_big is not None: + if fit_normalizers: + static_norm = UnitGaussianNormalizer(dim=[0, 1]) + static_norm.fit(static_big) + static_big = static_norm.transform(static_big) + normalizers["static"] = static_norm + else: + static_big = normalizers["static"].transform(static_big) + + if boundary_big is not None: + if fit_normalizers: + boundary_norm = UnitGaussianNormalizer(dim=[0, 1, 2]) + boundary_norm.fit(boundary_big) + boundary_big = boundary_norm.transform(boundary_big) + normalizers["boundary"] = boundary_norm + else: + boundary_big = normalizers["boundary"].transform(boundary_big) + + if target_big is not None: + if fit_normalizers: + target_norm = UnitGaussianNormalizer(dim=[0, 1, 2]) + target_norm.fit(target_big) + target_big = target_norm.transform(target_big) + normalizers["target"] = target_norm + else: + target_big = normalizers["target"].transform(target_big) + + if dynamic_big is not None: + if "target" in normalizers and normalizers["target"] is not None: + # Use target normalizer for dynamic fields if available + dynamic_big = normalizers["target"].transform(dynamic_big) + normalizers["dynamic"] = normalizers["target"] + elif fit_normalizers: + # Create separate normalizer for dynamic if target is unavailable + dynamic_norm = UnitGaussianNormalizer(dim=[0, 1, 2]) + dynamic_norm.fit(dynamic_big) + dynamic_big = dynamic_norm.transform(dynamic_big) + normalizers["dynamic"] = dynamic_norm + elif "dynamic" in normalizers: + # Use existing dynamic normalizer if available + dynamic_big = normalizers["dynamic"].transform(dynamic_big) + + big_tensors = { + "geometry": geometry_big, + "static": static_big, + "boundary": boundary_big, + "dynamic": dynamic_big, + "target": target_big, + } + return normalizers, big_tensors + + +def transform_with_existing_normalizers( + geom_list: List[torch.Tensor], + static_list: List[torch.Tensor], + boundary_list: List[torch.Tensor], + dyn_list: List[torch.Tensor], + normalizers: Dict[str, UnitGaussianNormalizer] +) -> Dict[str, torch.Tensor]: + r""" + Transform data lists using existing normalizers. + + Parameters + ---------- + geom_list : List[torch.Tensor] + List of geometry tensors. + static_list : List[torch.Tensor] + List of static feature tensors. + boundary_list : List[torch.Tensor] + List of boundary condition tensors. + dyn_list : List[torch.Tensor] + List of dynamic feature tensors. + normalizers : Dict[str, UnitGaussianNormalizer] + Dict of normalizers to use. + + Returns + ------- + Dict[str, torch.Tensor] + Dict of transformed tensors. + + Raises + ------ + KeyError + If required normalizers are missing. + ValueError + If lists are empty. + """ + if not normalizers: + raise ValueError("normalizers dict cannot be empty") + transformed = {} + data_map = {"geometry": geom_list, "static": static_list, "boundary": boundary_list, "dynamic": dyn_list} + + for key, data_list in data_map.items(): + if data_list and key in normalizers: + big_tensor = torch.stack(data_list, dim=0) + transformed[key] = normalizers[key].transform(big_tensor) + + return transformed + diff --git a/examples/weather/flood_modeling/flood_forecaster/utils/plotting.py b/examples/weather/flood_modeling/flood_forecaster/utils/plotting.py new file mode 100644 index 0000000000..822673d79a --- /dev/null +++ b/examples/weather/flood_modeling/flood_forecaster/utils/plotting.py @@ -0,0 +1,861 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Plotting and visualization utilities for flood prediction. +""" + +import os +import warnings + +import matplotlib as mpl +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np + +# Set matplotlib defaults +mpl.rcParams.update({ + "font.family": "serif", + "font.size": 14, + "axes.titlesize": 16, + "axes.labelsize": 14, + "legend.fontsize": 12, + "xtick.labelsize": 12, + "ytick.labelsize": 12, +}) + + +def create_rollout_animation( + geometry, + wd_gt, wd_pred, + vx_gt, vy_gt, + vx_pred, vy_pred, + run_id=None, + out_dir=".", + filename_prefix="rollout", + dt_seconds: float = 1200.0 +): + r""" + Creates an animation comparing Ground Truth and Predictions in a 3x2 grid. + + Parameters + ---------- + geometry : np.ndarray or torch.Tensor + Geometry coordinates of shape :math:`(n_{cells}, 2)`. + wd_gt : np.ndarray or torch.Tensor + Ground truth water depth of shape :math:`(T, n_{cells})`. + wd_pred : np.ndarray or torch.Tensor + Predicted water depth of shape :math:`(T, n_{cells})`. + vx_gt : np.ndarray or torch.Tensor + Ground truth x-velocity of shape :math:`(T, n_{cells})`. + vy_gt : np.ndarray or torch.Tensor + Ground truth y-velocity of shape :math:`(T, n_{cells})`. + vx_pred : np.ndarray or torch.Tensor + Predicted x-velocity of shape :math:`(T, n_{cells})`. + vy_pred : np.ndarray or torch.Tensor + Predicted y-velocity of shape :math:`(T, n_{cells})`. + run_id : str, optional + Run identifier for title. + out_dir : str, optional, default="." + Output directory for animation file. + filename_prefix : str, optional, default="rollout" + Prefix for output filename. + dt_seconds : float, optional, default=1200.0 + Time step size in seconds. + """ + # Convert inputs to numpy arrays + if not isinstance(geometry, np.ndarray) and hasattr(geometry, "cpu"): + geometry = geometry.cpu().numpy() + x_coords, y_coords = geometry[:, 0], geometry[:, 1] + + wd_gt, wd_pred = np.asarray(wd_gt), np.asarray(wd_pred) + vx_gt, vy_gt = np.asarray(vx_gt), np.asarray(vy_gt) + vx_pred, vy_pred = np.asarray(vx_pred), np.asarray(vy_pred) + rollout_length = wd_gt.shape[0] + + # Prepare figure with a 3x2 grid + fig, axes = plt.subplots(3, 2, figsize=(12, 16), constrained_layout=True) + fig.suptitle(f"Rollout Comparison (Run: {run_id or 'unknown'})", fontsize=20) + (ax_gt_wd, ax_pred_wd), (ax_gt_vx, ax_pred_vx), (ax_gt_vy, ax_pred_vy) = axes + + # Set Color Limits + depth_max = max(np.nanmax(wd_gt), np.nanmax(wd_pred)) + vx_abs_max = np.max([np.abs(vx_gt), np.abs(vx_pred)]) + vy_abs_max = np.max([np.abs(vy_gt), np.abs(vy_pred)]) + + # Row 1: Water Depth + sc_gt_wd = ax_gt_wd.scatter(x_coords, y_coords, c=wd_gt[0], vmin=0, vmax=depth_max, s=15, cmap='viridis') + ax_gt_wd.set_title("Ground Truth Depth", pad=10) + ax_gt_wd.axis('off') + fig.colorbar(sc_gt_wd, ax=ax_gt_wd, fraction=0.046, pad=0.04).set_label("Depth (m)") + + sc_pred_wd = ax_pred_wd.scatter(x_coords, y_coords, c=wd_pred[0], vmin=0, vmax=depth_max, s=15, cmap='viridis') + ax_pred_wd.set_title("Predicted Depth", pad=10) + ax_pred_wd.axis('off') + fig.colorbar(sc_pred_wd, ax=ax_pred_wd, fraction=0.046, pad=0.04).set_label("Depth (m)") + + # Row 2: X-Velocity (Vx) + sc_gt_vx = ax_gt_vx.scatter(x_coords, y_coords, c=vx_gt[0], vmin=-vx_abs_max, vmax=vx_abs_max, s=15, + cmap='coolwarm') + ax_gt_vx.set_title(r"Ground Truth $V_{x}$", pad=10) + ax_gt_vx.axis('off') + fig.colorbar(sc_gt_vx, ax=ax_gt_vx, fraction=0.046, pad=0.04).set_label(r"$V_{x}$ (m/s)") + + sc_pred_vx = ax_pred_vx.scatter(x_coords, y_coords, c=vx_pred[0], vmin=-vx_abs_max, vmax=vx_abs_max, s=15, + cmap='coolwarm') + ax_pred_vx.set_title(r"Predicted $V_{x}$", pad=10) + ax_pred_vx.axis('off') + fig.colorbar(sc_pred_vx, ax=ax_pred_vx, fraction=0.046, pad=0.04).set_label(r"$V_{x}$ (m/s)") + + # Row 3: Y-Velocity (Vy) + sc_gt_vy = ax_gt_vy.scatter(x_coords, y_coords, c=vy_gt[0], vmin=-vy_abs_max, vmax=vy_abs_max, s=15, + cmap='coolwarm') + ax_gt_vy.set_title(r"Ground Truth $V_{y}$", pad=10) + ax_gt_vy.axis('off') + fig.colorbar(sc_gt_vy, ax=ax_gt_vy, fraction=0.046, pad=0.04).set_label(r"$V_{y}$ (m/s)") + + sc_pred_vy = ax_pred_vy.scatter(x_coords, y_coords, c=vy_pred[0], vmin=-vy_abs_max, vmax=vy_abs_max, s=15, + cmap='coolwarm') + ax_pred_vy.set_title(r"Predicted $V_{y}$", pad=10) + ax_pred_vy.axis('off') + fig.colorbar(sc_pred_vy, ax=ax_pred_vy, fraction=0.046, pad=0.04).set_label(r"$V_{y}$ (m/s)") + + # Animation update function + def animate(frame_idx): + time_hours = (frame_idx + 1) * dt_seconds / 3600.0 + fig.suptitle(f"Rollout Comparison (Run: {run_id or 'unknown'}) - Time: {time_hours:.2f} hrs", fontsize=20) + sc_gt_wd.set_array(wd_gt[frame_idx]) + sc_pred_wd.set_array(wd_pred[frame_idx]) + sc_gt_vx.set_array(vx_gt[frame_idx]) + sc_pred_vx.set_array(vx_pred[frame_idx]) + sc_gt_vy.set_array(vy_gt[frame_idx]) + sc_pred_vy.set_array(vy_pred[frame_idx]) + return sc_gt_wd, sc_pred_wd, sc_gt_vx, sc_pred_vx, sc_gt_vy, sc_pred_vy + + ani = animation.FuncAnimation(fig, animate, frames=rollout_length, interval=200, blit=False) + + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, f"{filename_prefix}_{run_id or 'unknown'}.gif") + ani.save(out_path, writer="pillow", fps=5) + plt.close(fig) + # Logging handled by caller + + +def _r_squared(y_true, y_pred): + r"""Calculate R^2 score, handling NaNs.""" + mask = ~np.isnan(y_true) & ~np.isnan(y_pred) + if not np.any(mask): + return np.nan + y_true, y_pred = y_true[mask], y_pred[mask] + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + return 1 - (ss_res / ss_tot) if ss_tot > 0 else 1.0 + + +def generate_publication_maps( + geometry, + wd_gt_array: np.ndarray, wd_pred_array: np.ndarray, + vx_gt_array: np.ndarray, vy_gt_array: np.ndarray, + vx_pred_array: np.ndarray, vy_pred_array: np.ndarray, + steps, + out_dir: str = ".", + run_id: str = None, + filename_prefix: str = "step" +): + r""" + Generates high-quality 3x4 comparison maps for specific timesteps. + + Columns: Ground Truth | Prediction | Absolute Error | Scatter Plot + + Parameters + ---------- + geometry : np.ndarray or torch.Tensor + Geometry coordinates of shape :math:`(n_{cells}, 2)`. + wd_gt_array : np.ndarray + Ground truth water depth of shape :math:`(T, n_{cells})`. + wd_pred_array : np.ndarray + Predicted water depth of shape :math:`(T, n_{cells})`. + vx_gt_array : np.ndarray + Ground truth x-velocity of shape :math:`(T, n_{cells})`. + vy_gt_array : np.ndarray + Ground truth y-velocity of shape :math:`(T, n_{cells})`. + vx_pred_array : np.ndarray + Predicted x-velocity of shape :math:`(T, n_{cells})`. + vy_pred_array : np.ndarray + Predicted y-velocity of shape :math:`(T, n_{cells})`. + steps : int or List[int] + Timestep(s) to generate maps for. + out_dir : str, optional, default="." + Output directory for maps. + run_id : str, optional + Run identifier for filename. + filename_prefix : str, optional, default="step" + Prefix for output filename. + """ + if isinstance(steps, int): + steps = [steps] + geo_np = geometry.cpu().numpy() if hasattr(geometry, "cpu") else np.asarray(geometry) + x, y = geo_np[:, 0], geo_np[:, 1] + rid = run_id or "unknown" + os.makedirs(out_dir, exist_ok=True) + plt.rc("font", family="serif", size=12) + + for t in steps: + if t < 0 or t >= wd_gt_array.shape[0]: + warnings.warn(f"Skipping invalid step {t}") + continue + + wd_gt, wd_pred = wd_gt_array[t], wd_pred_array[t] + vx_gt, vy_gt = vx_gt_array[t], vy_gt_array[t] + vx_pred, vy_pred = vx_pred_array[t], vy_pred_array[t] + err_wd, err_vx, err_vy = np.abs(wd_pred - wd_gt), np.abs(vx_pred - vx_gt), np.abs(vy_pred - vy_gt) + + dmax = max(np.nanmax(wd_gt), np.nanmax(wd_pred)) + emax_wd = np.nanmax(err_wd) + vx_abs_max = np.max([np.abs(vx_gt), np.abs(vx_pred)]) if np.any(vx_gt) or np.any(vx_pred) else 1.0 + vy_abs_max = np.max([np.abs(vy_gt), np.abs(vy_pred)]) if np.any(vy_gt) or np.any(vy_pred) else 1.0 + emax_vx, emax_vy = np.nanmax(err_vx), np.nanmax(err_vy) + + fig, axs = plt.subplots(3, 4, figsize=(24, 17), dpi=300, constrained_layout=True) + + # Populate Spatial Maps (First 3 columns) + map_panels = [ + ("(a) Ground Truth Depth", wd_gt, "viridis", 0.0, dmax, "Depth (m)"), + ("(b) Predicted Depth", wd_pred, "viridis", 0.0, dmax, "Depth (m)"), + ("(c) Depth Absolute Error", err_wd, "magma", 0.0, emax_wd, "Error (m)"), + (r"(d) Ground Truth $V_{x}$", vx_gt, "coolwarm", -vx_abs_max, vx_abs_max, r"$V_{x}$ (m/s)"), + (r"(e) Predicted $V_{x}$", vx_pred, "coolwarm", -vx_abs_max, vx_abs_max, r"$V_{x}$ (m/s)"), + (r"(f) $V_{x}$ Absolute Error", err_vx, "magma", 0.0, emax_vx, "Error (m/s)"), + (r"(g) Ground Truth $V_{y}$", vy_gt, "coolwarm", -vy_abs_max, vy_abs_max, r"$V_{y}$ (m/s)"), + (r"(h) Predicted $V_{y}$", vy_pred, "coolwarm", -vy_abs_max, vy_abs_max, r"$V_{y}$ (m/s)"), + (r"(i) $V_{y}$ Absolute Error", err_vy, "magma", 0.0, emax_vy, "Error (m/s)"), + ] + + for i, (title, data, cmap, vmin, vmax, cblabel) in enumerate(map_panels): + row, col = i // 3, i % 3 + ax = axs[row, col] + sc = ax.scatter(x, y, c=data, cmap=cmap, vmin=vmin, vmax=vmax, s=6, marker="s", linewidths=0, + rasterized=True) + ax.set_title(title, pad=8, fontsize=14) + ax.set_aspect("equal") + ax.axis("off") + cbar = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.02) + cbar.set_label(cblabel, labelpad=10, fontsize=12) + cbar.ax.tick_params(labelsize=10) + + # Populate Scatter Plots (4th column) + scatter_data = [ + ("Depth", wd_gt, wd_pred), + (r"$V_{x}$", vx_gt, vx_pred), + (r"$V_{y}$", vy_gt, vy_pred) + ] + + for i, (var_name, gt, pred) in enumerate(scatter_data): + ax = axs[i, 3] + r2 = _r_squared(gt, pred) + ax.scatter(gt, pred, alpha=0.4, s=8, rasterized=True, c='royalblue', edgecolors='none') + lims = [min(np.nanmin(gt), np.nanmin(pred)), max(np.nanmax(gt), np.nanmax(pred))] + if lims[0] < lims[1]: + ax.plot(lims, lims, 'k--', alpha=0.8, zorder=10, label="1:1 Line") + ax.set_xlim(lims) + ax.set_ylim(lims) + + ax.set_aspect('equal', 'box') + ax.set_xlabel(f"Ground Truth {var_name}") + ax.set_ylabel(f"Predicted {var_name}") + ax.set_title(f"{var_name} Correlation\n$R^2 = {r2:.3f}$") + ax.grid(True, linestyle=':', alpha=0.7) + ax.legend(loc="upper left") + + fname = f"{filename_prefix}_{rid}_t{t}.png" + out_path = os.path.join(out_dir, fname) + fig.savefig(out_path, bbox_inches="tight", pad_inches=0.1) + plt.close(fig) + # Logging handled by caller + + +def generate_max_value_maps( + geometry, + wd_gt_array: np.ndarray, wd_pred_array: np.ndarray, + vx_gt_array: np.ndarray, vy_gt_array: np.ndarray, + vx_pred_array: np.ndarray, vy_pred_array: np.ndarray, + out_dir: str = ".", + run_id: str = None, + filename_prefix: str = "max_values" +): + r""" + Generates 3x4 comparison maps of the maximum value over time for each point. + + Columns: Ground Truth | Prediction | Absolute Error | Scatter Plot + + Parameters + ---------- + geometry : np.ndarray or torch.Tensor + Geometry coordinates of shape :math:`(n_{cells}, 2)`. + wd_gt_array : np.ndarray + Ground truth water depth of shape :math:`(T, n_{cells})`. + wd_pred_array : np.ndarray + Predicted water depth of shape :math:`(T, n_{cells})`. + vx_gt_array : np.ndarray + Ground truth x-velocity of shape :math:`(T, n_{cells})`. + vy_gt_array : np.ndarray + Ground truth y-velocity of shape :math:`(T, n_{cells})`. + vx_pred_array : np.ndarray + Predicted x-velocity of shape :math:`(T, n_{cells})`. + vy_pred_array : np.ndarray + Predicted y-velocity of shape :math:`(T, n_{cells})`. + out_dir : str, optional, default="." + Output directory for maps. + run_id : str, optional + Run identifier for filename. + filename_prefix : str, optional, default="max_values" + Prefix for output filename. + """ + max_wd_gt, max_wd_pred = np.max(wd_gt_array, axis=0), np.max(wd_pred_array, axis=0) + err_max_wd = np.abs(max_wd_pred - max_wd_gt) + + max_vx_gt, max_vx_pred = np.max(vx_gt_array, axis=0), np.max(vx_pred_array, axis=0) + err_max_vx = np.abs(max_vx_pred - max_vx_gt) + + max_vy_gt, max_vy_pred = np.max(vy_gt_array, axis=0), np.max(vy_pred_array, axis=0) + err_max_vy = np.abs(max_vy_pred - max_vy_gt) + + geo_np = geometry.cpu().numpy() if hasattr(geometry, "cpu") else np.asarray(geometry) + x, y = geo_np[:, 0], geo_np[:, 1] + rid = run_id or "unknown" + os.makedirs(out_dir, exist_ok=True) + plt.rc("font", family="serif", size=12) + + dmax = max(np.nanmax(max_wd_gt), np.nanmax(max_wd_pred)) + emax_wd = np.nanmax(err_max_wd) + vx_abs_max = np.max([np.abs(max_vx_gt), np.abs(max_vx_pred)]) if np.any(max_vx_gt) or np.any(max_vx_pred) else 1.0 + vy_abs_max = np.max([np.abs(max_vy_gt), np.abs(max_vy_pred)]) if np.any(max_vy_gt) or np.any(max_vy_pred) else 1.0 + emax_vx, emax_vy = np.nanmax(err_max_vx), np.nanmax(err_max_vy) + + fig, axs = plt.subplots(3, 4, figsize=(24, 17), dpi=300, constrained_layout=True) + + # Populate Spatial Maps (First 3 columns) + map_panels = [ + ("(a) Max Ground Truth Depth", max_wd_gt, "viridis", 0.0, dmax, "Depth (m)"), + ("(b) Max Predicted Depth", max_wd_pred, "viridis", 0.0, dmax, "Depth (m)"), + ("(c) Max Depth Absolute Error", err_max_wd, "magma", 0.0, emax_wd, "Error (m)"), + (r"(d) Max Ground Truth $V_{x}$", max_vx_gt, "coolwarm", -vx_abs_max, vx_abs_max, r"$V_{x}$ (m/s)"), + (r"(e) Max Predicted $V_{x}$", max_vx_pred, "coolwarm", -vx_abs_max, vx_abs_max, r"$V_{x}$ (m/s)"), + (r"(f) Max $V_{x}$ Absolute Error", err_max_vx, "magma", 0.0, emax_vx, "Error (m/s)"), + (r"(g) Max Ground Truth $V_{y}$", max_vy_gt, "coolwarm", -vy_abs_max, vy_abs_max, r"$V_{y}$ (m/s)"), + (r"(h) Max Predicted $V_{y}$", max_vy_pred, "coolwarm", -vy_abs_max, vy_abs_max, r"$V_{y}$ (m/s)"), + (r"(i) Max $V_{y}$ Absolute Error", err_max_vy, "magma", 0.0, emax_vy, "Error (m/s)"), + ] + + for i, (title, data, cmap, vmin, vmax, cblabel) in enumerate(map_panels): + row, col = i // 3, i % 3 + ax = axs[row, col] + sc = ax.scatter(x, y, c=data, cmap=cmap, vmin=vmin, vmax=vmax, s=6, marker="s", linewidths=0, rasterized=True) + ax.set_title(title, pad=8, fontsize=14) + ax.set_aspect("equal") + ax.axis("off") + cbar = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.02) + cbar.set_label(cblabel, labelpad=10, fontsize=12) + cbar.ax.tick_params(labelsize=10) + + # Populate Scatter Plots (4th column) + scatter_data = [ + ("Max Depth", max_wd_gt, max_wd_pred), + (r"Max $V_{x}$", max_vx_gt, max_vx_pred), + (r"Max $V_{y}$", max_vy_gt, max_vy_pred) + ] + + for i, (var_name, gt, pred) in enumerate(scatter_data): + ax = axs[i, 3] + r2 = _r_squared(gt, pred) + ax.scatter(gt, pred, alpha=0.4, s=8, rasterized=True, c='royalblue', edgecolors='none') + lims = [min(np.nanmin(gt), np.nanmin(pred)), max(np.nanmax(gt), np.nanmax(pred))] + if lims[0] < lims[1]: + ax.plot(lims, lims, 'k--', alpha=0.8, zorder=10, label="1:1 Line") + ax.set_xlim(lims) + ax.set_ylim(lims) + + ax.set_aspect('equal', 'box') + ax.set_xlabel(f"Ground Truth {var_name}") + ax.set_ylabel(f"Predicted {var_name}") + ax.set_title(f"{var_name} Correlation\n$R^2 = {r2:.3f}$") + ax.grid(True, linestyle=':', alpha=0.7) + ax.legend(loc="upper left") + + fname = f"{filename_prefix}_{rid}.png" + out_path = os.path.join(out_dir, fname) + fig.savefig(out_path, bbox_inches="tight", pad_inches=0.1) + plt.close(fig) + # Logging handled by caller + + +def generate_combined_analysis_maps( + geometry, + wd_gt_array: np.ndarray, wd_pred_array: np.ndarray, + vx_gt_array: np.ndarray, vy_gt_array: np.ndarray, + vx_pred_array: np.ndarray, vy_pred_array: np.ndarray, + dt: float, + out_dir: str = ".", + run_id: str = None, + inundation_threshold: float = 0.1, +): + r""" + Calculates and plots key temporal and hazard metrics in a single 3x4 figure. + + Rows: 1. Arrival Time, 2. Inundation Duration, 3. Max Momentum Flux + Columns: Ground Truth | Predicted | Absolute Error | Scatter Plot + + Parameters + ---------- + geometry : np.ndarray or torch.Tensor + Geometry coordinates of shape :math:`(n_{cells}, 2)`. + wd_gt_array : np.ndarray + Ground truth water depth of shape :math:`(T, n_{cells})`. + wd_pred_array : np.ndarray + Predicted water depth of shape :math:`(T, n_{cells})`. + vx_gt_array : np.ndarray + Ground truth x-velocity of shape :math:`(T, n_{cells})`. + vy_gt_array : np.ndarray + Ground truth y-velocity of shape :math:`(T, n_{cells})`. + vx_pred_array : np.ndarray + Predicted x-velocity of shape :math:`(T, n_{cells})`. + vy_pred_array : np.ndarray + Predicted y-velocity of shape :math:`(T, n_{cells})`. + dt : float + Time step size in seconds. + out_dir : str, optional, default="." + Output directory for maps. + run_id : str, optional + Run identifier for filename. + inundation_threshold : float, optional, default=0.1 + Threshold for inundation classification (meters). + + Returns + ------- + Tuple[float, float, float, float] + Tuple of (mae_arrival, mae_duration, rmse_hv2, fhca). + """ + rid = run_id or "unknown" + os.makedirs(out_dir, exist_ok=True) + plt.rc("font", family="serif", size=12) + + # Calculate all required metrics + # Arrival Time + def calculate_arrival(arr, threshold, dt_val): # noqa: D401 + inundated_mask = arr >= threshold + arrival_times = (np.argmax(inundated_mask, axis=0)).astype(np.float64) * dt_val + never_inundated_mask = ~inundated_mask.any(axis=0) + arrival_times[never_inundated_mask] = np.nan + return arrival_times + + arrival_gt = calculate_arrival(wd_gt_array, inundation_threshold, dt) + arrival_pred = calculate_arrival(wd_pred_array, inundation_threshold, dt) + + # Inundation Duration + duration_gt = np.sum(wd_gt_array >= inundation_threshold, axis=0) * dt + duration_pred = np.sum(wd_pred_array >= inundation_threshold, axis=0) * dt + + # Maximum Momentum Flux (h*V^2) + v_gt = np.sqrt(vx_gt_array ** 2 + vy_gt_array ** 2) + v_pred = np.sqrt(vx_pred_array ** 2 + vy_pred_array ** 2) + hv2_gt, hv2_pred = wd_gt_array * (v_gt ** 2), wd_pred_array * (v_pred ** 2) + max_hv2_gt, max_hv2_pred = np.max(hv2_gt, axis=0), np.max(hv2_pred, axis=0) + + # Error Metrics + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + err_arrival = np.abs(arrival_pred - arrival_gt) + err_duration = np.abs(duration_pred - duration_gt) + err_max_hv2 = np.abs(max_hv2_pred - max_hv2_gt) + + # Scalar metrics for return + mae_arrival = np.nanmean(err_arrival) + mae_duration = np.nanmean(err_duration) + rmse_hv2 = np.sqrt(np.mean((max_hv2_pred - max_hv2_gt) ** 2)) + + # Hazard Classification (FHCA) - based on hV + hv_gt, hv_pred = wd_gt_array * v_gt, wd_pred_array * v_pred + max_hv_gt = np.max(hv_gt, axis=0) + zones = {'Low': (0, 0.5), 'Medium': (0.5, 1.5), 'High': (1.5, np.inf)} + + def classify(hv_values, zones_dict): # noqa: D401 + classes = np.zeros_like(hv_values, dtype=int) + classes[hv_values >= zones_dict['Medium'][0]] = 1 + classes[hv_values >= zones_dict['High'][0]] = 2 + return classes + + gt_class, pred_class = classify(max_hv_gt, zones), classify(np.max(hv_pred, axis=0), zones) + fhca = np.mean(gt_class == pred_class) + + # Setup Figure + geo_np = geometry.cpu().numpy() if hasattr(geometry, "cpu") else np.asarray(geometry) + x, y = geo_np[:, 0], geo_np[:, 1] + fig, axs = plt.subplots(3, 4, figsize=(24, 17), dpi=300, constrained_layout=True) + + # Populate Panels + to_hours = lambda sec: sec / 3600.0 if sec is not None else None + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + vmax_arrival = to_hours(np.nanmax([arrival_gt, arrival_pred])) + emax_arrival = to_hours(np.nanmean(err_arrival) * 2) + vmax_duration = to_hours(np.nanmax([duration_gt, duration_pred])) + emax_duration = to_hours(np.nanmean(err_duration) * 2) + vmax_hv2 = max(np.nanmax(max_hv2_gt), np.nanmax(max_hv2_pred)) + emax_hv2 = np.nanmax(err_max_hv2) + + # Panel Data Definitions + map_panels = [ + # Row 1: Arrival Time + ("Ground Truth Arrival Time", to_hours(arrival_gt), "plasma", 0.0, vmax_arrival, "Time (hours)"), + ("Predicted Arrival Time", to_hours(arrival_pred), "plasma", 0.0, vmax_arrival, "Time (hours)"), + ("Arrival Time Absolute Error", to_hours(err_arrival), "magma", 0.0, emax_arrival, "Error (hours)"), + # Row 2: Duration + ("Ground Truth Inundation Duration", to_hours(duration_gt), "cividis", 0.0, vmax_duration, "Time (hours)"), + ("Predicted Inundation Duration", to_hours(duration_pred), "cividis", 0.0, vmax_duration, "Time (hours)"), + ("Duration Absolute Error", to_hours(err_duration), "magma", 0.0, emax_duration, "Error (hours)"), + # Row 3: Momentum Flux + (r"Ground Truth max($h \cdot V^2$)", max_hv2_gt, "YlOrRd", 0.0, vmax_hv2, r"Momentum Flux ($m^3/s^2$)"), + (r"Predicted max($h \cdot V^2$)", max_hv2_pred, "YlOrRd", 0.0, vmax_hv2, r"Momentum Flux ($m^3/s^2$)"), + (r"max($h \cdot V^2$) Absolute Error", err_max_hv2, "Blues", 0.0, emax_hv2, r"Error ($m^3/s^2$)"), + ] + + scatter_panels = [ + ("Arrival Time", arrival_gt, arrival_pred, "hours"), + ("Inundation Duration", duration_gt, duration_pred, "hours"), + (r"max($h \cdot V^2$)", max_hv2_gt, max_hv2_pred, r"$m^3/s^2$") + ] + + # Plotting Loops + # Plot Maps + for i, (title, data, cmap, vmin, vmax, cblabel) in enumerate(map_panels): + row, col = i // 3, i % 3 + ax = axs[row, col] + if data is not None and np.any(data[~np.isnan(data)]) and np.nanmax(data) > 0: + sc = ax.scatter(x, y, c=data, cmap=cmap, vmin=vmin, vmax=vmax, s=6, marker="s", linewidths=0, + rasterized=True) + cbar = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.02) + cbar.set_label(cblabel, labelpad=10, fontsize=12) + cbar.ax.tick_params(labelsize=10) + else: + ax.text(0.5, 0.5, 'No Data', ha='center', va='center', transform=ax.transAxes) + ax.set_title(title, pad=8, fontsize=14) + ax.set_aspect("equal") + ax.axis("off") + + # Plot Scatters + for i, (var_name, gt_vals, pred_vals, unit) in enumerate(scatter_panels): + ax_scatter = axs[i, 3] + valid_indices = ~np.isnan(gt_vals) & ~np.isnan(pred_vals) + gt_plot, pred_plot = gt_vals[valid_indices], pred_vals[valid_indices] + + r2 = _r_squared(gt_plot, pred_plot) + title = f"{var_name} Correlation\n$R^2 = {r2:.3f}$" + + if len(gt_plot) > 0: + plot_gt = gt_plot / 3600.0 if unit == "hours" else gt_plot + plot_pred = pred_plot / 3600.0 if unit == "hours" else pred_plot + + ax_scatter.scatter(plot_gt, plot_pred, alpha=0.5, s=10, rasterized=True, c='blue') + lims = [min(np.min(plot_gt), np.min(plot_pred)), max(np.max(plot_gt), np.max(plot_pred))] + if lims[0] < lims[1]: + ax_scatter.plot(lims, lims, 'k--', alpha=0.75, zorder=0, label="1:1 Line") + ax_scatter.set_xlim(lims) + ax_scatter.set_ylim(lims) + ax_scatter.set_aspect('equal', 'box') + ax_scatter.set_xlabel(f"Ground Truth ({unit})") + ax_scatter.set_ylabel(f"Prediction ({unit})") + ax_scatter.set_title(title, pad=8, fontsize=14) + ax_scatter.legend(loc="upper left") + ax_scatter.grid(True, linestyle=':') + else: + ax_scatter.text(0.5, 0.5, 'No valid data points', ha='center', va='center', transform=ax_scatter.transAxes) + ax_scatter.set_title(title, pad=8, fontsize=14) + + # Save and return + out_path = os.path.join(out_dir, f"combined_analysis_{rid}.png") + fig.savefig(out_path, bbox_inches="tight") + plt.close(fig) + # Logging handled by caller + + return mae_arrival, mae_duration, rmse_hv2, fhca + + +def plot_volume_conservation(wd_gt_array, wd_pred_array, cell_area, dt, out_dir, run_id): + r""" + Calculates and plots the total inundated volume over time. + + Parameters + ---------- + wd_gt_array : np.ndarray + Ground truth water depth of shape :math:`(T, n_{cells})`. + wd_pred_array : np.ndarray + Predicted water depth of shape :math:`(T, n_{cells})`. + cell_area : np.ndarray or torch.Tensor + Cell area of shape :math:`(n_{cells},)`. + dt : float + Time step size in seconds. + out_dir : str + Output directory for plot. + run_id : str + Run identifier for filename. + """ + if cell_area is None: + warnings.warn("Skipping volume conservation plot: Cell area not available.") + return + + # Ensure the cell_area array matches the number of cells in the simulation data. + num_cells_in_sim = wd_gt_array.shape[1] + if cell_area.shape[0] != num_cells_in_sim: + warnings.warn( + f"Cell area array shape ({cell_area.shape[0]}) does not match simulation cell count ({num_cells_in_sim}). " + f"Trimming cell area array to match. Please check input data consistency." + ) + cell_area = cell_area[:num_cells_in_sim] + + # Calculate total volume at each time step + volume_gt = np.sum(wd_gt_array * cell_area, axis=1) + volume_pred = np.sum(wd_pred_array * cell_area, axis=1) + + # Create time axis + time_hours = np.arange(len(volume_gt)) * dt / 3600.0 + + # Plotting + fig, ax = plt.subplots(figsize=(12, 7), dpi=150) + ax.plot(time_hours, volume_gt, label='Ground Truth', color='black', linestyle='-') + ax.plot(time_hours, volume_pred, label='Prediction', color='red', linestyle='--') + + ax.set_xlabel('Time (hours)', fontsize=14) + ax.set_ylabel(r'Total Volume ($m^3$)', fontsize=14) + ax.legend(fontsize=12) + ax.grid(True, linestyle=':', alpha=0.7) + + plt.tight_layout() + save_path = os.path.join(out_dir, f"Total_Volume_vs_Time_{run_id}.png") + plt.savefig(save_path) + plt.close(fig) + # Logging handled by caller + + +def plot_aggregated_scalar_metrics(scalar_metrics, out_dir): + r""" + Creates and saves a box plot summary of scalar metrics aggregated over the entire test dataset. + + Parameters + ---------- + scalar_metrics : Dict[str, List[float]] + Dictionary of scalar metrics with keys: 'h_V2_rmse', 'fhca', 'arrival_mae_hrs', 'duration_mae_hrs'. + out_dir : str + Output directory for plot. + """ + labels = { + 'h_V2_rmse': r'Max $h \cdot V^2$ RMSE' + '\n' + r'($m^3/s^2$)', 'fhca': 'FHCA', + 'arrival_mae_hrs': 'Arrival MAE\n(hours)', 'duration_mae_hrs': 'Duration MAE\n(hours)', + } + + # Prepare data for boxplots + hazard_data = [ + np.array(scalar_metrics.get('h_V2_rmse', [])), + np.array(scalar_metrics.get('fhca', [])) + ] + hazard_labels = [labels['h_V2_rmse'], labels['fhca']] + + temporal_data = [ + np.array(scalar_metrics.get('arrival_mae_hrs', [])), + np.array(scalar_metrics.get('duration_mae_hrs', [])) + ] + temporal_labels = [labels['arrival_mae_hrs'], labels['duration_mae_hrs']] + + fig, axs = plt.subplots(1, 2, figsize=(12, 7), dpi=150) + fig.suptitle("Aggregated Model Performance Across All Test Simulations", fontsize=20, y=1.0) + + # Boxplot for Hazard Metrics + bp1 = axs[0].boxplot(hazard_data, vert=True, patch_artist=True, whis=1.5, labels=hazard_labels) + axs[0].set_title('Hazard Metrics', fontsize=16) + axs[0].grid(True, linestyle='--', alpha=0.6) + + # Boxplot for Temporal MAE + bp2 = axs[1].boxplot(temporal_data, vert=True, patch_artist=True, labels=temporal_labels) + axs[1].set_title('Temporal Characteristics MAE', fontsize=16) + axs[1].set_ylabel('Mean Absolute Error (hours)') + axs[1].grid(True, linestyle='--', alpha=0.6) + + # Coloring + colors = ['lightblue', 'lightgreen'] + for patch in bp1['boxes']: + patch.set_facecolor(colors[0]) + for patch in bp2['boxes']: + patch.set_facecolor(colors[1]) + + plt.tight_layout(rect=[0, 0.03, 1, 0.95]) + save_path = os.path.join(out_dir, "rollout_scalar_metrics_boxplot.png") + plt.savefig(save_path) + plt.close(fig) + # Logging handled by caller + + +def plot_conditional_error_analysis( + wd_gt_array: np.ndarray, wd_pred_array: np.ndarray, + vx_gt_array: np.ndarray, vy_gt_array: np.ndarray, + vx_pred_array: np.ndarray, vy_pred_array: np.ndarray, + out_dir: str, + run_id: str +): + r""" + Creates and saves plots for conditional error analysis. + + 1. Absolute Depth Error vs. True Water Depth + 2. Absolute Velocity Magnitude Error vs. True Velocity Magnitude + + Parameters + ---------- + wd_gt_array : np.ndarray + Ground truth water depth of shape :math:`(T, n_{cells})`. + wd_pred_array : np.ndarray + Predicted water depth of shape :math:`(T, n_{cells})`. + vx_gt_array : np.ndarray + Ground truth x-velocity of shape :math:`(T, n_{cells})`. + vy_gt_array : np.ndarray + Ground truth y-velocity of shape :math:`(T, n_{cells})`. + vx_pred_array : np.ndarray + Predicted x-velocity of shape :math:`(T, n_{cells})`. + vy_pred_array : np.ndarray + Predicted y-velocity of shape :math:`(T, n_{cells})`. + out_dir : str + Output directory for plot. + run_id : str + Run identifier for filename. + """ + # Logging handled by caller + + # Calculate required values + wd_gt_flat = wd_gt_array.flatten() + wd_pred_flat = wd_pred_array.flatten() + + # Absolute error for water depth + err_wd_abs = np.abs(wd_pred_flat - wd_gt_flat) + + # Calculate velocity magnitudes + v_mag_gt = np.sqrt(vx_gt_array ** 2 + vy_gt_array ** 2).flatten() + v_mag_pred = np.sqrt(vx_pred_array ** 2 + vy_pred_array ** 2).flatten() + + # Absolute error for velocity magnitude + err_v_mag_abs = np.abs(v_mag_pred - v_mag_gt) + + # Create the plots + fig, axs = plt.subplots(1, 2, figsize=(16, 7), dpi=150) + + # Plot 1: Depth Error vs. True Depth + ax1 = axs[0] + mask_depth = wd_gt_flat > 0.01 + ax1.scatter(wd_gt_flat[mask_depth], err_wd_abs[mask_depth], + alpha=0.1, s=5, c='blue', rasterized=True, edgecolors='none') + ax1.set_xlabel("Ground Truth Water Depth (m)", fontsize=14) + ax1.set_ylabel("Absolute Error (m)", fontsize=14) + ax1.set_title("(a) Depth Error vs. True Depth", fontsize=16) + ax1.grid(True, linestyle=':', alpha=0.7) + ax1.set_yscale('log') + ax1.set_xscale('log') + + # Plot 2: Velocity Error vs. True Velocity + ax2 = axs[1] + mask_vel = v_mag_gt > 0.01 + ax2.scatter(v_mag_gt[mask_vel], err_v_mag_abs[mask_vel], + alpha=0.1, s=5, c='green', rasterized=True, edgecolors='none') + ax2.set_xlabel("Ground Truth Velocity Magnitude (m/s)", fontsize=14) + ax2.set_ylabel("Absolute Error (m/s)", fontsize=14) + ax2.set_title("(b) Velocity Magnitude Error vs. True Velocity", fontsize=16) + ax2.grid(True, linestyle=':', alpha=0.7) + ax2.set_yscale('log') + ax2.set_xscale('log') + + plt.tight_layout(rect=[0, 0, 1, 0.95]) + save_path = os.path.join(out_dir, f"conditional_error_analysis_{run_id}.png") + plt.savefig(save_path) + plt.close(fig) + # Logging handled by caller + + +def plot_event_magnitude_analysis( + q_peaks: list, + total_volumes: list, + avg_rmses_wd: list, + out_dir: str +): + r""" + Creates and saves two separate scatter plots correlating model error with + hydrograph characteristics. + + Parameters + ---------- + q_peaks : List[float] + List of peak discharge values. + total_volumes : List[float] + List of total volume values. + avg_rmses_wd : List[float] + List of average RMSE values for water depth. + out_dir : str + Output directory for plots. + + The function creates two plots: + 1. Overall RMSE vs. Peak Inflow (Q_peak) + 2. Overall RMSE vs. Total Inflow Volume + """ + # Logging handled by caller + os.makedirs(out_dir, exist_ok=True) + + # Convert lists to numpy arrays for easier plotting + q_peaks_arr = np.array(q_peaks) + total_volumes_arr = np.array(total_volumes) + avg_rmses_arr = np.array(avg_rmses_wd) + + # Figure 1: RMSE vs. Peak Inflow + fig1, ax1 = plt.subplots(figsize=(8, 6), dpi=150) + ax1.scatter(q_peaks_arr, avg_rmses_arr, alpha=0.7, c='coral', edgecolors='k', s=50) + ax1.set_xlabel("Peak Inflow ($Q_{peak}$, $m^3/s$)", fontsize=14) + ax1.set_ylabel("Time-Averaged Water Depth RMSE (m)", fontsize=14) + ax1.grid(True, linestyle=':', alpha=0.7) + + # Add trend line + z1 = np.polyfit(q_peaks_arr, avg_rmses_arr, 1) + p1 = np.poly1d(z1) + ax1.plot(q_peaks_arr, p1(q_peaks_arr), "k--", alpha=0.8, label=f"Trend (slope={z1[0]:.4f})") + ax1.legend() + + plt.tight_layout() + save_path1 = os.path.join(out_dir, "rmse_vs_peak_inflow.png") + plt.savefig(save_path1) + plt.close(fig1) + # Logging handled by caller + + # Figure 2: RMSE vs. Total Volume + fig2, ax2 = plt.subplots(figsize=(8, 6), dpi=150) + ax2.scatter(total_volumes_arr, avg_rmses_arr, alpha=0.7, c='deepskyblue', edgecolors='k', s=50) + ax2.set_xlabel("Total Inflow Volume ($m^3$)", fontsize=14) + ax2.set_ylabel("Time-Averaged Water Depth RMSE (m)", fontsize=14) + ax2.grid(True, linestyle=':', alpha=0.7) + + # Add trend line + z2 = np.polyfit(total_volumes_arr, avg_rmses_arr, 1) + p2 = np.poly1d(z2) + ax2.plot(total_volumes_arr, p2(total_volumes_arr), "k--", alpha=0.8, label=f"Trend (slope={z2[0]:.4g})") + ax2.legend() + + plt.tight_layout() + save_path2 = os.path.join(out_dir, "rmse_vs_total_volume.png") + plt.savefig(save_path2) + plt.close(fig2) + # Logging handled by caller + diff --git a/test/datapipes/test_flood_forecaster_datasets.py b/test/datapipes/test_flood_forecaster_datasets.py new file mode 100644 index 0000000000..51eed82864 --- /dev/null +++ b/test/datapipes/test_flood_forecaster_datasets.py @@ -0,0 +1,500 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for FloodForecaster dataset classes. +""" + +import sys +from pathlib import Path + +import pytest +import torch +from torch.utils.data import Dataset + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +from datasets import ( + FloodDatasetWithQueryPoints, + FloodRolloutTestDatasetNew, + NormalizedDataset, + NormalizedRolloutTestDataset, +) + +from . import common + +Tensor = torch.Tensor + + +@pytest.fixture +def sample_tensors(): + """Create sample tensors for dataset.""" + n_samples = 10 + n_cells = 100 + n_history = 3 + + return { + "geometry": torch.rand(n_samples, n_cells, 2), + "static": torch.rand(n_samples, n_cells, 7), + "boundary": torch.rand(n_samples, n_history, n_cells, 1), + "dynamic": torch.rand(n_samples, n_history, n_cells, 3), + "target": torch.rand(n_samples, n_cells, 3), + } + + +@pytest.mark.parametrize("device", _DEVICES) +def test_normalized_dataset_constructor(sample_tensors, device): + """Test NormalizedDataset constructor and basic properties.""" + ds = NormalizedDataset( + geometry=sample_tensors["geometry"], + static=sample_tensors["static"], + boundary=sample_tensors["boundary"], + dynamic=sample_tensors["dynamic"], + target=sample_tensors["target"], + query_res=[8, 8], + ) + + common.check_datapipe_iterable(ds) + assert len(ds) == 10 + assert isinstance(ds, Dataset) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_normalized_dataset_getitem(sample_tensors, device): + """Test NormalizedDataset __getitem__ returns correct structure.""" + ds = NormalizedDataset( + geometry=sample_tensors["geometry"], + static=sample_tensors["static"], + boundary=sample_tensors["boundary"], + dynamic=sample_tensors["dynamic"], + target=sample_tensors["target"], + query_res=[8, 8], + ) + + sample = ds[0] + + assert isinstance(sample, dict) + assert "geometry" in sample + assert "static" in sample + assert "boundary" in sample + assert "dynamic" in sample + assert "target" in sample + assert "query_points" in sample + + # Check shapes + assert sample["geometry"].shape == (100, 2) + assert sample["static"].shape == (100, 7) + assert sample["boundary"].shape == (3, 100, 1) + assert sample["dynamic"].shape == (3, 100, 3) + assert sample["target"].shape == (100, 3) + assert sample["query_points"].shape == (8, 8, 2) + + +@pytest.mark.parametrize("query_res", [[4, 4], [8, 8], [16, 16]]) +def test_normalized_dataset_query_points(sample_tensors, query_res): + """Test query points generation for different resolutions.""" + ds = NormalizedDataset( + geometry=sample_tensors["geometry"], + static=sample_tensors["static"], + boundary=sample_tensors["boundary"], + dynamic=sample_tensors["dynamic"], + target=sample_tensors["target"], + query_res=query_res, + ) + + sample = ds[0] + query_points = sample["query_points"] + + assert query_points.shape == (*query_res, 2) + # Values should be in [0, 1] range (normalized coordinates) + assert query_points.min() >= 0 + assert query_points.max() <= 1 + + +@pytest.fixture +def rollout_samples(): + """Create sample rollout data.""" + n_samples = 5 + n_cells = 100 + n_timesteps = 20 + + return [ + { + "run_id": f"run_{i}", + "geometry": torch.rand(n_cells, 2), + "static": torch.rand(n_cells, 7), + "boundary": torch.rand(n_timesteps, n_cells, 1), + "dynamic": torch.rand(n_timesteps, n_cells, 3), + "cell_area": torch.rand(n_cells), + } + for i in range(n_samples) + ] + + +@pytest.mark.parametrize("device", _DEVICES) +def test_normalized_rollout_dataset_constructor(rollout_samples, device): + """Test NormalizedRolloutTestDataset constructor.""" + ds = NormalizedRolloutTestDataset(rollout_samples, query_res=[8, 8]) + + common.check_datapipe_iterable(ds) + assert len(ds) == 5 + assert isinstance(ds, Dataset) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_normalized_rollout_dataset_getitem(rollout_samples, device): + """Test NormalizedRolloutTestDataset __getitem__ returns correct structure.""" + ds = NormalizedRolloutTestDataset(rollout_samples, query_res=[8, 8]) + sample = ds[0] + + assert isinstance(sample, dict) + assert "run_id" in sample + assert "geometry" in sample + assert "static" in sample + assert "boundary" in sample + assert "dynamic" in sample + assert "query_points" in sample + assert "cell_area" in sample + + # Check run_id is preserved + assert sample["run_id"] == "run_0" + assert sample["cell_area"].shape == (100,) + + +# Tests for file-based dataset classes +@pytest.fixture +def temp_data_dir(tmp_path): + """Create a temporary data directory with mock data files.""" + import numpy as np + + data_dir = tmp_path / "test_data" + data_dir.mkdir() + + # Create train.txt + train_file = data_dir / "train.txt" + train_file.write_text("run_001\nrun_002\n") + + # Create XY file + n_cells = 100 + xy_data = np.random.rand(n_cells, 2).astype(np.float32) + xy_file = data_dir / "M40_XY.txt" + np.savetxt(xy_file, xy_data, delimiter="\t") + + # Create static files + static_files = ["M40_CA.txt", "M40_CE.txt", "M40_CS.txt", "M40_FA.txt", "M40_A.txt", "M40_CU.txt"] + for fname in static_files: + static_data = np.random.rand(n_cells, 1).astype(np.float32) + np.savetxt(data_dir / fname, static_data, delimiter="\t") + + # Create dynamic files for each run + n_timesteps = 20 + for run_id in ["run_001", "run_002"]: + for var in ["WD", "VX", "VY"]: + # Shape: (n_timesteps, n_cells) + dyn_data = np.random.rand(n_timesteps, n_cells).astype(np.float32) + fname = f"M40_{var}_{run_id}.txt" + np.savetxt(data_dir / fname, dyn_data, delimiter="\t") + + # Create boundary files + for run_id in ["run_001", "run_002"]: + # Shape: (n_timesteps, 1) or (n_timesteps, 2) + bc_data = np.random.rand(n_timesteps, 1).astype(np.float32) + fname = f"M40_US_InF_{run_id}.txt" + np.savetxt(data_dir / fname, bc_data, delimiter="\t") + + return data_dir + + +@pytest.fixture +def temp_rollout_dir(tmp_path): + """Create a temporary rollout data directory with mock data files.""" + import numpy as np + + data_dir = tmp_path / "rollout_data" + data_dir.mkdir() + + # Create test.txt + test_file = data_dir / "test.txt" + test_file.write_text("run_001\nrun_002\n") + + # Create XY file + n_cells = 100 + xy_data = np.random.rand(n_cells, 2).astype(np.float32) + xy_file = data_dir / "M40_XY.txt" + np.savetxt(xy_file, xy_data, delimiter="\t") + + # Create static files including cell area + static_files = ["M40_CA.txt", "M40_CE.txt", "M40_CS.txt", "M40_FA.txt", "M40_A.txt", "M40_CU.txt"] + for fname in static_files: + static_data = np.random.rand(n_cells, 1).astype(np.float32) + np.savetxt(data_dir / fname, static_data, delimiter="\t") + + # Create dynamic files for each run (need enough timesteps for rollout) + n_timesteps = 30 # Enough for n_history + rollout_length + for run_id in ["run_001", "run_002"]: + for var in ["WD", "VX", "VY"]: + dyn_data = np.random.rand(n_timesteps, n_cells).astype(np.float32) + fname = f"M40_{var}_{run_id}.txt" + np.savetxt(data_dir / fname, dyn_data, delimiter="\t") + + # Create boundary files + for run_id in ["run_001", "run_002"]: + bc_data = np.random.rand(n_timesteps, 1).astype(np.float32) + fname = f"M40_US_InF_{run_id}.txt" + np.savetxt(data_dir / fname, bc_data, delimiter="\t") + + return data_dir + + +@pytest.mark.parametrize("device", _DEVICES) +def test_flood_dataset_with_query_points_init(temp_data_dir, device): + """Test FloodDatasetWithQueryPoints initialization.""" + # Note: xy_file should not be included in static_files to avoid duplication + # The xy_file is loaded separately for geometry, while static_files are for additional features + dataset = FloodDatasetWithQueryPoints( + data_root=str(temp_data_dir), + n_history=3, + xy_file="M40_XY.txt", + static_files=["M40_CA.txt", "M40_CE.txt"], # Exclude xy_file to avoid duplication + dynamic_patterns={ + "WD": "M40_WD_{}.txt", + "VX": "M40_VX_{}.txt", + "VY": "M40_VY_{}.txt", + }, + boundary_patterns={"inflow": "M40_US_InF_{}.txt"}, + ) + + common.check_datapipe_iterable(dataset) + assert len(dataset) > 0 + assert dataset.xy_coords is not None + assert len(dataset.run_ids) == 2 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_flood_dataset_with_query_points_getitem(temp_data_dir, device): + """Test FloodDatasetWithQueryPoints __getitem__ returns correct structure.""" + dataset = FloodDatasetWithQueryPoints( + data_root=str(temp_data_dir), + n_history=3, + xy_file="M40_XY.txt", + static_files=["M40_CA.txt", "M40_CE.txt"], + dynamic_patterns={ + "WD": "M40_WD_{}.txt", + "VX": "M40_VX_{}.txt", + "VY": "M40_VY_{}.txt", + }, + boundary_patterns={"inflow": "M40_US_InF_{}.txt"}, + ) + + if len(dataset) > 0: + sample = dataset[0] + + assert isinstance(sample, dict) + assert "geometry" in sample + assert "static" in sample + assert "boundary" in sample + assert "dynamic" in sample + assert "target" in sample + assert "run_id" in sample + assert "time_index" in sample + + # Check shapes + assert sample["geometry"].shape == (100, 2) + assert sample["static"].shape == (100, 2) # 2 static files + assert sample["boundary"].shape == (3, 100, 1) # n_history, n_cells, bc_dim + assert sample["dynamic"].shape == (3, 100, 3) # n_history, n_cells, 3 channels + assert sample["target"].shape == (100, 3) # n_cells, 3 channels + + +@pytest.mark.parametrize("device", _DEVICES) +def test_flood_dataset_noise_types(temp_data_dir, device): + """Test different noise types in FloodDatasetWithQueryPoints.""" + noise_types = ["none", "only_last", "correlated", "uncorrelated", "random_walk"] + + for noise_type in noise_types: + dataset = FloodDatasetWithQueryPoints( + data_root=str(temp_data_dir), + n_history=3, + xy_file="M40_XY.txt", + static_files=["M40_CA.txt"], + dynamic_patterns={ + "WD": "M40_WD_{}.txt", + "VX": "M40_VX_{}.txt", + "VY": "M40_VY_{}.txt", + }, + boundary_patterns={"inflow": "M40_US_InF_{}.txt"}, + noise_type=noise_type, + noise_std=[0.01, 0.001, 0.001], + ) + + if len(dataset) > 0: + sample = dataset[0] + # Should not crash + assert "dynamic" in sample + + +def test_flood_dataset_missing_data_root(): + """Test FloodDatasetWithQueryPoints fails with missing data root.""" + with pytest.raises(FileNotFoundError): + FloodDatasetWithQueryPoints( + data_root="/nonexistent/path", + n_history=3, + xy_file="M40_XY.txt", + ) + + +def test_flood_dataset_missing_train_file(tmp_path): + """Test FloodDatasetWithQueryPoints fails with missing train.txt.""" + data_dir = tmp_path / "test_data" + data_dir.mkdir() + + with pytest.raises(FileNotFoundError, match="train.txt"): + FloodDatasetWithQueryPoints( + data_root=str(data_dir), + n_history=3, + xy_file="M40_XY.txt", + ) + + +def test_flood_dataset_invalid_noise_std(temp_data_dir): + """Test FloodDatasetWithQueryPoints validates noise_std length.""" + with pytest.raises(ValueError, match="exactly 3 floats"): + FloodDatasetWithQueryPoints( + data_root=str(temp_data_dir), + n_history=3, + xy_file="M40_XY.txt", + noise_std=[0.01, 0.001], # Only 2 values + ) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_flood_rollout_test_dataset_new_init(temp_rollout_dir, device): + """Test FloodRolloutTestDatasetNew initialization.""" + dataset = FloodRolloutTestDatasetNew( + rollout_data_root=str(temp_rollout_dir), + n_history=3, + rollout_length=10, + xy_file="M40_XY.txt", + static_files=["M40_CA.txt", "M40_CE.txt"], + dynamic_patterns={ + "WD": "M40_WD_{}.txt", + "VX": "M40_VX_{}.txt", + "VY": "M40_VY_{}.txt", + }, + boundary_patterns={"inflow": "M40_US_InF_{}.txt"}, + ) + + common.check_datapipe_iterable(dataset) + assert len(dataset) > 0 + assert dataset.xy_coords is not None + assert len(dataset.valid_run_ids) > 0 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_flood_rollout_test_dataset_new_getitem(temp_rollout_dir, device): + """Test FloodRolloutTestDatasetNew __getitem__ returns correct structure.""" + dataset = FloodRolloutTestDatasetNew( + rollout_data_root=str(temp_rollout_dir), + n_history=3, + rollout_length=10, + xy_file="M40_XY.txt", + static_files=["M40_CA.txt", "M40_CE.txt"], + dynamic_patterns={ + "WD": "M40_WD_{}.txt", + "VX": "M40_VX_{}.txt", + "VY": "M40_VY_{}.txt", + }, + boundary_patterns={"inflow": "M40_US_InF_{}.txt"}, + ) + + if len(dataset) > 0: + sample = dataset[0] + + assert isinstance(sample, dict) + assert "run_id" in sample + assert "geometry" in sample + assert "static" in sample + assert "boundary" in sample + assert "dynamic" in sample + + # Check shapes + assert sample["geometry"].shape == (100, 2) + assert sample["static"].shape == (100, 2) # 2 static files + assert sample["dynamic"].shape == (30, 100, 3) # Full time series + assert sample["boundary"].shape == (30, 100, 1) # Full time series + + +def test_flood_rollout_test_dataset_missing_data_root(): + """Test FloodRolloutTestDatasetNew fails with missing data root.""" + with pytest.raises(FileNotFoundError): + FloodRolloutTestDatasetNew( + rollout_data_root="/nonexistent/path", + n_history=3, + rollout_length=10, + xy_file="M40_XY.txt", + ) + + +def test_flood_rollout_test_dataset_insufficient_timesteps(tmp_path): + """Test that runs with insufficient timesteps are filtered out.""" + import numpy as np + + data_dir = tmp_path / "rollout_data" + data_dir.mkdir() + + test_file = data_dir / "test.txt" + test_file.write_text("run_001\n") + + xy_file = data_dir / "M40_XY.txt" + np.savetxt(xy_file, np.random.rand(100, 2), delimiter="\t") + + # Create static files + for fname in ["M40_CA.txt", "M40_CE.txt"]: + np.savetxt(data_dir / fname, np.random.rand(100, 1), delimiter="\t") + + # Create dynamic files with insufficient timesteps + n_timesteps = 5 # Too few for n_history=3 + rollout_length=10 + for var in ["WD", "VX", "VY"]: + dyn_data = np.random.rand(n_timesteps, 100).astype(np.float32) + np.savetxt(data_dir / f"M40_{var}_run_001.txt", dyn_data, delimiter="\t") + + # Create boundary file + bc_data = np.random.rand(n_timesteps, 1).astype(np.float32) + np.savetxt(data_dir / "M40_US_InF_run_001.txt", bc_data, delimiter="\t") + + with pytest.raises(ValueError, match="No hydrographs have enough time steps"): + FloodRolloutTestDatasetNew( + rollout_data_root=str(data_dir), + n_history=3, + rollout_length=10, + xy_file="M40_XY.txt", + static_files=["M40_CA.txt"], + dynamic_patterns={ + "WD": "M40_WD_{}.txt", + "VX": "M40_VX_{}.txt", + "VY": "M40_VY_{}.txt", + }, + boundary_patterns={"inflow": "M40_US_InF_{}.txt"}, + ) + diff --git a/test/inference/test_flood_forecaster_inference.py b/test/inference/test_flood_forecaster_inference.py new file mode 100644 index 0000000000..71c29e685b --- /dev/null +++ b/test/inference/test_flood_forecaster_inference.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for FloodForecaster inference module. + +This module tests the inference pipeline including checkpoint loading, +rollout prediction, and metric computation. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch +import tempfile +import shutil + +import pytest +import torch +import torch.nn as nn +import numpy as np + +import physicsnemo + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +# Import rollout functions - need to set up module structure first +import importlib.util + +# Set up inference package +if "inference" not in sys.modules: + import types + inference_pkg = types.ModuleType("inference") + sys.modules["inference"] = inference_pkg + +# Import rollout module +spec = importlib.util.spec_from_file_location("inference.rollout", _examples_dir / "inference" / "rollout.py") +rollout_module = importlib.util.module_from_spec(spec) +sys.modules["inference.rollout"] = rollout_module +spec.loader.exec_module(rollout_module) + +# Make rollout available as inference.rollout attribute +sys.modules["inference"].rollout = rollout_module + +from inference.rollout import compute_csi, rollout_prediction +from data_processing import GINOWrapper + + +@pytest.fixture +def mock_gino_model(): + """Create a mock GINO model.""" + model = MagicMock(spec=nn.Module) + model.fno_hidden_channels = 64 + model.out_channels = 3 + model.gno_coord_dim = 2 + model.in_coord_dim_reverse_order = [2, 3] # For 2D: permute dims 2,3 (H, W) + model.out_gno_tanh = None # No tanh activation + + # Mock latent_embedding (required by GINOWrapper.forward) + # Use Identity which just returns the input + model.latent_embedding = nn.Identity() + + # Mock gno_in method (required by GINOWrapper.forward) + def mock_gno_in(y, x, f_y=None): + # Return tensor with shape (n_points, channels) for flattened queries + # GINOWrapper expects gno_in to return (n_points, channels) which gets reshaped to (batch_size, H, W, channels) + n_points = x.shape[0] if x.ndim == 2 else x.shape[1] + return torch.rand(n_points, 64) # (n_points, channels) + + model.gno_in = MagicMock(side_effect=mock_gno_in) + + # Mock gno_out method (required by GINOWrapper.forward) + def mock_gno_out(y, x, f_y): + # f_y is (B, n_latent, channels), x is output queries (n_out, coord_dim) + # Return (B, channels, n_out) - will be permuted to (B, n_out, channels) in GINOWrapper + batch_size = f_y.shape[0] + n_out = x.shape[0] + return torch.rand(batch_size, 64, n_out) # (B, channels, n_out) + + model.gno_out = MagicMock(side_effect=mock_gno_out) + + # Mock projection method (required by GINOWrapper.forward) + def mock_projection(x): + # x is (B, n_out, channels) after permute in GINOWrapper + batch_size = x.shape[0] + n_out = x.shape[1] + return torch.rand(batch_size, n_out, 3) # (B, n_out, out_channels) + + model.projection = MagicMock(side_effect=mock_projection) + + # Mock forward to return predictions on correct device + def mock_forward(**kwargs): + batch_size = kwargs.get("x", torch.rand(1, 100, 10)).shape[0] + n_out = kwargs.get("output_queries", torch.rand(100, 2)).shape[0] + # Get device from x if available, otherwise use CPU + x_tensor = kwargs.get("x") + if x_tensor is not None: + device = x_tensor.device + else: + device = torch.device("cpu") + return torch.rand(batch_size, n_out, 3, device=device) + + model.forward = MagicMock(side_effect=mock_forward) + return model + + +@pytest.fixture +def mock_normalizer(): + """Create a mock normalizer.""" + norm = MagicMock() + norm.inverse_transform = MagicMock(side_effect=lambda x: x * 2.0) # Simple transform + norm.to = MagicMock(return_value=norm) + return norm + + +@pytest.fixture +def mock_rollout_dataset(): + """Create a mock rollout dataset.""" + class MockRolloutDataset: + def __init__(self): + self.valid_run_ids = ["run_001", "run_002"] + + def __len__(self): + return 2 + + def __getitem__(self, idx): + return { + "run_id": self.valid_run_ids[idx], + "geometry": torch.rand(100, 2), + "static": torch.rand(100, 7), + "boundary": torch.rand(20, 100, 1), + "dynamic": torch.rand(20, 100, 3), + "cell_area": torch.rand(100), + } + + return MockRolloutDataset() + + +@pytest.mark.parametrize("device", _DEVICES) +def test_compute_csi_perfect_match(device): + """Test CSI computation with perfect match.""" + pred = np.array([1.0, 1.0, 0.0, 0.0]) + gt = np.array([1.0, 1.0, 0.0, 0.0]) + threshold = 0.5 + + csi = compute_csi(threshold, pred, gt) + assert csi == 1.0 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_compute_csi_no_match(device): + """Test CSI computation with no match.""" + pred = np.array([1.0, 1.0, 0.0, 0.0]) + gt = np.array([0.0, 0.0, 1.0, 1.0]) + threshold = 0.5 + + csi = compute_csi(threshold, pred, gt) + assert csi == 0.0 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_compute_csi_partial_match(device): + """Test CSI computation with partial match.""" + pred = np.array([1.0, 1.0, 1.0, 0.0]) + gt = np.array([1.0, 1.0, 0.0, 0.0]) + threshold = 0.5 + + csi = compute_csi(threshold, pred, gt) + # TP=2, FP=1, FN=0, CSI = 2/(2+1+0) = 2/3 + assert abs(csi - 2.0/3.0) < 1e-6 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_compute_csi_all_zeros(device): + """Test CSI computation with all zeros (no events).""" + pred = np.array([0.0, 0.0, 0.0, 0.0]) + gt = np.array([0.0, 0.0, 0.0, 0.0]) + threshold = 0.5 + + csi = compute_csi(threshold, pred, gt) + # When TP+FP+FN=0, CSI should be 1.0 (perfect match, no events) + assert csi == 1.0 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_compute_csi_different_thresholds(device): + """Test CSI computation with different thresholds.""" + pred = np.array([0.3, 0.6, 0.2, 0.8]) + gt = np.array([0.1, 0.7, 0.4, 0.9]) + + # Low threshold + csi_low = compute_csi(0.05, pred, gt) + # High threshold + csi_high = compute_csi(0.3, pred, gt) + + # Both should be valid (0-1 range) + assert 0.0 <= csi_low <= 1.0 + assert 0.0 <= csi_high <= 1.0 + # Different thresholds should give different results + assert csi_low != csi_high or (csi_low == 1.0 and csi_high == 1.0) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_compute_csi_edge_cases(device): + """Test CSI computation with various edge cases.""" + # Single element + csi = compute_csi(0.5, np.array([1.0]), np.array([1.0])) + assert csi == 1.0 + + # All predictions above threshold, no ground truth + csi = compute_csi(0.5, np.array([1.0, 1.0]), np.array([0.0, 0.0])) + assert csi == 0.0 + + # All ground truth above threshold, no predictions + csi = compute_csi(0.5, np.array([0.0, 0.0]), np.array([1.0, 1.0])) + assert csi == 0.0 + + diff --git a/test/models/data/cnn_domain_classifier_custom_v1.0.pth b/test/models/data/cnn_domain_classifier_custom_v1.0.pth new file mode 100644 index 0000000000..fdd1f5f1a6 Binary files /dev/null and b/test/models/data/cnn_domain_classifier_custom_v1.0.pth differ diff --git a/test/models/data/cnn_domain_classifier_default_v1.0.pth b/test/models/data/cnn_domain_classifier_default_v1.0.pth new file mode 100644 index 0000000000..1bb41fff6d Binary files /dev/null and b/test/models/data/cnn_domain_classifier_default_v1.0.pth differ diff --git a/test/models/data/gradient_reversal_custom_v1.0.pth b/test/models/data/gradient_reversal_custom_v1.0.pth new file mode 100644 index 0000000000..4d3a8aabb9 Binary files /dev/null and b/test/models/data/gradient_reversal_custom_v1.0.pth differ diff --git a/test/models/data/gradient_reversal_default_v1.0.pth b/test/models/data/gradient_reversal_default_v1.0.pth new file mode 100644 index 0000000000..97f288f078 Binary files /dev/null and b/test/models/data/gradient_reversal_default_v1.0.pth differ diff --git a/test/models/test_flood_forecaster_data_processing.py b/test/models/test_flood_forecaster_data_processing.py new file mode 100644 index 0000000000..bef189bfb9 --- /dev/null +++ b/test/models/test_flood_forecaster_data_processing.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for FloodForecaster data processing modules. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +import physicsnemo + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +from data_processing import FloodGINODataProcessor, GINOWrapper, LpLossWrapper + +from . import common + + +# Define MockGINOModelForCheckpoint at module level so it can be properly loaded from checkpoint +# Note: Name doesn't start with "Test" to avoid pytest collection +# Register it in the model registry to ensure it can be found when loading checkpoints +class MockGINOModelForCheckpoint(physicsnemo.Module): + """Simple test GINO model for checkpoint testing.""" + def __init__(self): + super().__init__(meta=physicsnemo.ModelMetaData(name="MockGINOForCheckpoint")) + self.fno_hidden_channels = 64 + self.out_channels = 3 + self.gno_coord_dim = 2 + self.latent_feature_channels = None + self.in_coord_dim_reverse_order = [2, 3] + self.out_gno_tanh = None + + # Create minimal layers for the model to work + self.gno_in = nn.Linear(2, 64) + self.gno_out = nn.Linear(64, 64) + self.projection = nn.Linear(64, 3) + self.latent_embedding = nn.Identity() + +# Register MockGINOModelForCheckpoint in the model registry so it can be loaded from checkpoints +try: + registry = physicsnemo.registry.ModelRegistry() + if "MockGINOModelForCheckpoint" not in registry.list_models(): + registry.register(MockGINOModelForCheckpoint, "MockGINOModelForCheckpoint") +except (ValueError, AttributeError): + # Already registered or registry issue - continue + pass + + +@pytest.fixture +def sample_dict(): + """Create sample dictionary for preprocessing.""" + batch_size = 2 + num_cells = 100 + n_history = 3 + # query_points should be (B, H, W, 2) or (H, W, 2) for latent queries + # Using a simple grid: (8, 8, 2) for 2D + H, W = 8, 8 + + return { + "geometry": torch.rand(batch_size, num_cells, 2), + "static": torch.rand(batch_size, num_cells, 7), + "boundary": torch.rand(batch_size, n_history, num_cells, 1), + "dynamic": torch.rand(batch_size, n_history, num_cells, 3), + "target": torch.rand(batch_size, num_cells, 3), + "query_points": torch.rand(batch_size, H, W, 2), # Required for preprocessing + } + + +@pytest.fixture +def mock_gino_model(): + """Create a mock GINO model with internal components.""" + model = MagicMock(spec=nn.Module) + model.fno_hidden_channels = 64 + model.out_channels = 3 + model.gno_coord_dim = 2 + model.latent_feature_channels = None + model.in_coord_dim_reverse_order = [2, 3] # For 2D: [2, 3] means permute dims 2,3 + model.out_gno_tanh = None + + # Mock internal components - gno_in returns flattened output + def mock_gno_in(y, x, f_y=None): + # Returns (n_points, channels) for flattened queries + # GINOWrapper reshapes this to (batch_size, H, W, channels) + n_points = x.shape[0] + return torch.rand(n_points, 64) # (n_points, channels) + + model.gno_in = MagicMock(side_effect=mock_gno_in) + + # Mock gno_out - returns (B, channels, n_out) after permute + def mock_gno_out(y, x, f_y): + # f_y is (B, n_latent, channels) + batch_size = f_y.shape[0] + n_out = x.shape[0] + return torch.rand(batch_size, 64, n_out) # (B, channels, n_out) + + model.gno_out = MagicMock(side_effect=mock_gno_out) + + # Mock projection - returns (B, n_out, out_channels) + def mock_projection(x): + # x is (B, n_out, channels) after permute in GINOWrapper + batch_size = x.shape[0] + n_out = x.shape[1] + return torch.rand(batch_size, n_out, 3) # (B, n_out, out_channels) + + model.projection = MagicMock(side_effect=mock_projection) + + # Mock latent_embedding to return a tensor + def mock_latent_embedding(in_p, ada_in=None): + # in_p shape: (B, H, W, channels) + # Return shape: (B, channels, H, W) for 2D + batch_size = in_p.shape[0] + grid_shape = in_p.shape[1:-1] # (H, W) + return torch.rand(batch_size, 64, *grid_shape) + + model.latent_embedding = mock_latent_embedding + + return model + + +@pytest.mark.parametrize("device", _DEVICES) +def test_data_processor_init(device): + """Test FloodGINODataProcessor initialization.""" + processor = FloodGINODataProcessor(device=device) + # device property returns torch.device, so compare string representations + assert str(processor.device) == str(torch.device(device)) + assert processor.target_norm is None + assert processor.inverse_test is True + assert processor.model is None + assert isinstance(processor, nn.Module) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_data_processor_preprocess(sample_dict, device): + """Test preprocessing with batched input.""" + processor = FloodGINODataProcessor(device=device) + result = processor.preprocess(sample_dict) + + # Check required keys exist + assert "input_geom" in result + assert "latent_queries" in result + assert "output_queries" in result + assert "x" in result + assert "y" in result + + # Check geometry has no batch dim (shared) + assert result["input_geom"].dim() == 2 # (n_cells, 2) + assert result["latent_queries"].dim() == 3 # (H, W, 2) + assert result["output_queries"].dim() == 2 # (n_cells, 2) + + # Check x has batch dim + assert result["x"].dim() == 3 # (B, n_cells, features) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_data_processor_postprocess(device): + """Test postprocessing in training and eval modes.""" + mock_norm = MagicMock() + mock_norm.inverse_transform = MagicMock(return_value=torch.ones(2, 100, 3) * 2) + + processor = FloodGINODataProcessor(device=device, target_norm=mock_norm, inverse_test=True) + out = torch.ones(2, 100, 3) + sample = {"y": torch.ones(2, 100, 3)} + + # Training mode: no inverse transform + processor.train() + result_out, _ = processor.postprocess(out, sample) + mock_norm.inverse_transform.assert_not_called() + + # Eval mode: applies inverse transform + processor.eval() + result_out, _ = processor.postprocess(out, sample) + assert torch.allclose(result_out, torch.ones(2, 100, 3) * 2) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_ginowrapper_init(mock_gino_model, device): + """Test GINOWrapper initialization.""" + wrapper = GINOWrapper(mock_gino_model) + # After conversion, gino is wrapped in CustomPhysicsNeMoWrapper + # Check that the inner model is accessible + assert hasattr(wrapper.gino, 'inner_model') or wrapper.gino == mock_gino_model + assert isinstance(wrapper, nn.Module) + assert wrapper.fno_hidden_channels == 64 + assert wrapper.autoregressive is False # Default value + + +@pytest.mark.parametrize("device", _DEVICES) +def test_ginowrapper_forward(mock_gino_model, device): + """Test GINOWrapper forward pass with kwargs filtering and autoregressive mode.""" + wrapper = GINOWrapper(mock_gino_model, autoregressive=True) + + input_geom = torch.rand(1, 100, 2) + latent_queries = torch.rand(1, 8, 8, 2) + output_queries = torch.rand(1, 100, 2) + x = torch.rand(1, 100, 10) + + # Mock the internal GNO components + mock_gino_model.gno_in.return_value = torch.rand(8 * 8, 64) # (n_points, channels) + mock_gino_model.gno_out.return_value = torch.rand(1, 64, 100) # (B, channels, n_out) + def mock_projection(x): + # x is (B, n_out, channels), should return (B, n_out, out_channels) + return torch.zeros(x.shape[0], x.shape[1], 3) # (B, n_out, out_channels) + mock_gino_model.projection.side_effect = mock_projection + + # Test forward with extra kwargs (should be filtered) + result = wrapper( + input_geom=input_geom, + latent_queries=latent_queries, + output_queries=output_queries, + x=x, + y=torch.rand(1, 100, 3), # Should be filtered out + extra_arg="should be ignored", + ) + + # Verify result shape and autoregressive residual + assert isinstance(result, torch.Tensor) + assert result.shape == (1, 100, 3) + expected = x[..., -3:] # Last 3 channels (autoregressive residual) + assert torch.allclose(result, expected, atol=1e-5) + + # Test return_features + out, features = wrapper( + input_geom=input_geom, + latent_queries=latent_queries, + output_queries=output_queries, + x=x, + return_features=True, + ) + assert isinstance(features, torch.Tensor) + assert features.shape == (1, 64, 8, 8) # (B, channels, H, W) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_lploss_wrapper(device): + """Test LpLossWrapper filters kwargs correctly.""" + mock_loss = MagicMock(return_value=torch.tensor(0.5)) + wrapper = LpLossWrapper(mock_loss) + + y_pred = torch.rand(2, 100, 3) + y = torch.rand(2, 100, 3) + + # Extra kwargs should be ignored + wrapper(y_pred, y=y, input_geom=torch.rand(100, 2), extra_arg="ignored") + + # Should only pass y_pred and y + mock_loss.assert_called_once_with(y_pred, y) + + +def _instantiate_model(cls, seed: int = 0, **kwargs): + """Helper to create model with reproducible parameters.""" + model = cls(**kwargs) + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + with torch.no_grad(): + for param in model.parameters(): + param.copy_(torch.randn(param.shape, generator=gen, dtype=param.dtype)) + return model + + + + +@pytest.mark.parametrize("device", _DEVICES) +def test_ginowrapper_from_checkpoint(device, mock_gino_model): + """Test loading GINOWrapper from checkpoint and verify outputs.""" + from pathlib import Path + import physicsnemo + + # Use the module-level MockGINOModelForCheckpoint class for checkpoint testing + # This ensures the class can be properly loaded from checkpoint + gino_model = MockGINOModelForCheckpoint() + + # Create a model and save checkpoint + model_orig = GINOWrapper(gino_model, autoregressive=False).to(device) + checkpoint_path = Path("checkpoint_gino_wrapper.mdlus") + model_orig.save(str(checkpoint_path)) + + # Load from checkpoint - use strict=False to handle potential state dict mismatches + # The nested TestGINOModel should be properly reconstructed via module path or registry + model = physicsnemo.Module.from_checkpoint(str(checkpoint_path), strict=False).to(device) + + # Verify attributes after loading + assert model.autoregressive is False + assert isinstance(model, GINOWrapper) + # Verify the wrapped model was loaded correctly with all layers + # Note: The model structure is preserved even if class type isn't exactly TestGINOModel + # (physicsnemo may load it as a generic Module if the class can't be imported) + assert hasattr(model.gino, 'gno_in') + assert hasattr(model.gino, 'gno_out') + assert hasattr(model.gino, 'projection') + assert hasattr(model.gino, 'latent_embedding') + # Verify the layers have the correct structure + assert isinstance(model.gino.gno_in, nn.Linear) + assert isinstance(model.gino.gno_out, nn.Linear) + assert isinstance(model.gino.projection, nn.Linear) + + # Cleanup + checkpoint_path.unlink(missing_ok=True) + + + diff --git a/test/models/test_flood_forecaster_integration.py b/test/models/test_flood_forecaster_integration.py new file mode 100644 index 0000000000..a51829b098 --- /dev/null +++ b/test/models/test_flood_forecaster_integration.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Integration tests for FloodForecaster training and inference pipeline. + +This module tests end-to-end workflows including: +- Pretraining → Domain Adaptation pipeline +- Checkpoint save/load across stages +- Model state consistency +- Inference pipeline integration +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch +import tempfile +import shutil + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +import physicsnemo + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +# Import modules explicitly to avoid conflicts +import importlib.util + +# First, set up the training package structure +if "training" not in sys.modules: + import types + training_pkg = types.ModuleType("training") + sys.modules["training"] = training_pkg + +# Import trainer module +spec = importlib.util.spec_from_file_location("training.trainer", _examples_dir / "training" / "trainer.py") +trainer_module = importlib.util.module_from_spec(spec) +sys.modules["training.trainer"] = trainer_module +spec.loader.exec_module(trainer_module) + +from data_processing import GINOWrapper, FloodGINODataProcessor +from training.trainer import NeuralOperatorTrainer + + +@pytest.fixture +def simple_gino_model(device): + """Create a simple GINO-like model for testing.""" + class SimpleGINOModel(nn.Module): + def __init__(self): + super().__init__() + self.fno_hidden_channels = 64 + self.out_channels = 3 + self.gno_coord_dim = 2 + self.latent_embedding = nn.Identity() + self.in_coord_dim_reverse_order = [2, 3] # For 2D: permute dims 2,3 (H, W) + self.out_gno_tanh = None # No tanh activation + # Create proper gno_in and gno_out that accept GINO signature + self._gno_in_linear = nn.Linear(2, 64) + self._gno_out_linear = nn.Linear(64, 64) + self.projection = nn.Linear(64, 3) + + def gno_in(self, y, x, f_y=None): + """GNO input block - accepts (y, x, f_y=None).""" + # x is flattened queries (n_points, coord_dim), y is geometry, f_y is optional features + # GINOWrapper expects gno_in to return (n_points, channels) which gets reshaped to (batch_size, H, W, channels) + # Return (n_points, channels) - will be reshaped to (batch_size, H, W, channels) + n_points = x.shape[0] + # Process queries through linear layer: (n_points, coord_dim) -> (n_points, channels) + features = self._gno_in_linear(x) # (n_points, channels) + return features # (n_points, channels) + + def gno_out(self, y, x, f_y): + """GNO output block - accepts (y, x, f_y).""" + # f_y is (B, n_latent, channels), x is output queries (n_out, coord_dim), y is latent queries + # In real GINO, this would query/interpolate features from f_y at locations x + # For testing, we'll process f_y features and return correct shape + batch_size = f_y.shape[0] + n_out = x.shape[0] + # Process features: take mean over latent dimension and expand to output queries + # f_y: (B, n_latent, channels) -> mean -> (B, channels) -> expand -> (B, channels, n_out) + features = f_y.mean(dim=1) # (B, channels) + features = features.unsqueeze(-1).expand(-1, -1, n_out) # (B, channels, n_out) + # Apply linear transformation + # Permute to (B, n_out, channels) for linear, then back to (B, channels, n_out) + features_perm = features.permute(0, 2, 1) # (B, n_out, channels) + out = self._gno_out_linear(features_perm) # (B, n_out, channels) + return out.permute(0, 2, 1) # (B, channels, n_out) + + def forward(self, input_geom, latent_queries, output_queries, x, **kwargs): + # Simple forward pass + batch_size = x.shape[0] if x.dim() > 1 else 1 + if output_queries.dim() == 2: + n_out = output_queries.shape[0] + else: + n_out = output_queries.shape[1] + return torch.rand(batch_size, n_out, 3) + + return SimpleGINOModel().to(device) + + +@pytest.fixture +def sample_train_data(device): + """Create sample training data.""" + batch_size = 4 + n_samples = 16 + n_cells = 50 + + samples = [] + for i in range(n_samples): + samples.append({ + "geometry": torch.rand(n_cells, 2).to(device), + "static": torch.rand(n_cells, 7).to(device), + "boundary": torch.rand(3, n_cells, 1).to(device), + "dynamic": torch.rand(3, n_cells, 3).to(device), + "target": torch.rand(n_cells, 3).to(device), + "query_points": torch.rand(8, 8, 2).to(device), + }) + + return samples + + +@pytest.fixture +def train_loader(sample_train_data): + """Create training dataloader.""" + class DictDataset: + def __init__(self, samples): + self.samples = samples + self.dataset = self # For compatibility + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + dataset = DictDataset(sample_train_data) + return DataLoader(dataset, batch_size=4, shuffle=False) + + +@pytest.fixture +def val_loader(sample_train_data): + """Create validation dataloader.""" + class DictDataset: + def __init__(self, samples): + self.samples = samples[:8] # Smaller validation set + self.dataset = self + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + dataset = DictDataset(sample_train_data) + return DataLoader(dataset, batch_size=4, shuffle=False) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_pretrain_to_adapt_pipeline(simple_gino_model, train_loader, val_loader, device, tmp_path): + """Test full pretraining → domain adaptation pipeline.""" + # Step 1: Pretraining + model = GINOWrapper(simple_gino_model).to(device) + data_processor = FloodGINODataProcessor(device=device).to(device) + data_processor.wrap(model) + + trainer = NeuralOperatorTrainer( + model=model, + n_epochs=2, + device=device, + data_processor=data_processor, + verbose=False, + eval_interval=1, + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) + + from neuralop.losses import LpLoss + training_loss = LpLoss(d=2, p=2) + eval_losses = {"l2": LpLoss(d=2, p=2)} + + pretrain_dir = tmp_path / "pretrain" + pretrain_dir.mkdir() + + # Pretrain + pretrain_metrics = trainer.train( + train_loader=train_loader, + test_loaders={"val": val_loader}, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + eval_losses=eval_losses, + save_dir=str(pretrain_dir), + save_best="val_l2", + ) + + assert isinstance(pretrain_metrics, dict) + assert "val_l2" in pretrain_metrics + + # Step 2: Domain Adaptation (simplified - just verify checkpoint can be loaded) + # Create new model and load from pretrain checkpoint + new_model = GINOWrapper(simple_gino_model).to(device) + new_data_processor = FloodGINODataProcessor(device=device).to(device) + new_data_processor.wrap(new_model) + + new_trainer = NeuralOperatorTrainer( + model=new_model, + n_epochs=1, + device=device, + data_processor=new_data_processor, + verbose=False, + eval_interval=1, + ) + + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-3) + new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.9) + + adapt_dir = tmp_path / "adapt" + adapt_dir.mkdir() + + # Resume from pretrain checkpoint and train (domain adaptation) + adapt_metrics = new_trainer.train( + train_loader=train_loader, + test_loaders={"val": val_loader}, + optimizer=new_optimizer, + scheduler=new_scheduler, + training_loss=training_loss, + eval_losses=eval_losses, + save_dir=str(adapt_dir), + resume_from_dir=str(pretrain_dir), + ) + + assert isinstance(adapt_metrics, dict) + # Verify that training resumed (start_epoch > 0 or metrics were computed) + assert new_trainer.start_epoch > 0 or "val_l2" in adapt_metrics + + +@pytest.mark.parametrize("device", _DEVICES) +def test_checkpoint_compatibility_physicsnemo_format(simple_gino_model, device, tmp_path): + """Test checkpoint save/load with PhysicsNeMo format and state consistency.""" + # Create and save model + model1 = GINOWrapper(simple_gino_model, autoregressive=True).to(device) + checkpoint_path = tmp_path / "test_checkpoint.mdlus" + + # Save checkpoint + model1.save(str(checkpoint_path)) + assert checkpoint_path.exists() + + # Load checkpoint using PhysicsNeMo's from_checkpoint + # Note: PhysicsNeMo's from_checkpoint reconstructs the model from saved _args, + # which may result in a different structure than the original + model2 = physicsnemo.Module.from_checkpoint(str(checkpoint_path), strict=False).to(device) + assert isinstance(model2, GINOWrapper) + assert hasattr(model2, 'autoregressive') + assert model2.autoregressive == True + + # Verify state consistency - check that model was loaded correctly + # The loaded model structure may differ when using PhysicsNeMo's from_checkpoint, + # so just verify it's a valid GINOWrapper with the expected structure + assert hasattr(model2, 'gino') + assert isinstance(model2.gino, (physicsnemo.models.Module, torch.nn.Module)) + + # Verify that the model structure is valid + # PhysicsNeMo's from_checkpoint may reconstruct the model with a different + # internal structure, but it should still be a valid GINOWrapper + # We don't verify parameter consistency here because the reconstruction + # process may not preserve the exact same parameter structure + # The important thing is that the checkpoint can be saved and loaded, + # and the loaded model is a valid GINOWrapper instance + + + + +@pytest.mark.parametrize("device", _DEVICES) +def test_trainer_with_ginowrapper(simple_gino_model, train_loader, val_loader, device): + """Test NeuralOperatorTrainer with GINOWrapper model.""" + model = GINOWrapper(simple_gino_model).to(device) + + # Add data processor to handle data preprocessing + data_processor = FloodGINODataProcessor(device=device).to(device) + data_processor.wrap(model) + + trainer = NeuralOperatorTrainer( + model=model, + n_epochs=1, + device=device, + data_processor=data_processor, + verbose=False, + eval_interval=1, + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) + + from neuralop.losses import LpLoss + training_loss = LpLoss(d=2, p=2) + eval_losses = {"l2": LpLoss(d=2, p=2)} + + # Train + metrics = trainer.train( + train_loader=train_loader, + test_loaders={"val": val_loader}, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + eval_losses=eval_losses, + ) + + assert isinstance(metrics, dict) + assert "val_l2" in metrics + diff --git a/test/models/test_flood_forecaster_trainer.py b/test/models/test_flood_forecaster_trainer.py new file mode 100644 index 0000000000..698633e6b8 --- /dev/null +++ b/test/models/test_flood_forecaster_trainer.py @@ -0,0 +1,376 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for FloodForecaster NeuralOperatorTrainer class. + +This module tests the PhysicsNeMo-style Trainer class that handles neural operator training +with checkpointing, distributed training, and evaluation support. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch +import tempfile +import shutil + +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset + +import physicsnemo + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +# Import modules explicitly to avoid conflicts +import importlib.util + +# First, set up the training package structure +if "training" not in sys.modules: + import types + training_pkg = types.ModuleType("training") + sys.modules["training"] = training_pkg + +# Import trainer module +spec = importlib.util.spec_from_file_location("training.trainer", _examples_dir / "training" / "trainer.py") +trainer_module = importlib.util.module_from_spec(spec) +sys.modules["training.trainer"] = trainer_module +spec.loader.exec_module(trainer_module) + +from training.trainer import NeuralOperatorTrainer, _has_pytorch_submodules, save_model_checkpoint +from data_processing import GINOWrapper + +from . import common + + +@pytest.fixture +def simple_model(device): + """Create a simple test model that accepts kwargs.""" + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 3) + ) + + def forward(self, x=None, **kwargs): + # Accept x as positional or keyword, ignore other kwargs + if x is None: + # Try to get x from kwargs + x = kwargs.get('x', None) + if x is None: + raise ValueError("x must be provided") + return self.layers(x) + + return SimpleModel().to(device) + + +@pytest.fixture +def mock_data_processor(device): + """Create a mock data processor.""" + processor = MagicMock(spec=nn.Module) + processor.preprocess = MagicMock(side_effect=lambda x: x) + processor.postprocess = MagicMock(side_effect=lambda out, sample: (out, sample)) + processor.train = MagicMock() + processor.eval = MagicMock() + processor.to = MagicMock(return_value=processor) + return processor + + +@pytest.fixture +def sample_dataset(device): + """Create a sample dataset for training.""" + batch_size = 4 + n_samples = 20 + n_features = 10 + n_outputs = 3 + + # Create sample data + x = torch.rand(n_samples, n_features).to(device) + y = torch.rand(n_samples, n_outputs).to(device) + + # Create dataset with dict format + samples = [{"x": x[i], "y": y[i]} for i in range(n_samples)] + return samples + + +@pytest.fixture +def train_loader(sample_dataset, device): + """Create a training dataloader.""" + class DictDataset: + def __init__(self, samples): + self.samples = samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + dataset = DictDataset(sample_dataset) + return DataLoader(dataset, batch_size=4, shuffle=False) + + +@pytest.fixture +def test_loader(sample_dataset, device): + """Create a test dataloader.""" + class DictDataset: + def __init__(self, samples): + self.samples = samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + dataset = DictDataset(sample_dataset[:10]) # Smaller test set + return DataLoader(dataset, batch_size=4, shuffle=False) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_trainer_init(simple_model, mock_data_processor, device): + """Test NeuralOperatorTrainer initialization with various configurations.""" + # Basic init + trainer = NeuralOperatorTrainer( + model=simple_model, + n_epochs=10, + device=device, + verbose=False, + ) + assert trainer.model == simple_model + assert trainer.n_epochs == 10 + assert str(trainer.device) == str(torch.device(device)) + assert trainer.mixed_precision is False + assert trainer.data_processor is None + + # With data processor + trainer2 = NeuralOperatorTrainer( + model=simple_model, + n_epochs=10, + device=device, + data_processor=mock_data_processor, + verbose=False, + ) + assert trainer2.data_processor == mock_data_processor + + # With mixed precision + trainer3 = NeuralOperatorTrainer( + model=simple_model, + n_epochs=10, + device=device, + mixed_precision=True, + verbose=False, + ) + assert trainer3.mixed_precision is True + assert trainer3.scaler is not None + + +@pytest.mark.parametrize("device", _DEVICES) +def test_trainer_train_one_epoch(simple_model, train_loader, test_loader, device): + """Test training for one epoch.""" + trainer = NeuralOperatorTrainer( + model=simple_model, + n_epochs=1, + device=device, + verbose=False, + eval_interval=1, + ) + + optimizer = torch.optim.Adam(simple_model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) + + from neuralop.losses import LpLoss + training_loss = LpLoss(d=2, p=2) + eval_losses = {"l2": LpLoss(d=2, p=2)} + + # Train for one epoch + metrics = trainer.train( + train_loader=train_loader, + test_loaders={"val": test_loader}, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + eval_losses=eval_losses, + save_dir=None, # Don't save checkpoints in test + ) + + # Check that metrics were returned + assert isinstance(metrics, dict) + assert "train_err" in metrics + assert "val_l2" in metrics + assert trainer.epoch == 0 # After 1 epoch, epoch should be 0 (0-indexed) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_trainer_checkpoint_save_load(simple_model, train_loader, test_loader, device, tmp_path): + """Test checkpoint saving and loading.""" + trainer = NeuralOperatorTrainer( + model=simple_model, + n_epochs=2, + device=device, + verbose=False, + eval_interval=1, + ) + + optimizer = torch.optim.Adam(simple_model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) + + from neuralop.losses import LpLoss + training_loss = LpLoss(d=2, p=2) + eval_losses = {"l2": LpLoss(d=2, p=2)} + + save_dir = tmp_path / "checkpoints" + save_dir.mkdir() + + # Train for one epoch and save checkpoint + trainer.train( + train_loader=train_loader, + test_loaders={"val": test_loader}, + optimizer=optimizer, + scheduler=scheduler, + training_loss=training_loss, + eval_losses=eval_losses, + save_dir=str(save_dir), + save_best="val_l2", + ) + + # Check that checkpoint files were created + checkpoint_files = list(save_dir.glob("checkpoint.*.pt")) + assert len(checkpoint_files) > 0, "Checkpoint file should be created" + + # Create new trainer and resume from checkpoint + new_model = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, 3) + ).to(device) + + new_trainer = NeuralOperatorTrainer( + model=new_model, + n_epochs=2, + device=device, + verbose=False, + eval_interval=1, + ) + + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-3) + new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.9) + + # Resume from checkpoint + new_trainer.train( + train_loader=train_loader, + test_loaders={"val": test_loader}, + optimizer=new_optimizer, + scheduler=new_scheduler, + training_loss=training_loss, + eval_losses=eval_losses, + save_dir=str(save_dir), + resume_from_dir=str(save_dir), + ) + + # Check that training resumed + assert new_trainer.start_epoch > 0 or new_trainer.epoch >= 0 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_trainer_training_features(simple_model, mock_data_processor, train_loader, test_loader, device, tmp_path): + """Test training with various features: mixed precision, data processor, best model tracking.""" + from neuralop.losses import LpLoss + + # Test mixed precision training + trainer1 = NeuralOperatorTrainer( + model=simple_model, + n_epochs=1, + device=device, + mixed_precision=True, + verbose=False, + eval_interval=1, + ) + optimizer1 = torch.optim.Adam(simple_model.parameters(), lr=1e-3) + scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=1, gamma=0.9) + training_loss = LpLoss(d=2, p=2) + eval_losses = {"l2": LpLoss(d=2, p=2)} + + metrics1 = trainer1.train( + train_loader=train_loader, + test_loaders={"val": test_loader}, + optimizer=optimizer1, + scheduler=scheduler1, + training_loss=training_loss, + eval_losses=eval_losses, + ) + assert isinstance(metrics1, dict) + assert trainer1.scaler is not None + + # Test with data processor + trainer2 = NeuralOperatorTrainer( + model=simple_model, + n_epochs=1, + device=device, + data_processor=mock_data_processor, + verbose=False, + eval_interval=1, + ) + optimizer2 = torch.optim.Adam(simple_model.parameters(), lr=1e-3) + scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=1, gamma=0.9) + metrics2 = trainer2.train( + train_loader=train_loader, + test_loaders={"val": test_loader}, + optimizer=optimizer2, + scheduler=scheduler2, + training_loss=training_loss, + eval_losses=eval_losses, + ) + assert isinstance(metrics2, dict) + + # Test best model tracking + save_dir = tmp_path / "checkpoints" + save_dir.mkdir() + trainer3 = NeuralOperatorTrainer( + model=simple_model, + n_epochs=2, + device=device, + verbose=False, + eval_interval=1, + ) + optimizer3 = torch.optim.Adam(simple_model.parameters(), lr=1e-3) + scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=1, gamma=0.9) + trainer3.train( + train_loader=train_loader, + test_loaders={"val": test_loader}, + optimizer=optimizer3, + scheduler=scheduler3, + training_loss=training_loss, + eval_losses=eval_losses, + save_dir=str(save_dir), + save_best="val_l2", + ) + assert trainer3.best_metric_value < float("inf") + checkpoint_files = list(save_dir.glob("checkpoint.*.pt")) + assert len(checkpoint_files) > 0 + diff --git a/test/models/test_flood_forecaster_training.py b/test/models/test_flood_forecaster_training.py new file mode 100644 index 0000000000..b949f2fd6e --- /dev/null +++ b/test/models/test_flood_forecaster_training.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for FloodForecaster training modules. +""" + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +import physicsnemo + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +# Import modules explicitly to avoid conflicts with other utils modules +import importlib.util + +# First, set up the training package structure +if "training" not in sys.modules: + import types + training_pkg = types.ModuleType("training") + sys.modules["training"] = training_pkg + +# Import trainer module FIRST (pretraining depends on it) +spec = importlib.util.spec_from_file_location("training.trainer", _examples_dir / "training" / "trainer.py") +trainer_module = importlib.util.module_from_spec(spec) +sys.modules["training.trainer"] = trainer_module +spec.loader.exec_module(trainer_module) + +# Import pretraining (depends on trainer) +spec = importlib.util.spec_from_file_location("training.pretraining", _examples_dir / "training" / "pretraining.py") +pretraining_module = importlib.util.module_from_spec(spec) +sys.modules["training.pretraining"] = pretraining_module +spec.loader.exec_module(pretraining_module) + +# Import domain_adaptation (depends on pretraining, but not on __init__) +# Use full module path to ensure __module__ attribute is set correctly for checkpoint loading +spec = importlib.util.spec_from_file_location("training.domain_adaptation", _examples_dir / "training" / "domain_adaptation.py") +domain_adaptation_module = importlib.util.module_from_spec(spec) +sys.modules["training.domain_adaptation"] = domain_adaptation_module +spec.loader.exec_module(domain_adaptation_module) + +# Fix __module__ attributes for classes to ensure checkpoint loading works +for name in dir(domain_adaptation_module): + obj = getattr(domain_adaptation_module, name) + if isinstance(obj, type) and issubclass(obj, (torch.nn.Module, physicsnemo.Module)): + obj.__module__ = "training.domain_adaptation" + +# Now import __init__ (it can safely import from domain_adaptation and pretraining) +spec = importlib.util.spec_from_file_location("training", _examples_dir / "training" / "__init__.py") +training_init = importlib.util.module_from_spec(spec) +sys.modules["training"].__dict__.update(training_init.__dict__) +spec.loader.exec_module(training_init) + +from training.domain_adaptation import ( + CNNDomainClassifier, + DomainAdaptationTrainer, + GradientReversal, + GradientReversalFunction, +) +from training.pretraining import create_scheduler + + +@pytest.mark.parametrize("device", _DEVICES) +@pytest.mark.parametrize("scheduler_type", ["StepLR", "CosineAnnealingLR", "ReduceLROnPlateau"]) +def test_create_scheduler(device, scheduler_type): + """Test scheduler creation for different types.""" + model = nn.Linear(10, 10).to(device) + optimizer = torch.optim.Adam(model.parameters()) + + config = MagicMock() + config.training = MagicMock() + if scheduler_type == "StepLR": + config.training.get = lambda key, default=None: { + "scheduler": "StepLR", + "step_size": 10, + "gamma": 0.5, + }.get(key, default) + expected_type = torch.optim.lr_scheduler.StepLR + elif scheduler_type == "CosineAnnealingLR": + config.training.get = lambda key, default=None: { + "scheduler": "CosineAnnealingLR", + "scheduler_T_max": 100, + }.get(key, default) + expected_type = torch.optim.lr_scheduler.CosineAnnealingLR + else: # ReduceLROnPlateau + config.training.get = lambda key, default=None: { + "scheduler": "ReduceLROnPlateau", + "gamma": 0.5, + "scheduler_patience": 5, + }.get(key, default) + expected_type = torch.optim.lr_scheduler.ReduceLROnPlateau + + scheduler = create_scheduler(optimizer, config) + assert isinstance(scheduler, expected_type) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_gradient_reversal(device): + """Test gradient reversal forward and backward passes.""" + grl = GradientReversal(lambda_max=1.0) + x = torch.rand(4, 10, requires_grad=True).to(device).detach().requires_grad_(True) + + # Forward: identity + y = grl(x) + assert torch.allclose(x, y) + + # Backward: negates gradient + loss = y.sum() + loss.backward() + assert x.grad is not None + assert torch.allclose(x.grad, -torch.ones_like(x.grad)) + + +@pytest.fixture +def da_config(): + """Create mock DA config that supports .get() method, 'in' operator, and [] access.""" + # Use a dict-like object that supports both attribute access and dict-like access + class DictLike: + def __init__(self): + self.conv_layers = [ + {"out_channels": 16, "kernel_size": 3, "pool_size": 2}, + {"out_channels": 32, "kernel_size": 3, "pool_size": 2}, + ] + self.fc_dim = 1 + + def get(self, key, default=None): + """Support .get() method for dict-like access.""" + return getattr(self, key, default) + + def __contains__(self, key): + """Support 'in' operator for dict-like access.""" + return hasattr(self, key) + + def __getitem__(self, key): + """Support [] subscript access for dict-like access.""" + return getattr(self, key) + + return DictLike() + + +@pytest.mark.parametrize("device", _DEVICES) +def test_cnn_domain_classifier_init(da_config, device): + """Test CNN domain classifier initialization.""" + classifier = CNNDomainClassifier(in_channels=64, lambda_max=1.0, da_cfg=da_config).to(device) + + assert isinstance(classifier, nn.Module) + assert hasattr(classifier, "grl") + assert isinstance(classifier.grl, GradientReversal) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_cnn_domain_classifier_forward(da_config, device): + """Test CNN domain classifier forward pass.""" + classifier = CNNDomainClassifier(in_channels=64, lambda_max=1.0, da_cfg=da_config).to(device) + + x = torch.rand(4, 64, 16, 16).to(device) # (B, C, H, W) + y = classifier(x) + + assert y.shape == (4, 1) # (B, fc_dim) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_domain_adaptation_trainer_init(device): + """Test DomainAdaptationTrainer initialization.""" + mock_model = MagicMock(spec=nn.Module) + mock_model.to = MagicMock(return_value=mock_model) + mock_classifier = MagicMock(spec=nn.Module) + mock_classifier.to = MagicMock(return_value=mock_classifier) + mock_processor = MagicMock() + mock_processor.to = MagicMock(return_value=mock_processor) + + trainer = DomainAdaptationTrainer( + model=mock_model, + data_processor=mock_processor, + domain_classifier=mock_classifier, + device=device, + ) + + assert trainer.model == mock_model + assert trainer.domain_classifier == mock_classifier + assert trainer.data_processor == mock_processor + assert trainer.device == device + + +@pytest.mark.parametrize("device", _DEVICES) +def test_gradient_reversal_lambda(device): + """Test GradientReversal lambda setting and scaling.""" + grl = GradientReversal(lambda_max=1.0) + grl.set_lambda(0.5) + assert grl.lambda_ == 0.5 + + # Test lambda scales gradient + grl2 = GradientReversal(lambda_max=0.5) + x2 = torch.ones(4, 10, requires_grad=True).to(device).detach().requires_grad_(True) + y2 = grl2(x2) + loss2 = y2.sum() + loss2.backward() + assert x2.grad is not None + assert torch.allclose(x2.grad, -0.5 * torch.ones_like(x2.grad)) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_create_scheduler_unknown_raises(device): + """Test that unknown scheduler raises ValueError.""" + model = nn.Linear(10, 10).to(device) + optimizer = torch.optim.Adam(model.parameters()) + + config = MagicMock() + config.training = MagicMock() + config.training.get = lambda key, default=None: { + "scheduler": "UnknownScheduler", + }.get(key, default) + + with pytest.raises(ValueError, match="Unknown scheduler"): + create_scheduler(optimizer, config) + + + + +def _instantiate_model(cls, seed: int = 0, **kwargs): + """Helper to create model with reproducible parameters.""" + model = cls(**kwargs) + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + with torch.no_grad(): + for param in model.parameters(): + param.copy_(torch.randn(param.shape, generator=gen, dtype=param.dtype)) + return model + + +@pytest.mark.parametrize("device", _DEVICES) +def test_checkpoint_save_load(device): + """Test checkpoint save/load for training components.""" + import physicsnemo + from pathlib import Path + + # Test GradientReversal checkpoint + grl_orig = GradientReversal(lambda_max=1.0).to(device) + grl_path = Path("checkpoint_grl.mdlus") + grl_orig.save(str(grl_path)) + grl_loaded = physicsnemo.Module.from_checkpoint(str(grl_path)).to(device) + assert grl_loaded.lambda_ == 1.0 + assert isinstance(grl_loaded, GradientReversal) + grl_path.unlink(missing_ok=True) + + # Test CNNDomainClassifier checkpoint + da_config = { + "conv_layers": [ + {"out_channels": 16, "kernel_size": 3, "pool_size": 2}, + {"out_channels": 32, "kernel_size": 3, "pool_size": 2}, + ], + "fc_dim": 1 + } + classifier_orig = CNNDomainClassifier(in_channels=64, lambda_max=1.0, da_cfg=da_config).to(device) + classifier_path = Path("checkpoint_classifier.mdlus") + classifier_orig.save(str(classifier_path)) + classifier_loaded = physicsnemo.Module.from_checkpoint(str(classifier_path)).to(device) + assert classifier_loaded.grl.lambda_ == 1.0 + assert isinstance(classifier_loaded, CNNDomainClassifier) + classifier_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/test/utils/test_flood_forecaster_utils.py b/test/utils/test_flood_forecaster_utils.py new file mode 100644 index 0000000000..03d4e3c9f7 --- /dev/null +++ b/test/utils/test_flood_forecaster_utils.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for FloodForecaster utility modules. +""" + +import sys +from pathlib import Path + +import pytest +import torch + +# Add the FloodForecaster example to the path +_examples_dir = Path(__file__).parent.parent.parent / "examples" / "weather" / "flood_modeling" / "flood_forecaster" +if str(_examples_dir) not in sys.path: + sys.path.insert(0, str(_examples_dir)) + +from utils.normalization import ( + collect_all_fields, + stack_and_fit_transform, + transform_with_existing_normalizers, +) + + +# Conditionally include CUDA in device parametrization only if available +_DEVICES = ["cpu"] +if torch.cuda.is_available(): + _DEVICES.append("cuda:0") + + +@pytest.mark.parametrize("device", _DEVICES) +def test_collect_all_fields_with_target(device): + """Test collecting fields with target.""" + # Create mock dataset + mock_dataset = [ + { + "geometry": torch.rand(100, 2).to(device), + "static": torch.rand(100, 7).to(device), + "boundary": torch.rand(3, 100, 1).to(device), + "dynamic": torch.rand(3, 100, 3).to(device), + "target": torch.rand(100, 3).to(device), + } + for _ in range(5) + ] + + geom, static, boundary, dyn, target = collect_all_fields(mock_dataset, expect_target=True) + + assert len(geom) == 5 + assert len(static) == 5 + assert len(boundary) == 5 + assert len(dyn) == 5 + assert len(target) == 5 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_collect_all_fields_without_target(device): + """Test collecting fields without target.""" + mock_dataset = [ + { + "geometry": torch.rand(100, 2).to(device), + "static": torch.rand(100, 7).to(device), + "boundary": torch.rand(3, 100, 1).to(device), + "dynamic": torch.rand(3, 100, 3).to(device), + } + for _ in range(5) + ] + + result = collect_all_fields(mock_dataset, expect_target=False) + + assert len(result) == 5 # geom, static, boundary, dyn, target (empty) + + +@pytest.mark.parametrize("device", _DEVICES) +def test_collect_all_fields_with_cell_area(device): + """Test collecting fields with cell_area.""" + mock_dataset = [ + { + "geometry": torch.rand(100, 2).to(device), + "static": torch.rand(100, 7).to(device), + "boundary": torch.rand(3, 100, 1).to(device), + "dynamic": torch.rand(3, 100, 3).to(device), + "target": torch.rand(100, 3).to(device), + "cell_area": torch.rand(100).to(device), + } + for _ in range(5) + ] + + result = collect_all_fields(mock_dataset, expect_target=True) + + assert len(result) == 6 # Includes cell_area + assert len(result[5]) == 5 # 5 cell_area tensors + + +@pytest.mark.parametrize("device", _DEVICES) +def test_stack_and_fit_transform_creates_normalizers(device): + """Test that stack_and_fit_transform creates normalizers.""" + geom = [torch.rand(100, 2).to(device) for _ in range(5)] + static = [torch.rand(100, 7).to(device) for _ in range(5)] + boundary = [torch.rand(3, 100, 1).to(device) for _ in range(5)] + dyn = [torch.rand(3, 100, 3).to(device) for _ in range(5)] + target = [torch.rand(100, 3).to(device) for _ in range(5)] + + normalizers, big_tensors = stack_and_fit_transform(geom, static, boundary, dyn, target) + + assert "geometry" in normalizers + assert "static" in normalizers + assert "boundary" in normalizers + assert "target" in normalizers + assert "dynamic" in normalizers + + assert big_tensors["geometry"].shape[0] == 5 + assert big_tensors["static"].shape[0] == 5 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_stack_and_fit_transform_uses_existing(device): + """Test using existing normalizers.""" + geom = [torch.rand(100, 2).to(device) for _ in range(5)] + static = [torch.rand(100, 7).to(device) for _ in range(5)] + boundary = [torch.rand(3, 100, 1).to(device) for _ in range(5)] + dyn = [torch.rand(3, 100, 3).to(device) for _ in range(5)] + target = [torch.rand(100, 3).to(device) for _ in range(5)] + + # First pass - fit normalizers + normalizers, _ = stack_and_fit_transform( + geom, static, boundary, dyn, target, fit_normalizers=True + ) + + # Second pass - use existing + new_geom = [torch.rand(100, 2).to(device) for _ in range(3)] + new_static = [torch.rand(100, 7).to(device) for _ in range(3)] + new_boundary = [torch.rand(3, 100, 1).to(device) for _ in range(3)] + new_dyn = [torch.rand(3, 100, 3).to(device) for _ in range(3)] + new_target = [torch.rand(100, 3).to(device) for _ in range(3)] + + _, new_tensors = stack_and_fit_transform( + new_geom, + new_static, + new_boundary, + new_dyn, + new_target, + normalizers=normalizers, + fit_normalizers=False, + ) + + assert new_tensors["geometry"].shape[0] == 3 + + +@pytest.mark.parametrize("device", _DEVICES) +def test_transform_with_existing_normalizers(device): + """Test transform_with_existing_normalizers.""" + # Create and fit normalizers + geom = [torch.rand(100, 2).to(device) for _ in range(5)] + static = [torch.rand(100, 7).to(device) for _ in range(5)] + boundary = [torch.rand(3, 100, 1).to(device) for _ in range(5)] + dyn = [torch.rand(3, 100, 3).to(device) for _ in range(5)] + target = [torch.rand(100, 3).to(device) for _ in range(5)] + + normalizers, _ = stack_and_fit_transform(geom, static, boundary, dyn, target) + + # Transform new data + new_geom = [torch.rand(100, 2).to(device) for _ in range(3)] + new_static = [torch.rand(100, 7).to(device) for _ in range(3)] + new_boundary = [torch.rand(3, 100, 1).to(device) for _ in range(3)] + new_dyn = [torch.rand(3, 100, 3).to(device) for _ in range(3)] + + transformed = transform_with_existing_normalizers( + new_geom, new_static, new_boundary, new_dyn, normalizers + ) + + assert "geometry" in transformed + assert "static" in transformed + assert "boundary" in transformed + assert "dynamic" in transformed + + assert transformed["geometry"].shape[0] == 3 +