Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f5a53ef
Add HealDA to PhysicsNeMo subtree
aayushg55 Jan 26, 2026
380434a
Add healda integration to physicsnemo
aayushg55 Jan 26, 2026
dbf153a
delete file
aayushg55 Jan 26, 2026
fc97759
Fix license headers
aayushg55 Jan 27, 2026
1321d06
more linting
aayushg55 Jan 27, 2026
ef4111c
more linting
aayushg55 Jan 27, 2026
3ac1ad9
more linting
aayushg55 Jan 27, 2026
87e0716
arxiv link
aayushg55 Jan 27, 2026
3ec5ee8
readme cleanup
aayushg55 Jan 27, 2026
70d610f
skip cuda tests when unavailable
aayushg55 Jan 27, 2026
e505a28
Merge commit '66aa539ef6df4e7f9388aecee9fecfff46f8676f' into pnm-inte…
aayushg55 Jan 27, 2026
53b1144
Cleanup to ETL and new FCN3 forecasting example
aayushg55 Jan 28, 2026
4928fc9
update readme
aayushg55 Jan 28, 2026
13f377d
Fixed training and inference pipelines
aayushg55 Jan 28, 2026
7bc5ae0
Merge commit 'bfe710511f88ce6c78e8b1ca6900ee073535ec46' into pnm-inte…
aayushg55 Jan 28, 2026
ea0b4f3
Remove acc scoring and other unneeded files
aayushg55 Jan 28, 2026
c412c8a
HealDA/DiT architecture refactor
aayushg55 Feb 3, 2026
fc57101
Merge branch 'main' of github.com:NVIDIA/physicsnemo into healda
aayushg55 Feb 3, 2026
e6178a9
HealDA working with timm backend, debugging TE errors
aayushg55 Feb 3, 2026
0ad266c
fixed te issue - qkv_layout difference
aayushg55 Feb 3, 2026
d5aa7a0
Merge branch 'main' of github.com:NVIDIA/physicsnemo into healda
aayushg55 Feb 3, 2026
c11ac9c
cleanup documentation
aayushg55 Feb 4, 2026
2ff6109
Merge branch 'main' of github.com:NVIDIA/physicsnemo into healda
aayushg55 Mar 16, 2026
62d0c38
delete old tests
aayushg55 Mar 16, 2026
19bd5fa
update healda readme with ckpt/e2studio links
aayushg55 Mar 16, 2026
7a296a7
update readme
aayushg55 Mar 16, 2026
1755f50
update license headers
aayushg55 Mar 16, 2026
d3394ec
update inference
aayushg55 Mar 16, 2026
b27151c
update base
aayushg55 Mar 16, 2026
e54de49
update readme
aayushg55 Mar 16, 2026
da9f14d
cleanup model config
aayushg55 Mar 16, 2026
34ba868
cleanup dataset
aayushg55 Mar 16, 2026
96cdef8
update license
aayushg55 Mar 16, 2026
cc1ad80
update transform and loading
aayushg55 Mar 16, 2026
d400fcf
update model setup to use new pnm model
aayushg55 Mar 16, 2026
fd741ec
use pnm checkpointing
aayushg55 Mar 17, 2026
24bb714
Merge branch 'main' of github.com:NVIDIA/physicsnemo into healda
aayushg55 Mar 17, 2026
7e227d9
fix license headers
aayushg55 Mar 17, 2026
71eff55
fix batch keys
aayushg55 Mar 17, 2026
e43eeec
ensure conv gets 1 platform
aayushg55 Mar 17, 2026
2d70390
update model config
aayushg55 Mar 17, 2026
aa63f5a
reduce logging
aayushg55 Mar 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/img/healda.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
203 changes: 203 additions & 0 deletions examples/weather/healda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
<!-- markdownlint-disable -->
# HealDA: Highlighting the importance of initial errors in end-to-end AI weather forecasts

<p align="center">
<img src="../../../docs/img/healda.png" width="800"/>
</p>

[📄 arXiv](https://arxiv.org/abs/2601.17636) · 📦 Checkpoints (coming soon)

---

## Problem Overview

Machine-learning (ML) weather models now rival leading numerical weather prediction (NWP) systems in medium-range skill. However, almost all still rely on NWP data assimilation (DA) to provide initial conditions, tying them to expensive infrastructure and limiting the practical speed and accuracy gains of ML.

**HealDA** is a global ML-based data assimilation system that maps satellite and conventional observations (microwave sounders, aircraft, radiosondes, surface stations) to a 1° atmospheric state on the HEALPix grid. HealDA analyses can initialize off-the-shelf ML forecast models (e.g., FourCastNet3, Aurora, FengWu) without fine-tuning, enabling end-to-end ML weather forecasting with less than one day loss of skill compared to ERA5 initialization.

---

## Installation

### Using uv

```bash
# 1. From PhysicsNeMo root directory
cd /path/to/physicsnemo

# 2. Create .venv and install PNM
uv sync

# 3. Activate the virtual environment
source .venv/bin/activate

# 4. Install earth2grid
uv pip install setuptools hatchling
uv pip install --no-build-isolation \
"earth2grid @ https://github.com/NVlabs/earth2grid/archive/main.tar.gz"

# 5. Install healda dependencies
uv pip install -r examples/weather/healda/requirements.txt
```

### Using pip

```bash
# 1. Install PhysicsNeMo
pip install nvidia-physicsnemo

# 2. Install earth2grid
pip install setuptools hatchling
pip install --no-build-isolation https://github.com/NVlabs/earth2grid/archive/main.tar.gz

# 3. Install healda dependencies
pip install -r requirements.txt
```

> **Warning:** Include `--no-build-isolation` when installing earth2grid to avoid building against the wrong PyTorch version.

---

## Configuration

Create a `.env` file in the `examples/weather/healda/` directory with the following:

```bash
# Project paths
PROJECT_ROOT=/path/to/project

# Raw observation data (NC4 files downloaded from NOAA S3)
UFS_RAW_OBS_DIR=/path/to/raw_obs

# Processed observation data (parquet from ETL)
UFS_OBS_PATH=/path/to/processed_obs
# UFS_OBS_PROFILE=

# ERA5 HEALPix zarr (training targets)
V6_ERA5_ZARR=/path/to/era5_hpx.zarr
# V6_ERA5_ZARR_PROFILE=

# Land fraction mask
UFS_LAND_DATA_ZARR=/path/to/land_frac.zarr
# UFS_LAND_DATA_PROFILE=
```

> **Note:** The `*_PROFILE` variables configure [rclone](https://rclone.org/) S3 profiles for cloud storage access. Leave empty for local paths.

---

## Data Preparation

HealDA requires preprocessed observation data and ERA5 target fields. We source observational data from the [NOAA Unified Forecast System (UFS) GEFSv13 Replay dataset](https://psl.noaa.gov/data/ufs_replay/) (NOAA, 2024).

See [`datasets/etl/`](datasets/etl/) for ETL scripts to prepare observation data into a parquet data format.

---

## Training

```bash
python train.py --name era5-v2-dense-noInfill-10M-fusion512-lrObs1e-4
```

This uses the paper configuration defined in `train.py`. See `python train.py --help` for options.

> **Resource Requirements:** Training takes approximately **8.3 days on 1 H100 node** (8 GPUs total) with batch size 1 per GPU.

---

## Inference

### Step 1: Generate DA Analysis (Initial Conditions)

The following produces analyses for all of 2022. `See inference_helpers.py` to configure inference. Inference only requires ~20GB of memory and can produce an analysis in under 1 second on a single H100.
```bash
python inference.py \
/path/to/checkpoint.pt \
--output_path /path/to/da_output.zarr \
--context_start -21 \
--context_end 3 \
--time_frequency 6h \
--num_samples -1 \
--batch_gpu 1
```

### Step 2: Forecast from HealDA initial conditions

#### Installing FCN3 dependencies

FCN3 requires `earth2studio`. Recommended to install torch-harmonics with CUDA extensions for best performance:

```bash
# Using uv
export FORCE_CUDA_EXTENSION=1
uv pip install torch-harmonics==0.8.0 --no-build-isolation
uv pip install earth2studio[fcn3]

# Or using pip
export FORCE_CUDA_EXTENSION=1
pip install torch-harmonics==0.8.0 --no-build-isolation
pip install earth2studio[fcn3]
```

> **Note:** See [Earth2Studio docs](https://nvidia.github.io/earth2studio/userguide/about/install.html) for more information or installing other forecast models beyond FCN3.

#### Running forecasts

Use the DA output to initialize the FCN3 forecast model and create a 10-day forecast (40 6-hour steps):

```bash
python scripts/forecast.py \
--init_path /path/to/da_output.zarr \
--out_dir /path/to/forecast_output \
--model FCN3 \
--num_steps 40 \
--num_ensemble 1 \
--num_times 1
```

> **Note:** The forecast script:
> - Regrids HealDA analysis (HPX64) → 0.25° lat-lon for FCN3 input
> - Regrids FCN3 output (0.25° lat-lon) → HPX64 NEST format for storage

> **ERA5-initialized forecasts:** To create forecasts from ERA5 instead of DA output, run `inference.py` with `--use_analysis` flag to create an ERA5 zarr in the same format, then use that as `--init_path`.


### Step 3: Score Forecasts

Score forecasts against a reference dataset (also on HPX64 grid):

```bash
python scripts/score_forecast.py \
--forecast_path /path/to/forecast.zarr \
--reference_path /path/to/era5.zarr \
--output_path /path/to/scores.nc
```

To plot the metrics:

```bash
python scripts/plot_panel.py \
--stats /path/to/scores.nc \
--labels "HealDA-initialized FCN3" \
--metric crps \
--output_path /path/to/plots/crps_comparison.pdf
```

See `python inference.py --help` and `python scripts/forecast.py --help` for full options.

---

## Citation

```bibtex
@misc{gupta2026healdahighlightingimportanceinitial,
title={HealDA: Highlighting the importance of initial errors in end-to-end AI weather forecasts},
author={Aayush Gupta and Akshay Subramaniam and Michael S. Pritchard and Karthik Kashinath and Sergey Frolov and Kelsey Lieberman and Christopher Miller and Nicholas Silverman and Noah D. Brenowitz},
year={2026},
eprint={2601.17636},
archivePrefix={arXiv},
primaryClass={physics.ao-ph},
url={https://arxiv.org/abs/2601.17636},
}
```
Empty file.
64 changes: 64 additions & 0 deletions examples/weather/healda/config/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import dotenv

dotenv.load_dotenv(dotenv.find_dotenv(usecwd=True))

non_config = dir()

CACHE_DIR = os.path.expanduser("~/.cache/healda")

###############
# ERA5 inputs #
###############
V6_ERA5_ZARR = os.getenv("V6_ERA5_ZARR", "")
V6_ERA5_ZARR_PROFILE = os.getenv("V6_ERA5_ZARR_PROFILE", "")

########
# UFS #
########
UFS_HPX6_ZARR = os.getenv("UFS_HPX6_ZARR", "")
UFS_LAND_DATA_ZARR = os.getenv("UFS_LAND_DATA_ZARR", "")
UFS_LAND_DATA_PROFILE = os.getenv("UFS_LAND_DATA_PROFILE", "")
UFS_ZARR_PROFILE = os.getenv("UFS_ZARR_PROFILE", "")
UFS_OBS_PATH = os.getenv("UFS_OBS_PATH", "")
UFS_OBS_PROFILE = os.getenv("UFS_OBS_PROFILE", "")
# project file
PROJECT_ROOT = os.getenv("PROJECT_ROOT", "")
DATA_ROOT = os.getenv("DATA_ROOT", os.path.join(PROJECT_ROOT, "datasets"))
CHECKPOINT_ROOT = os.getenv(
"CHECKPOINT_ROOT", os.path.join(PROJECT_ROOT, "training-runs")
)


_config_vars = dict(vars())


def print_config():
print("Environment settings:")
print("-" * 80)
for v in _config_vars:
if v == "non_config":
continue

if v in non_config:
continue

value = _config_vars[v]
print(f"{v}={value}")
print("-" * 80)
Empty file.
75 changes: 75 additions & 0 deletions examples/weather/healda/config/training/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import Optional


@dataclasses.dataclass
class TrainingLoopBase:
"""Base training config"""

run_dir: str = "." # Output directory.
seed: int = 0 # Global random seed.
batch_size: int = 512 # Total batch size for one training iteration.
batch_gpu: Optional[int] = None # Limit batch size per GPU, None = no limit.
enable_ema: bool = False
ema_halflife_kimg: int = (
500 # Half-life of the exponential moving average (EMA) of model weights.
)
ema_rampup_ratio: float = 0.05 # EMA ramp-up coefficient, None = no rampup.
lr_rampup_img: int = 10_000 # Learning rate ramp-up duration.
flat_imgs: int = 1_500_000 - 10_000
decay_imgs: int = 1_500_000
lr_min: float = 1e-6
lr: float = 1e-4

loss_reduction: str = "v1"
"""
Controls how the [b c t x] shaped loss is reduced, where 'b' is the

Options:
- v1 (default) - sum over c x, mean over b c
- mean - mean over all dimensions
"""

loss_scaling: float = 1.0 # Loss scaling factor for reducing FP16 under/overflows.
gradient_clip_max_norm: Optional[float] = None
total_ticks: int = 10
print_steps: int = 50
steps_per_tick: int = 1024
snapshot_ticks: int | None = (
50 # How often to save network snapshots, None = disable.
)
state_dump_ticks: int | None = (
500 # How often to dump training state, None = disable.
)

test_with_single_batch: bool = False
"""Only load a single batch of data for testing and profiling purposes"""

# Performance optimizations
# Mixed precision and performance options
cudnn_benchmark: bool = True # Enable torch.backends.cudnn.benchmark?
tf32: bool = True
bf16: bool = True
compile_optimizer: bool = False # if true wrap the optimizer with torch compile

# wandb
wandb_id: str | None = None # will be read from checkpoint if not provided

# logging
log_parameter_norm: bool = False
log_parameter_grad_norm: bool = False
Empty file.
Loading