Skip to content

Add GraphCast (1 degree model) #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
May 22, 2025
Merged

Conversation

rodrigoalmeida94
Copy link
Contributor

@rodrigoalmeida94 rodrigoalmeida94 commented Apr 11, 2025

Earth2Studio Pull Request

Description

Add support for GraphCast model (small). In addition, adds ARCOExtra data (to produce relative humidity and accumulated precipitation on 6h intervals).
Closes #199

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

ToDos:

  • Fix mypy
  • Add explicit device management (torch to jax)

Dev notes

  • This is quite a hacky implementation based loosely on https://github.com/ecmwf-lab/ai-models-graphcast and https://github.com/ankurmahesh/earth2mip-fork/blob/70fec6cbb388ae46591a5827c5e60621b728f602/earth2mip/networks/graphcast.py
  • I tried to use as much of the original graphcast logic: this results in a weird data flow, going from (torch.array, coords) -> xr.Dataset (creating the variables structure + generating forcings) -> running inference with the xr.Dataset in JAX -> convert back to earth2studio (torch.array, coords) convention
  • In order to make use of the graphcast data handling utilities, I had to save the number of forecast steps as class instance variable, which is then used to create the data + iterator to produce forecasts. This means that an extra step is required when running the predictions (e.g. run.deterministic(["2019-02-01"], nsteps, model.set_nsteps(nsteps), data, io), which is definitely not ideal.
  • I also made use of the ARCOLexicon (with the extra derived variables like relative humidity and total_precipitation_6hr) since graphcast expects the "long" variable names.

@NickGeneva NickGeneva added the 1 - On Deck To be worked on next label Apr 11, 2025
@NickGeneva
Copy link
Collaborator

Hi @rodrigoalmeida94

This is awesome, thanks so much for adding this!
We'll start taking a look at this next week, there may be some edits coming in as we review on our side.

@loliverhennigh
Copy link
Collaborator

Hi @rodrigoalmeida94

Finishing up some other work but will start working on this next week. We are planning to add the 0.25 degree model in as well. Ill start pushing changes to your branch starting tomorrow if thats ok.

@rodrigoalmeida94
Copy link
Contributor Author

Hi @rodrigoalmeida94

Finishing up some other work but will start working on this next week. We are planning to add the 0.25 degree model in as well. Ill start pushing changes to your branch starting tomorrow if thats ok.

@loliverhennigh sounds good! I wanted to add the device management (for jax), going to try to add something for that today or tomorrow.

@nbren12
Copy link

nbren12 commented Apr 22, 2025

FYI. in case it's helpful, there is a graphcast runner here that works with the 0.25 deg models: https://github.com/NVIDIA/earth2mip/blob/main/earth2mip/networks/graphcast.py.

I got it to work with a custom tisr implementation in physicsnemo so it can beyond the duration of the input data. It's been used for inferences in a few of our papers.

@loliverhennigh
Copy link
Collaborator

Hey @rodrigoalmeida94 , I have some changes for this PR. Could you give me push access to your fork? If not I will fork off it and make another PR.

@rodrigoalmeida94
Copy link
Contributor Author

@loliverhennigh just did, hope it worked?

@loliverhennigh
Copy link
Collaborator

/blossom-ci

@loliverhennigh
Copy link
Collaborator

/blossom-ci

@rodrigoalmeida94
Copy link
Contributor Author

@loliverhennigh the good news is that yes, I can run all my perturbation stuff with GraphCastMini class, everything works smoothly, which is great 🥳

The bad news is that I was trying to check the predictions against the original implementation as you suggested, and it's not adding up. See notebook here https://github.com/rodrigoalmeida94/earth2studio/blob/check-graphcast/check_graphcast.ipynb

My check here was to make use of the example batch data that the GraphCast repo provides and compute the predictions using the original methods and our implementation. The differences in the predictions are quite large (up to 15K in t2m) so something must be off - I was thinking maybe we are dealing with the lead times somehow wrong, but honestly not really sure how. Maybe I mixed up something in the notebook?

@rodrigoalmeida94
Copy link
Contributor Author

Okay so good news: I tested this again using the WB2ERA5 data source and now the predictions are the same as in the original repository (notebook 1 and 2).

I think the previous notebooks I was dealing with the lead times wrong (because I was using a local copy of ARCO, which was only for 2022).

@loliverhennigh
Copy link
Collaborator

/blossom-ci

@loliverhennigh
Copy link
Collaborator

/blossom-ci

@NickGeneva NickGeneva added the ! - Release PRs or Issues releating to a release label May 20, 2025
@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

NickGeneva commented May 21, 2025

Cross checked the implementation with my own validation script based on @rodrigoalmeida94 notebooks and also looking at the GC repo

import os

os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from earth2studio.models.px import GraphCastSmall
from graphcast import data_utils
import xarray as xr
import dataclasses
from graphcast import rollout
from graphcast import data_utils
import jax
import numpy as np
from earth2studio.data.utils import fetch_data
from earth2studio.utils.coords import map_coords

from earth2studio.data import WB2ERA5
from datetime import datetime

# Set up data and model
model = GraphCastSmall.load_model(GraphCastSmall.load_default_package())
ds = WB2ERA5(cache=True)

x, coords = fetch_data(
    source=ds,
    time=[datetime(2022,1,1)],
    variable=model.input_coords()["variable"],
    lead_time=model.input_coords()["lead_time"],
    device="cuda",
)
x_input, coords_input = map_coords(x, coords, model.input_coords())
iter = model.create_iterator(x_input, coords_input)

n_steps = 6

# Earth2Studio wrapper forward prediction
outputs = []
step = 0
for x, coords in iter:
    print(x.shape)
    outputs.append(xr.DataArray(x.cpu(), coords=coords))
    step += 1
    if step > n_steps:
        break
prediction_e2studio = xr.concat(outputs, dim="lead_time")
prediction_e2studio.to_netcdf("e2s.nc")
del iter

# Graphcast original prediction
batch, target_lead_times = model.from_dataarray_to_dataset(xr.DataArray(x_input.cpu(), coords=coords_input), lead_time=6*n_steps)
eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
        batch, target_lead_times=target_lead_times,
        **dataclasses.asdict(model.ckpt.task_config))

generator = rollout.chunked_prediction_generator(model.run_forward,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings
)
prediction_graphcast = [next(generator) for _ in range(n_steps)]
prediction_graphcast = xr.concat(prediction_graphcast, dim="time")
prediction_graphcast.to_netcdf("gc.nc")

print(prediction_e2studio)
print(prediction_graphcast)

# Plot difference between two variables for each time-step
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(3, n_steps, figsize=(18, 8))
for i in range(n_steps):
    e2s_data = prediction_e2studio.sel(variable="u10m").isel(time=0, lead_time=i+1)
    gc_data = prediction_graphcast['10m_u_component_of_wind'].isel(time=i, batch=0)
    
    ax[0, i].imshow(e2s_data, cmap='RdBu_r', vmin=-30, vmax=30)
    ax[1, i].imshow(gc_data, cmap='RdBu_r', vmin=-30, vmax=30)
    ax[2, i].imshow(np.abs(e2s_data - gc_data), cmap='magma', vmin=0, vmax=2)
    ax[0, i].set_title(f'Timestep {i}')
    
ax[0, 0].set_ylabel(f'Earth2Studio')
ax[1, 0].set_ylabel(f'GraphCast')
ax[2, 0].set_ylabel(f'Diff')
plt.tight_layout()
plt.savefig("u10m.png")

plt.close("all")
fig, ax = plt.subplots(3, n_steps, figsize=(18, 8))
for i in range(n_steps):
    e2s_data = prediction_e2studio.sel(variable="z500").isel(time=0, lead_time=i+1)
    gc_data = prediction_graphcast['geopotential'].isel(time=i, batch=0, level=7)
    
    ax[0, i].imshow(e2s_data, cmap='viridis')
    ax[1, i].imshow(gc_data, cmap='viridis')
    ax[2, i].imshow(np.abs(e2s_data - gc_data), cmap='magma', vmin=0, vmax=10)
    ax[0, i].set_title(f'Timestep {i}')

ax[0, 0].set_ylabel(f'Earth2Studio')
ax[1, 0].set_ylabel(f'GraphCast')
ax[2, 0].set_ylabel(f'Diff')
plt.tight_layout()
plt.savefig("z500.png")

I chose one surface variable and also a pressure variable.
This generates the following images, the results are identical for 6 time-steps.

u10m
u10m

z500
z500

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva
Copy link
Collaborator

/blossom-ci

@NickGeneva NickGeneva merged commit 368fba0 into NVIDIA:main May 22, 2025
11 checks passed
@NickGeneva
Copy link
Collaborator

Thank you for the great contribution @rodrigoalmeida94 !

We greatly appreciate it and are already working on also adding the 0.25 degree model.

@rodrigoalmeida94 rodrigoalmeida94 deleted the graphcast branch May 22, 2025 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1 - On Deck To be worked on next ! - Release PRs or Issues releating to a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🚀[FEA]: Adding GraphCast
4 participants