Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

201modularize wxc #328

Merged
merged 77 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
d5566bc
add ft functionality
romeokienzler Nov 6, 2024
dd043ee
Merge branch 'IBM:main' into 201
romeokienzler Nov 7, 2024
2541009
Merge branch 'IBM:main' into 201
romeokienzler Nov 12, 2024
d7742c3
add test for wxc factory
romeokienzler Nov 13, 2024
87306b0
instantiate wxc model
romeokienzler Nov 19, 2024
ba5d291
add path to weights for test
romeokienzler Nov 19, 2024
f485210
update data path for ccc
romeokienzler Nov 19, 2024
b0a7d8a
fix path
romeokienzler Nov 20, 2024
b3185d9
fix path
romeokienzler Nov 20, 2024
cf91c1e
fix path
romeokienzler Nov 20, 2024
7665db4
fix model reference
romeokienzler Nov 20, 2024
b1615ca
support parameterizable unet pincer
romeokienzler Nov 22, 2024
6e45b97
fix list initialization
romeokienzler Nov 22, 2024
a1cbdfe
refix list init
romeokienzler Nov 22, 2024
4cc452c
create trainable version of unet pincer
romeokienzler Nov 22, 2024
4163cc9
fix import
romeokienzler Nov 25, 2024
b23a1c8
implement training in module
romeokienzler Nov 26, 2024
781b08f
add test for wxc modularization inference and train
romeokienzler Nov 26, 2024
8a3d571
fix model to gpu assignment
romeokienzler Nov 26, 2024
8a57c3e
fix model to gpu assignment
romeokienzler Nov 26, 2024
9ba926b
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
747ac25
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
a76d585
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
e31344a
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
4f43698
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
1dbe908
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
18edfa3
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
9b9ba07
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
b928fb5
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
92b540f
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 26, 2024
9fe322f
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
2254271
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
f69b652
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
5f5223b
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
dbac694
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
0a7de40
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
9d41bfc
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
dfbbea1
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
8ef4c57
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
5282f21
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
9938115
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Nov 27, 2024
0f424aa
optimize config
romeokienzler Nov 28, 2024
d534b80
optimize config
romeokienzler Nov 28, 2024
ac6fdbd
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Dec 16, 2024
431e22c
add ft functionality
romeokienzler Nov 6, 2024
6b55b61
add test for wxc factory
romeokienzler Nov 13, 2024
ef2c1de
instantiate wxc model
romeokienzler Nov 19, 2024
0ad349f
add path to weights for test
romeokienzler Nov 19, 2024
fd24c16
update data path for ccc
romeokienzler Nov 19, 2024
303bd47
fix path
romeokienzler Nov 20, 2024
7728279
fix path
romeokienzler Nov 20, 2024
a30d05d
fix path
romeokienzler Nov 20, 2024
1296317
fix model reference
romeokienzler Nov 20, 2024
e347046
support parameterizable unet pincer
romeokienzler Nov 22, 2024
21056bf
fix list initialization
romeokienzler Nov 22, 2024
9ba6409
refix list init
romeokienzler Nov 22, 2024
61b56d2
create trainable version of unet pincer
romeokienzler Nov 22, 2024
d678c30
fix import
romeokienzler Nov 25, 2024
45e9d8e
implement training in module
romeokienzler Nov 26, 2024
2feaa63
add test for wxc modularization inference and train
romeokienzler Nov 26, 2024
889ce15
fix model to gpu assignment
romeokienzler Nov 26, 2024
e6886ed
optimize config
romeokienzler Nov 28, 2024
63edb29
merge
romeokienzler Dec 16, 2024
97fbeed
Merge branch 'IBM:main' into 201
romeokienzler Jan 6, 2025
669f09c
Merge branch '201' of github.com:romeokienzler/terratorch into 201
romeokienzler Jan 6, 2025
f920c43
cleanup
romeokienzler Jan 8, 2025
64ab98a
Merge branch '201' of github.com:romeokienzler/terratorch into 201
romeokienzler Jan 8, 2025
332eac7
move test
romeokienzler Jan 8, 2025
e5b2894
Merge branch '201' of github.com:romeokienzler/terratorch into 201
romeokienzler Jan 8, 2025
b359e80
Merge branch '201' of github.com:romeokienzler/terratorch into 201
romeokienzler Jan 10, 2025
f35ff17
add weights to test
romeokienzler Jan 13, 2025
e405700
add weights to test
romeokienzler Jan 13, 2025
093b57e
Merge remote-tracking branch 'origin/201' into 201
romeokienzler Jan 13, 2025
70b009f
working modularized version of downscaling eval
romeokienzler Jan 16, 2025
b529ef1
modularization for wxc gravity wave working
romeokienzler Jan 16, 2025
7a0cb34
add era5 datamodule
romeokienzler Jan 16, 2025
06821d8
merge main
romeokienzler Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 148 additions & 122 deletions examples/notebooks/WxCTutorialGravityWave.ipynb

Large diffs are not rendered by default.

270 changes: 270 additions & 0 deletions integrationtests/test_prithvi_wxc_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright contributors to the Terratorch project

import os

import pytest
import torch
import torch.distributed as dist
import yaml
from granitewxc.utils.config import get_config
from huggingface_hub import hf_hub_download
from lightning.pytorch import Trainer

from terratorch.models.wxc_model_factory import WxCModelFactory
from terratorch.tasks.wxc_task import WxCTask
import lightning.pytorch as pl

from terratorch.datamodules.era5 import ERA5DataModule
from terratorch.tasks.wxc_task import WxCTask
from typing import Any


def setup_function():
print("\nSetup function is called")

def teardown_function():
try:
os.remove("config.yaml")
except OSError:
pass

class StopTrainerCallback(pl.Callback):
def __init__(self, stop_after_n_batches):
super().__init__()
self.stop_after_n_batches = stop_after_n_batches
self.current_batch = 0

def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self.current_batch += 1
if self.current_batch >= self.stop_after_n_batches:
print("Stopping training early...")
#trainer.should_stop = True
raise StopIteration("Stopped prediction after reaching the specified batch limit.")

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self.current_batch += 1
if self.current_batch >= self.stop_after_n_batches:
print("Stopping training early...")
#trainer.should_stop = True
raise StopIteration("Stopped prediction after reaching the specified batch limit.")

@pytest.mark.parametrize("backbone", ["gravitywave", None, 'prithviwxc'])
def test_can_create_wxc_models(backbone):
if backbone == "gravitywave":
config_data = {
"singular_sharded_checkpoint": "./examples/notebooks/magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
}

with open("config.yaml", "w") as file:
yaml.dump(config_data, file, default_flow_style=False)

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

if dist.is_initialized():
dist.destroy_process_group()

dist.init_process_group(
backend='gloo',
init_method='env://',
rank=0,
world_size=1
)

f = WxCModelFactory()
f.build_model(backbone, None)

elif backbone == 'prithviwxc':
f = WxCModelFactory()
f.build_model(backbone, aux_decoders = None, backbone_weights='/dccstor/wfm/shared/pretrained/step_400.pt')

else:
config = get_config('./examples/confs/granite-wxc-merra2-downscale-config.yaml')
config.download_path = "/dccstor/wfm/shared/datasets/training/merra-2_v1/"

config.data.data_path_surface = os.path.join(config.download_path,'merra-2')
config.data.data_path_vertical = os.path.join(config.download_path, 'merra-2')
config.data.climatology_path_surface = os.path.join(config.download_path,'climatology')
config.data.climatology_path_vertical = os.path.join(config.download_path,'climatology')

config.model.input_scalers_surface_path = os.path.join(config.download_path,'climatology/musigma_surface.nc')
config.model.input_scalers_vertical_path = os.path.join(config.download_path,'climatology/musigma_vertical.nc')
config.model.output_scalers_surface_path = os.path.join(config.download_path,'climatology/anomaly_variance_surface.nc')
config.model.output_scalers_vertical_path = os.path.join(config.download_path,'climatology/anomaly_variance_vertical.nc')
f = WxCModelFactory()
f.build_model(backbone, aux_decoders = None, model_config=config)



def test_wxc_unet_pincer_inference():
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

if dist.is_initialized():
dist.destroy_process_group()

dist.init_process_group(
backend='gloo',
init_method='env://',
rank=0,
world_size=1
)

hf_hub_download(
repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
local_dir=".",
)

hf_hub_download(
)

hf_hub_download(
repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
repo_type='dataset',
filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc",
local_dir=".",
)

model_args = {
"in_channels": 1280,
"input_size_time": 1,
"n_lats_px": 64,
"n_lons_px": 128,
"patch_size_px": [2, 2],
"mask_unit_size_px": [8, 16],
"mask_ratio_inputs": 0.5,
"embed_dim": 2560,
"n_blocks_encoder": 12,
"n_blocks_decoder": 2,
"mlp_multiplier": 4,
"n_heads": 16,
"dropout": 0.0,
"drop_path": 0.05,
"parameter_dropout": 0.0,
"residual": "none",
"masking_mode": "both",
"decoder_shifting": False,
"positional_encoding": "absolute",
"checkpoint_encoder": [3, 6, 9, 12, 15, 18, 21, 24],
"checkpoint_decoder": [1, 3],
"in_channels_static": 3,
"input_scalers_mu": torch.tensor([0] * 1280),
"input_scalers_sigma": torch.tensor([1] * 1280),
"input_scalers_epsilon": 0,
"static_input_scalers_mu": torch.tensor([0] * 3),
"static_input_scalers_sigma": torch.tensor([1] * 3),
"static_input_scalers_epsilon": 0,
"output_scalers": torch.tensor([0] * 1280),
"backbone_weights": "magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
"backbone": "prithviwxc",
"aux_decoders": "unetpincer",
}
task = WxCTask(WxCModelFactory(), model_args=model_args, mode='eval')

trainer = Trainer(
max_epochs=1,
callbacks=[StopTrainerCallback(stop_after_n_batches=3)],
)
dm = ERA5DataModule(train_data_path='.', valid_data_path='.')
results = trainer.predict(model=task, datamodule=dm, return_predictions=True)

dist.destroy_process_group()


def test_wxc_unet_pincer_train():
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

if dist.is_initialized():
dist.destroy_process_group()

dist.init_process_group(
backend='gloo',
init_method='env://',
rank=0,
world_size=1
)

hf_hub_download(
repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
filename=f"magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
local_dir=".",
)

hf_hub_download(
repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
filename=f"config.yaml",
local_dir=".",
)

hf_hub_download(
repo_id="Prithvi-WxC/Gravity_wave_Parameterization",
repo_type='dataset',
filename=f"wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc",
local_dir=".",
)

model_args = {
"in_channels": 1280,
"input_size_time": 1,
"n_lats_px": 64,
"n_lons_px": 128,
"patch_size_px": [2, 2],
"mask_unit_size_px": [8, 16],
"mask_ratio_inputs": 0.5,
"embed_dim": 2560,
"n_blocks_encoder": 12,
"n_blocks_decoder": 2,
"mlp_multiplier": 4,
"n_heads": 16,
"dropout": 0.0,
"drop_path": 0.05,
"parameter_dropout": 0.0,
"residual": "none",
"masking_mode": "both",
"decoder_shifting": False,
"positional_encoding": "absolute",
"checkpoint_encoder": [3, 6, 9, 12, 15, 18, 21, 24],
"checkpoint_decoder": [1, 3],
"in_channels_static": 3,
"input_scalers_mu": torch.tensor([0] * 1280),
"input_scalers_sigma": torch.tensor([1] * 1280),
"input_scalers_epsilon": 0,
"static_input_scalers_mu": torch.tensor([0] * 3),
"static_input_scalers_sigma": torch.tensor([1] * 3),
"static_input_scalers_epsilon": 0,
"output_scalers": torch.tensor([0] * 1280),
"backbone_weights": "magnet-flux-uvtp122-epoch-99-loss-0.1022.pt",
"backbone": "prithviwxc",
"aux_decoders": "unetpincer",
"skip_connection": True,
}

task = WxCTask(WxCModelFactory(), model_args=model_args, mode='train')

trainer = Trainer(
callbacks=[StopTrainerCallback(stop_after_n_batches=3)],
max_epochs=1,
)
dm = ERA5DataModule(train_data_path='.', valid_data_path='.')
results = trainer.fit(model=task, datamodule=dm)

dist.destroy_process_group()

Loading
Loading