Skip to content

Commit c403c47

Browse files
committed
Add benchmark & test files for the compiled clay encoder
1 parent 2d80fe6 commit c403c47

File tree

4 files changed

+208
-32
lines changed

4 files changed

+208
-32
lines changed

src/benchmark_encoder.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
import time
3+
import warnings
4+
5+
import torch
6+
7+
warnings.filterwarnings("ignore")
8+
9+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
11+
12+
def get_data():
13+
"""
14+
Generate random data tensors for model input.
15+
"""
16+
cube = torch.randn(128, 3, 256, 256).to(DEVICE)
17+
timestep = torch.randn(128, 4).to(DEVICE)
18+
latlon = torch.randn(128, 4).to(DEVICE)
19+
waves = torch.randn(3).to(DEVICE)
20+
gsd = torch.randn(1).to(DEVICE)
21+
return cube, timestep, latlon, waves, gsd
22+
23+
24+
def load_exported_model(eager=True):
25+
"""
26+
Load the exported model from a file.
27+
28+
Args:
29+
eager (bool): Flag to decide whether to use eager mode or compiled mode.
30+
"""
31+
print("Loading exported model")
32+
ep = torch.export.load("checkpoints/compiled/encoder.pt")
33+
if eager:
34+
model = ep.module()
35+
else:
36+
model = torch.compile(ep.module(), backend="inductor")
37+
return model
38+
39+
40+
def benchmark_model(model):
41+
"""
42+
Benchmark the model by running inference on randomly generated data.
43+
44+
Args:
45+
model: The model to benchmark.
46+
"""
47+
print("Benchmarking model")
48+
start = time.time()
49+
for i in range(20):
50+
cube, timestep, latlon, waves, gsd = get_data()
51+
with torch.inference_mode():
52+
out = model(cube, timestep, latlon, waves, gsd)
53+
print(
54+
f"Iteration {i}: Output shapes - {out[0].shape}, {out[1].shape}, {out[2].shape}, {out[3].shape}" # noqa E501
55+
)
56+
print("Time taken for inference: ", time.time() - start)
57+
58+
59+
def run(eager=True):
60+
"""
61+
Run the exported model and benchmark it.
62+
63+
Args:
64+
eager (bool): Flag to decide whether to use eager mode or compiled mode.
65+
"""
66+
print("Running model")
67+
model = load_exported_model(eager=eager)
68+
benchmark_model(model)
69+
70+
71+
if __name__ == "__main__":
72+
parser = argparse.ArgumentParser(
73+
description="Run benchmark for the exported model."
74+
)
75+
parser.add_argument(
76+
"--eager", action="store_true", help="Use eager mode for running the model."
77+
)
78+
args = parser.parse_args()
79+
80+
run(args.eager)

src/export.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,52 @@
1+
import warnings
12
from pathlib import Path
23

34
import torch
45
from torch.export import Dim
56

67
from src.model import ClayMAEModule
78

9+
warnings.filterwarnings("ignore")
10+
811
CHECKPOINT_PATH = "checkpoints/clay-v1-base.ckpt"
9-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10-
# device = torch.device("cpu")
12+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13+
CHIP_SIZE = 256
1114

1215

1316
def get_data():
14-
# Load data
15-
cube = torch.randn(128, 3, 224, 224).to(device)
16-
time = torch.randn(128, 4).to(device)
17-
latlon = torch.randn(128, 4).to(device)
18-
waves = torch.randn(3).to(device)
19-
gsd = torch.randn(1).to(device)
20-
return cube, time, latlon, waves, gsd
17+
"""
18+
Generate random data tensors for model input.
19+
"""
20+
cube = torch.randn(128, 3, CHIP_SIZE, CHIP_SIZE).to(DEVICE)
21+
timestep = torch.randn(128, 4).to(DEVICE)
22+
latlon = torch.randn(128, 4).to(DEVICE)
23+
waves = torch.randn(3).to(DEVICE)
24+
gsd = torch.randn(1).to(DEVICE)
25+
return cube, timestep, latlon, waves, gsd
2126

2227

2328
def load_model():
24-
module = ClayMAEModule.load_from_checkpoint(CHECKPOINT_PATH)
25-
encoder = module.model.encoder # Get the encoder
26-
encoder = encoder.to(device) # Move to device
29+
"""
30+
Load the model from a checkpoint and prepare it for evaluation.
31+
"""
32+
module = ClayMAEModule.load_from_checkpoint(
33+
CHECKPOINT_PATH, shuffle=False, mask_ratio=0.0
34+
)
35+
encoder = module.model.encoder.eval() # Get the encoder in eval mode
36+
encoder = encoder.to(DEVICE) # Move to the appropriate device
2737
return encoder
2838

2939

30-
def main():
31-
# Load data
32-
cube, time, latlon, waves, gsd = get_data()
33-
34-
# Load model
40+
def export_model():
41+
"""
42+
Export the model with dynamic shapes for deployment.
43+
"""
44+
cube, timestep, latlon, waves, gsd = get_data()
3545
encoder = load_model()
3646

3747
# Define dynamic shapes for model export
38-
batch_size = Dim("batch_size", min=2, max=128) # Define batch size range
39-
channel_bands = Dim("channel_bands", min=1, max=12) # Define channel bands range
48+
batch_size = Dim("batch_size", min=32, max=1200)
49+
channel_bands = Dim("channel_bands", min=1, max=10)
4050

4151
dynamic_shapes = {
4252
"cube": {0: batch_size, 1: channel_bands},
@@ -47,22 +57,17 @@ def main():
4757
}
4858

4959
# Export model
50-
exp_compiled_encoder = torch.export.export(
60+
ep = torch.export.export(
5161
mod=encoder,
52-
args=(cube, time, latlon, waves, gsd),
62+
args=(cube, timestep, latlon, waves, gsd),
5363
dynamic_shapes=dynamic_shapes,
54-
strict=False,
64+
strict=True,
5565
)
5666

57-
# tensortrt compiled model
58-
# trt_encoder = torch_tensorrt.dynamo.compile(
59-
# exp_compiled_encoder, [cube, time, latlon, waves, gsd]
60-
# )
61-
62-
# Save model
67+
# Save the exported model
6368
Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True)
64-
torch.export.save(exp_compiled_encoder, "checkpoints/compiled/encoder.pt")
69+
torch.export.save(ep, "checkpoints/compiled/encoder.pt")
6570

6671

6772
if __name__ == "__main__":
68-
main()
73+
export_model()

src/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def __init__( # noqa: PLR0913
3939
self.dim = dim
4040
self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
4141

42+
# Required to compile & export the model
43+
self.grid_size = 256 // 8
44+
self.num_patches = self.grid_size**2
45+
4246
self.patch_embedding = DynamicEmbedding(
4347
wave_dim=128,
4448
num_latent_tokens=128,
@@ -64,8 +68,9 @@ def add_encodings(self, patches, time, latlon, gsd):
6468
"""Add position encoding to the patches"""
6569
B, L, D = patches.shape
6670

67-
grid_size = int(math.sqrt(L))
68-
self.num_patches = grid_size**2
71+
# grid_size = int(math.sqrt(L))
72+
# self.num_patches = grid_size**2
73+
grid_size = self.grid_size
6974

7075
pos_encoding = (
7176
posemb_sincos_2d_with_gsd(

src/test_encoder.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
3+
from src.datamodule import ClayDataModule
4+
5+
# Load the pre-trained Clay encoder model
6+
clay_encoder = torch.export.load("checkpoints/compiled/encoder.pt").module()
7+
8+
9+
def load_batch():
10+
# Initialize the data module with appropriate parameters
11+
dm = ClayDataModule(
12+
data_dir="/home/ubuntu/data",
13+
size=256,
14+
metadata_path="configs/metadata.yaml",
15+
batch_size=1,
16+
num_workers=1,
17+
)
18+
19+
# Setup the data module for the 'fit' stage
20+
dm.setup(stage="fit")
21+
metadata = dm.metadata
22+
23+
# Get the training data loader and create an iterator
24+
trn_dl = dm.train_dataloader()
25+
iter_dl = iter(trn_dl)
26+
27+
return iter_dl, metadata
28+
29+
30+
def prepare_data(sensor, metadata, device):
31+
"""
32+
Load data from the sensor and transfer it to the specified device.
33+
34+
Args:
35+
- sensor (dict): Sensor data containing 'pixels', 'time', 'latlon', and 'platform'.
36+
- metadata (dict): Metadata information for different platforms.
37+
- device (torch.device): The device to which the data should be transferred.
38+
39+
Returns:
40+
- tuple: Transferred cube, timestep, latlon, waves, and gsd tensors.
41+
"""
42+
cube = sensor["pixels"]
43+
timestep = sensor["time"]
44+
latlon = sensor["latlon"]
45+
platform = sensor["platform"][0]
46+
47+
# Get wavelengths and ground sampling distance (gsd) from metadata
48+
waves = torch.tensor(list(metadata[platform].bands.wavelength.values()))
49+
gsd = torch.tensor([metadata[platform].gsd])
50+
51+
# Transfer data to the specified device
52+
cube, timestep, latlon, waves, gsd = map(
53+
lambda x: x.to(device), (cube, timestep, latlon, waves, gsd)
54+
)
55+
return cube, timestep, latlon, waves, gsd
56+
57+
58+
def main():
59+
dl, metadata = load_batch()
60+
61+
# Fetch samples from the data loader
62+
l8_c2l1 = next(dl)
63+
l8_c2l2 = next(dl)
64+
linz = next(dl)
65+
naip = next(dl)
66+
s1 = next(dl)
67+
s2 = next(dl)
68+
69+
# Perform inference with the Clay encoder model
70+
with torch.no_grad():
71+
for sensor in (l8_c2l1, l8_c2l2, linz, naip, s1, s2):
72+
# Load data and transfer to GPU
73+
batch = prepare_data(sensor, metadata, torch.device("cuda"))
74+
75+
# Get patch embeddings from the encoder model
76+
patch_embeddings, *_ = clay_encoder(*batch)
77+
78+
# Extract the class (CLS) embedding
79+
cls_embedding = patch_embeddings[:, 0, :]
80+
81+
# Print the platform and the shape of the CLS embedding
82+
print(sensor["platform"][0], cls_embedding.shape)
83+
84+
85+
if __name__ == "__main__":
86+
main()

0 commit comments

Comments
 (0)