1+ import warnings
12from pathlib import Path
23
34import torch
45from torch .export import Dim
56
67from src .model import ClayMAEModule
78
9+ warnings .filterwarnings ("ignore" )
10+
811CHECKPOINT_PATH = "checkpoints/clay-v1-base.ckpt"
9- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
10- # device = torch.device("cpu")
12+ DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
13+ CHIP_SIZE = 256
1114
1215
1316def get_data ():
14- # Load data
15- cube = torch .randn (128 , 3 , 224 , 224 ).to (device )
16- time = torch .randn (128 , 4 ).to (device )
17- latlon = torch .randn (128 , 4 ).to (device )
18- waves = torch .randn (3 ).to (device )
19- gsd = torch .randn (1 ).to (device )
20- return cube , time , latlon , waves , gsd
17+ """
18+ Generate random data tensors for model input.
19+ """
20+ cube = torch .randn (128 , 3 , CHIP_SIZE , CHIP_SIZE ).to (DEVICE )
21+ timestep = torch .randn (128 , 4 ).to (DEVICE )
22+ latlon = torch .randn (128 , 4 ).to (DEVICE )
23+ waves = torch .randn (3 ).to (DEVICE )
24+ gsd = torch .randn (1 ).to (DEVICE )
25+ return cube , timestep , latlon , waves , gsd
2126
2227
2328def load_model ():
24- module = ClayMAEModule .load_from_checkpoint (CHECKPOINT_PATH )
25- encoder = module .model .encoder # Get the encoder
26- encoder = encoder .to (device ) # Move to device
29+ """
30+ Load the model from a checkpoint and prepare it for evaluation.
31+ """
32+ module = ClayMAEModule .load_from_checkpoint (
33+ CHECKPOINT_PATH , shuffle = False , mask_ratio = 0.0
34+ )
35+ encoder = module .model .encoder .eval () # Get the encoder in eval mode
36+ encoder = encoder .to (DEVICE ) # Move to the appropriate device
2737 return encoder
2838
2939
30- def main ():
31- # Load data
32- cube , time , latlon , waves , gsd = get_data ()
33-
34- # Load model
40+ def export_model ():
41+ """
42+ Export the model with dynamic shapes for deployment.
43+ """
44+ cube , timestep , latlon , waves , gsd = get_data ()
3545 encoder = load_model ()
3646
3747 # Define dynamic shapes for model export
38- batch_size = Dim ("batch_size" , min = 2 , max = 128 ) # Define batch size range
39- channel_bands = Dim ("channel_bands" , min = 1 , max = 12 ) # Define channel bands range
48+ batch_size = Dim ("batch_size" , min = 32 , max = 1200 )
49+ channel_bands = Dim ("channel_bands" , min = 1 , max = 10 )
4050
4151 dynamic_shapes = {
4252 "cube" : {0 : batch_size , 1 : channel_bands },
@@ -47,22 +57,17 @@ def main():
4757 }
4858
4959 # Export model
50- exp_compiled_encoder = torch .export .export (
60+ ep = torch .export .export (
5161 mod = encoder ,
52- args = (cube , time , latlon , waves , gsd ),
62+ args = (cube , timestep , latlon , waves , gsd ),
5363 dynamic_shapes = dynamic_shapes ,
54- strict = False ,
64+ strict = True ,
5565 )
5666
57- # tensortrt compiled model
58- # trt_encoder = torch_tensorrt.dynamo.compile(
59- # exp_compiled_encoder, [cube, time, latlon, waves, gsd]
60- # )
61-
62- # Save model
67+ # Save the exported model
6368 Path ("checkpoints/compiled" ).mkdir (parents = True , exist_ok = True )
64- torch .export .save (exp_compiled_encoder , "checkpoints/compiled/encoder.pt" )
69+ torch .export .save (ep , "checkpoints/compiled/encoder.pt" )
6570
6671
6772if __name__ == "__main__" :
68- main ()
73+ export_model ()
0 commit comments