-
Notifications
You must be signed in to change notification settings - Fork 53
Add GraphCast (1 degree model) #256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is awesome, thanks so much for adding this! |
Finishing up some other work but will start working on this next week. We are planning to add the 0.25 degree model in as well. Ill start pushing changes to your branch starting tomorrow if thats ok. |
@loliverhennigh sounds good! I wanted to add the device management (for jax), going to try to add something for that today or tomorrow. |
FYI. in case it's helpful, there is a graphcast runner here that works with the 0.25 deg models: https://github.com/NVIDIA/earth2mip/blob/main/earth2mip/networks/graphcast.py. I got it to work with a custom tisr implementation in physicsnemo so it can beyond the duration of the input data. It's been used for inferences in a few of our papers. |
Hey @rodrigoalmeida94 , I have some changes for this PR. Could you give me push access to your fork? If not I will fork off it and make another PR. |
@loliverhennigh just did, hope it worked? |
/blossom-ci |
/blossom-ci |
@loliverhennigh the good news is that yes, I can run all my perturbation stuff with GraphCastMini class, everything works smoothly, which is great 🥳 The bad news is that I was trying to check the predictions against the original implementation as you suggested, and it's not adding up. See notebook here https://github.com/rodrigoalmeida94/earth2studio/blob/check-graphcast/check_graphcast.ipynb My check here was to make use of the example batch data that the GraphCast repo provides and compute the predictions using the original methods and our implementation. The differences in the predictions are quite large (up to 15K in t2m) so something must be off - I was thinking maybe we are dealing with the lead times somehow wrong, but honestly not really sure how. Maybe I mixed up something in the notebook? |
Okay so good news: I tested this again using the I think the previous notebooks I was dealing with the lead times wrong (because I was using a local copy of ARCO, which was only for 2022). |
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
Cross checked the implementation with my own validation script based on @rodrigoalmeida94 notebooks and also looking at the GC repo import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from earth2studio.models.px import GraphCastSmall
from graphcast import data_utils
import xarray as xr
import dataclasses
from graphcast import rollout
from graphcast import data_utils
import jax
import numpy as np
from earth2studio.data.utils import fetch_data
from earth2studio.utils.coords import map_coords
from earth2studio.data import WB2ERA5
from datetime import datetime
# Set up data and model
model = GraphCastSmall.load_model(GraphCastSmall.load_default_package())
ds = WB2ERA5(cache=True)
x, coords = fetch_data(
source=ds,
time=[datetime(2022,1,1)],
variable=model.input_coords()["variable"],
lead_time=model.input_coords()["lead_time"],
device="cuda",
)
x_input, coords_input = map_coords(x, coords, model.input_coords())
iter = model.create_iterator(x_input, coords_input)
n_steps = 6
# Earth2Studio wrapper forward prediction
outputs = []
step = 0
for x, coords in iter:
print(x.shape)
outputs.append(xr.DataArray(x.cpu(), coords=coords))
step += 1
if step > n_steps:
break
prediction_e2studio = xr.concat(outputs, dim="lead_time")
prediction_e2studio.to_netcdf("e2s.nc")
del iter
# Graphcast original prediction
batch, target_lead_times = model.from_dataarray_to_dataset(xr.DataArray(x_input.cpu(), coords=coords_input), lead_time=6*n_steps)
eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
batch, target_lead_times=target_lead_times,
**dataclasses.asdict(model.ckpt.task_config))
generator = rollout.chunked_prediction_generator(model.run_forward,
rng=jax.random.PRNGKey(0),
inputs=eval_inputs,
targets_template=eval_targets * np.nan,
forcings=eval_forcings
)
prediction_graphcast = [next(generator) for _ in range(n_steps)]
prediction_graphcast = xr.concat(prediction_graphcast, dim="time")
prediction_graphcast.to_netcdf("gc.nc")
print(prediction_e2studio)
print(prediction_graphcast)
# Plot difference between two variables for each time-step
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots(3, n_steps, figsize=(18, 8))
for i in range(n_steps):
e2s_data = prediction_e2studio.sel(variable="u10m").isel(time=0, lead_time=i+1)
gc_data = prediction_graphcast['10m_u_component_of_wind'].isel(time=i, batch=0)
ax[0, i].imshow(e2s_data, cmap='RdBu_r', vmin=-30, vmax=30)
ax[1, i].imshow(gc_data, cmap='RdBu_r', vmin=-30, vmax=30)
ax[2, i].imshow(np.abs(e2s_data - gc_data), cmap='magma', vmin=0, vmax=2)
ax[0, i].set_title(f'Timestep {i}')
ax[0, 0].set_ylabel(f'Earth2Studio')
ax[1, 0].set_ylabel(f'GraphCast')
ax[2, 0].set_ylabel(f'Diff')
plt.tight_layout()
plt.savefig("u10m.png")
plt.close("all")
fig, ax = plt.subplots(3, n_steps, figsize=(18, 8))
for i in range(n_steps):
e2s_data = prediction_e2studio.sel(variable="z500").isel(time=0, lead_time=i+1)
gc_data = prediction_graphcast['geopotential'].isel(time=i, batch=0, level=7)
ax[0, i].imshow(e2s_data, cmap='viridis')
ax[1, i].imshow(gc_data, cmap='viridis')
ax[2, i].imshow(np.abs(e2s_data - gc_data), cmap='magma', vmin=0, vmax=10)
ax[0, i].set_title(f'Timestep {i}')
ax[0, 0].set_ylabel(f'Earth2Studio')
ax[1, 0].set_ylabel(f'GraphCast')
ax[2, 0].set_ylabel(f'Diff')
plt.tight_layout()
plt.savefig("z500.png") I chose one surface variable and also a pressure variable. |
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
Thank you for the great contribution @rodrigoalmeida94 ! We greatly appreciate it and are already working on also adding the 0.25 degree model. |
Earth2Studio Pull Request
Description
Add support for GraphCast model (small). In addition, adds
ARCOExtra
data (to produce relative humidity and accumulated precipitation on 6h intervals).Closes #199
Checklist
ToDos:
Dev notes
earth2studio
(torch.array, coords) conventiongraphcast
data handling utilities, I had to save the number of forecast steps as class instance variable, which is then used to create the data + iterator to produce forecasts. This means that an extra step is required when running the predictions (e.g.run.deterministic(["2019-02-01"], nsteps, model.set_nsteps(nsteps), data, io)
, which is definitely not ideal.graphcast
expects the "long" variable names.