|
| 1 | +"""Export the Clay model to ONNX and pytorch ExportedProgram format. |
| 2 | +
|
| 3 | +This script exports the Clay model to ONNX and pytorch ExportedProgram format |
| 4 | +for deployment. The model is exported with dynamic shapes for inference. |
| 5 | +
|
| 6 | +How to use: |
| 7 | +
|
| 8 | +```bash |
| 9 | +python -m finetune.embedder.factory \ |
| 10 | + --img_size 256 \ |
| 11 | + --ckpt_path checkpoints/clay-v1-base.ckpt \ |
| 12 | + --device cuda \ |
| 13 | + --name clay-v1-encoder.onnx \ |
| 14 | + --onnx |
| 15 | +# exports Clay encoder to ONNX format that can handle chips of size 256x256 |
| 16 | +# for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ & Sentinel 1. |
| 17 | +``` |
| 18 | +
|
| 19 | +```bash |
| 20 | +python -m finetune.embedder.factory \ |
| 21 | + --img_size 224 \ |
| 22 | + --ckpt_path checkpoints/clay-v1-base.ckpt \ |
| 23 | + --device cuda \ |
| 24 | + --name clay-v1-encoder.pt2 \ |
| 25 | + --ep |
| 26 | +# exports Clay encoder to pytorch ExportedProgram format that can handle chips |
| 27 | +# of size 224x224 for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ |
| 28 | +# & Sentinel 1. |
| 29 | +``` |
| 30 | +
|
| 31 | +""" |
| 32 | + |
| 33 | +import argparse |
| 34 | +import re |
| 35 | +import warnings |
| 36 | +from pathlib import Path |
| 37 | + |
| 38 | +import torch |
| 39 | +from einops import repeat |
| 40 | +from torch import nn |
| 41 | +from torch.export import Dim |
| 42 | + |
| 43 | +from src.model import Encoder |
| 44 | +from src.utils import posemb_sincos_2d_with_gsd |
| 45 | + |
| 46 | +warnings.filterwarnings("ignore", category=UserWarning) |
| 47 | + |
| 48 | + |
| 49 | +class EmbeddingEncoder(Encoder): |
| 50 | + """Clay Encoder without mask and shuffle.""" |
| 51 | + |
| 52 | + def __init__( # noqa: PLR0913 |
| 53 | + self, |
| 54 | + img_size, |
| 55 | + patch_size, |
| 56 | + dim, |
| 57 | + depth, |
| 58 | + heads, |
| 59 | + dim_head, |
| 60 | + mlp_ratio, |
| 61 | + ): |
| 62 | + super().__init__( |
| 63 | + mask_ratio=0.0, |
| 64 | + shuffle=False, |
| 65 | + patch_size=patch_size, |
| 66 | + dim=dim, |
| 67 | + depth=depth, |
| 68 | + heads=heads, |
| 69 | + dim_head=dim_head, |
| 70 | + mlp_ratio=mlp_ratio, |
| 71 | + ) |
| 72 | + self.img_size = img_size |
| 73 | + |
| 74 | + # Using fixed grid size for inference |
| 75 | + self.grid_size = img_size // patch_size |
| 76 | + self.num_patches = self.grid_size**2 |
| 77 | + |
| 78 | + def add_encodings(self, patches, time, latlon, gsd): |
| 79 | + """Add position encoding to the patches""" |
| 80 | + B, L, D = patches.shape |
| 81 | + |
| 82 | + grid_size = self.grid_size |
| 83 | + |
| 84 | + pos_encoding = ( |
| 85 | + posemb_sincos_2d_with_gsd( |
| 86 | + h=grid_size, |
| 87 | + w=grid_size, |
| 88 | + dim=(self.dim - 8), |
| 89 | + gsd=gsd, |
| 90 | + ) |
| 91 | + .to(patches.device) |
| 92 | + .detach() |
| 93 | + ) # [L (D - 8)] |
| 94 | + |
| 95 | + time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8] |
| 96 | + |
| 97 | + pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] |
| 98 | + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] |
| 99 | + pos_metadata_encoding = torch.cat( |
| 100 | + (pos_encoding, time_latlon), dim=-1 |
| 101 | + ) # [B L D] |
| 102 | + |
| 103 | + patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D] |
| 104 | + return patches # [B L D] |
| 105 | + |
| 106 | + # def forward(self, cube, time, latlon, waves, gsd): |
| 107 | + def forward(self, datacube): |
| 108 | + cube, time, latlon, gsd, waves = ( |
| 109 | + datacube["pixels"], # [B C H W] |
| 110 | + datacube["time"], # [B 2] |
| 111 | + datacube["latlon"], # [B 2] |
| 112 | + datacube["gsd"], # 1 |
| 113 | + datacube["waves"], # [N] |
| 114 | + ) # [B C H W] |
| 115 | + B, C, H, W = cube.shape |
| 116 | + |
| 117 | + patches, _ = self.to_patch_embed( |
| 118 | + cube, waves |
| 119 | + ) # [B L D] - patchify & create embeddings per patch |
| 120 | + |
| 121 | + # Add time & latlon as encoding to patches |
| 122 | + patches = self.add_encodings( |
| 123 | + patches, |
| 124 | + time, |
| 125 | + latlon, |
| 126 | + gsd, |
| 127 | + ) # [B L D] - add position encoding to the embeddings |
| 128 | + |
| 129 | + # Add class tokens |
| 130 | + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] |
| 131 | + patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] |
| 132 | + |
| 133 | + # pass the patches through the transformer |
| 134 | + patches = self.transformer(patches) # [B (1 + L) D] |
| 135 | + |
| 136 | + # get the cls token |
| 137 | + embeddings = patches[:, 0, :] # [B D] |
| 138 | + |
| 139 | + return embeddings |
| 140 | + |
| 141 | + |
| 142 | +class Embedder(nn.Module): |
| 143 | + def __init__(self, img_size=256, ckpt_path=None, device="cpu"): |
| 144 | + super().__init__() |
| 145 | + self.clay_encoder = ( |
| 146 | + EmbeddingEncoder( # Default parameters for the Clay base model |
| 147 | + img_size=img_size, |
| 148 | + patch_size=8, |
| 149 | + dim=768, |
| 150 | + depth=12, |
| 151 | + heads=12, |
| 152 | + dim_head=64, |
| 153 | + mlp_ratio=4.0, |
| 154 | + ).to(device) |
| 155 | + ) |
| 156 | + self.img_size = img_size |
| 157 | + self.device = torch.device(device) |
| 158 | + self.load_clay_weights(ckpt_path) |
| 159 | + |
| 160 | + def load_clay_weights(self, ckpt_path): |
| 161 | + "Load the weights from the Clay model encoder." |
| 162 | + ckpt = torch.load(ckpt_path, map_location=self.device) |
| 163 | + state_dict = ckpt.get("state_dict") |
| 164 | + state_dict = { |
| 165 | + re.sub(r"^model\.encoder\.", "", name): param |
| 166 | + for name, param in state_dict.items() |
| 167 | + if name.startswith("model.encoder") |
| 168 | + } |
| 169 | + |
| 170 | + with torch.no_grad(): |
| 171 | + for name, param in self.clay_encoder.named_parameters(): |
| 172 | + if name in state_dict and param.size() == state_dict[name].size(): |
| 173 | + param.data.copy_(state_dict[name]) # Copy the weights |
| 174 | + else: |
| 175 | + print(f"No matching parameter for {name} with size {param.size()}") |
| 176 | + |
| 177 | + for param in self.clay_encoder.parameters(): |
| 178 | + param.requires_grad = False |
| 179 | + |
| 180 | + self.clay_encoder.eval() |
| 181 | + |
| 182 | + def forward(self, datacube): |
| 183 | + embeddings = self.clay_encoder(datacube) |
| 184 | + |
| 185 | + return embeddings |
| 186 | + |
| 187 | + def fake_datacube(self): |
| 188 | + "Generate a fake datacube for model export." |
| 189 | + dummy_datacube = { |
| 190 | + "pixels": torch.randn(2, 3, self.img_size, self.img_size), |
| 191 | + "time": torch.randn(2, 4), |
| 192 | + "latlon": torch.randn(2, 4), |
| 193 | + "waves": torch.randn(3), |
| 194 | + "gsd": torch.randn(1), |
| 195 | + } |
| 196 | + dummy_datacube = {k: v.to(self.device) for k, v in dummy_datacube.items()} |
| 197 | + return dummy_datacube |
| 198 | + |
| 199 | + def export_to_onnx(self, name): |
| 200 | + "Save the model to ONNX format." |
| 201 | + |
| 202 | + datacube = self.fake_datacube() |
| 203 | + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) |
| 204 | + |
| 205 | + # Export the model to ONNX format |
| 206 | + onnx_program = torch.onnx.dynamo_export( |
| 207 | + self.eval(), datacube, export_options=export_options |
| 208 | + ) |
| 209 | + |
| 210 | + # Save the exported model |
| 211 | + onnx_program.save(f"checkpoints/compiled/{name}") |
| 212 | + print(f"Model exported to ONNX format: checkpoints/compiled/{name}") |
| 213 | + |
| 214 | + return onnx_program |
| 215 | + |
| 216 | + def export_to_torchep(self, name): |
| 217 | + "Save the model to pytorch ExportedProgram format." |
| 218 | + |
| 219 | + datacube = self.fake_datacube() |
| 220 | + |
| 221 | + # dynamic shapes for model export |
| 222 | + batch_size = Dim("batch_size", min=2, max=1000) |
| 223 | + channel_bands = Dim("channel_bands", min=1, max=10) |
| 224 | + dynamic_shapes = { |
| 225 | + "datacube": { |
| 226 | + "pixels": {0: batch_size, 1: channel_bands}, |
| 227 | + "time": {0: batch_size}, |
| 228 | + "latlon": {0: batch_size}, |
| 229 | + "waves": {0: channel_bands}, |
| 230 | + "gsd": {0: None}, |
| 231 | + } |
| 232 | + } |
| 233 | + |
| 234 | + # Export the model to pytorch ExportedProgram format |
| 235 | + ep = torch.export.export( |
| 236 | + self.eval(), |
| 237 | + (datacube,), |
| 238 | + dynamic_shapes=dynamic_shapes, |
| 239 | + strict=True, |
| 240 | + ) |
| 241 | + |
| 242 | + # Save the exported model |
| 243 | + torch.export.save(ep, f"checkpoints/compiled/{name}") |
| 244 | + print( |
| 245 | + f"Model exported to pytorch ExportedProgram format: checkpoints/compiled/{name}" # noqa: E501 |
| 246 | + ) |
| 247 | + |
| 248 | + return ep |
| 249 | + |
| 250 | + |
| 251 | +if __name__ == "__main__": |
| 252 | + parser = argparse.ArgumentParser(description="Export the Clay model.") |
| 253 | + parser.add_argument( |
| 254 | + "--img_size", |
| 255 | + type=int, |
| 256 | + default=256, |
| 257 | + help="Image size for the model", |
| 258 | + ) |
| 259 | + parser.add_argument( |
| 260 | + "--ckpt_path", |
| 261 | + type=str, |
| 262 | + default="checkpoints/clay-v1-base.ckpt", |
| 263 | + help="Path to the Clay model checkpoint", |
| 264 | + ) |
| 265 | + parser.add_argument( |
| 266 | + "--device", |
| 267 | + type=str, |
| 268 | + default="cuda", |
| 269 | + help="Device to use for the model", |
| 270 | + ) |
| 271 | + parser.add_argument( |
| 272 | + "--name", |
| 273 | + type=str, |
| 274 | + default="clay-base.pt", |
| 275 | + help="Name of the exported model", |
| 276 | + ) |
| 277 | + parser.add_argument( |
| 278 | + "--onnx", |
| 279 | + action="store_true", |
| 280 | + help="Export the model to ONNX format", |
| 281 | + ) |
| 282 | + parser.add_argument( |
| 283 | + "--ep", |
| 284 | + action="store_true", |
| 285 | + help="Export the model to pytorch ExportedProgram format", |
| 286 | + ) |
| 287 | + |
| 288 | + args = parser.parse_args() |
| 289 | + |
| 290 | + Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) |
| 291 | + embedder = Embedder( |
| 292 | + img_size=args.img_size, |
| 293 | + ckpt_path=args.ckpt_path, |
| 294 | + device=args.device, |
| 295 | + ) |
| 296 | + |
| 297 | + if args.onnx: |
| 298 | + embedder.export_to_onnx(args.name) |
| 299 | + elif args.ep: |
| 300 | + embedder.export_to_torchep(args.name) |
| 301 | + else: |
| 302 | + print("Please specify the format to export the model.") |
| 303 | + parser.print_help() |
0 commit comments