Skip to content

Commit 23a6f27

Browse files
committed
Add embedder to load clay encoder & save in onnx/ep format
1 parent fbac3cd commit 23a6f27

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed

finetune/embedder/factory.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""Export the Clay model to ONNX and pytorch ExportedProgram format.
2+
3+
This script exports the Clay model to ONNX and pytorch ExportedProgram format
4+
for deployment. The model is exported with dynamic shapes for inference.
5+
6+
How to use:
7+
8+
```bash
9+
python -m finetune.embedder.factory \
10+
--img_size 256 \
11+
--ckpt_path checkpoints/clay-v1-base.ckpt \
12+
--device cuda \
13+
--name clay-v1-encoder.onnx \
14+
--onnx
15+
# exports Clay encoder to ONNX format that can handle chips of size 256x256
16+
# for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ & Sentinel 1.
17+
```
18+
19+
```bash
20+
python -m finetune.embedder.factory \
21+
--img_size 224 \
22+
--ckpt_path checkpoints/clay-v1-base.ckpt \
23+
--device cuda \
24+
--name clay-v1-encoder.pt2 \
25+
--ep
26+
# exports Clay encoder to pytorch ExportedProgram format that can handle chips
27+
# of size 224x224 for different sensors like Sentinel-2, Landsat-8, NAIP, LINZ
28+
# & Sentinel 1.
29+
```
30+
31+
"""
32+
33+
import argparse
34+
import re
35+
import warnings
36+
from pathlib import Path
37+
38+
import torch
39+
from einops import repeat
40+
from torch import nn
41+
from torch.export import Dim
42+
43+
from src.model import Encoder
44+
from src.utils import posemb_sincos_2d_with_gsd
45+
46+
warnings.filterwarnings("ignore", category=UserWarning)
47+
48+
49+
class EmbeddingEncoder(Encoder):
50+
"""Clay Encoder without mask and shuffle."""
51+
52+
def __init__( # noqa: PLR0913
53+
self,
54+
img_size,
55+
patch_size,
56+
dim,
57+
depth,
58+
heads,
59+
dim_head,
60+
mlp_ratio,
61+
):
62+
super().__init__(
63+
mask_ratio=0.0,
64+
shuffle=False,
65+
patch_size=patch_size,
66+
dim=dim,
67+
depth=depth,
68+
heads=heads,
69+
dim_head=dim_head,
70+
mlp_ratio=mlp_ratio,
71+
)
72+
self.img_size = img_size
73+
74+
# Using fixed grid size for inference
75+
self.grid_size = img_size // patch_size
76+
self.num_patches = self.grid_size**2
77+
78+
def add_encodings(self, patches, time, latlon, gsd):
79+
"""Add position encoding to the patches"""
80+
B, L, D = patches.shape
81+
82+
grid_size = self.grid_size
83+
84+
pos_encoding = (
85+
posemb_sincos_2d_with_gsd(
86+
h=grid_size,
87+
w=grid_size,
88+
dim=(self.dim - 8),
89+
gsd=gsd,
90+
)
91+
.to(patches.device)
92+
.detach()
93+
) # [L (D - 8)]
94+
95+
time_latlon = torch.hstack((time, latlon)).to(patches.device).detach() # [B 8]
96+
97+
pos_encoding = repeat(pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)]
98+
time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8]
99+
pos_metadata_encoding = torch.cat(
100+
(pos_encoding, time_latlon), dim=-1
101+
) # [B L D]
102+
103+
patches = patches + pos_metadata_encoding # [B L D] + [B L D] -> [B L D]
104+
return patches # [B L D]
105+
106+
# def forward(self, cube, time, latlon, waves, gsd):
107+
def forward(self, datacube):
108+
cube, time, latlon, gsd, waves = (
109+
datacube["pixels"], # [B C H W]
110+
datacube["time"], # [B 2]
111+
datacube["latlon"], # [B 2]
112+
datacube["gsd"], # 1
113+
datacube["waves"], # [N]
114+
) # [B C H W]
115+
B, C, H, W = cube.shape
116+
117+
patches, _ = self.to_patch_embed(
118+
cube, waves
119+
) # [B L D] - patchify & create embeddings per patch
120+
121+
# Add time & latlon as encoding to patches
122+
patches = self.add_encodings(
123+
patches,
124+
time,
125+
latlon,
126+
gsd,
127+
) # [B L D] - add position encoding to the embeddings
128+
129+
# Add class tokens
130+
cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D]
131+
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]
132+
133+
# pass the patches through the transformer
134+
patches = self.transformer(patches) # [B (1 + L) D]
135+
136+
# get the cls token
137+
embeddings = patches[:, 0, :] # [B D]
138+
139+
return embeddings
140+
141+
142+
class Embedder(nn.Module):
143+
def __init__(self, img_size=256, ckpt_path=None, device="cpu"):
144+
super().__init__()
145+
self.clay_encoder = (
146+
EmbeddingEncoder( # Default parameters for the Clay base model
147+
img_size=img_size,
148+
patch_size=8,
149+
dim=768,
150+
depth=12,
151+
heads=12,
152+
dim_head=64,
153+
mlp_ratio=4.0,
154+
).to(device)
155+
)
156+
self.img_size = img_size
157+
self.device = torch.device(device)
158+
self.load_clay_weights(ckpt_path)
159+
160+
def load_clay_weights(self, ckpt_path):
161+
"Load the weights from the Clay model encoder."
162+
ckpt = torch.load(ckpt_path, map_location=self.device)
163+
state_dict = ckpt.get("state_dict")
164+
state_dict = {
165+
re.sub(r"^model\.encoder\.", "", name): param
166+
for name, param in state_dict.items()
167+
if name.startswith("model.encoder")
168+
}
169+
170+
with torch.no_grad():
171+
for name, param in self.clay_encoder.named_parameters():
172+
if name in state_dict and param.size() == state_dict[name].size():
173+
param.data.copy_(state_dict[name]) # Copy the weights
174+
else:
175+
print(f"No matching parameter for {name} with size {param.size()}")
176+
177+
for param in self.clay_encoder.parameters():
178+
param.requires_grad = False
179+
180+
self.clay_encoder.eval()
181+
182+
def forward(self, datacube):
183+
embeddings = self.clay_encoder(datacube)
184+
185+
return embeddings
186+
187+
def fake_datacube(self):
188+
"Generate a fake datacube for model export."
189+
dummy_datacube = {
190+
"pixels": torch.randn(2, 3, self.img_size, self.img_size),
191+
"time": torch.randn(2, 4),
192+
"latlon": torch.randn(2, 4),
193+
"waves": torch.randn(3),
194+
"gsd": torch.randn(1),
195+
}
196+
dummy_datacube = {k: v.to(self.device) for k, v in dummy_datacube.items()}
197+
return dummy_datacube
198+
199+
def export_to_onnx(self, name):
200+
"Save the model to ONNX format."
201+
202+
datacube = self.fake_datacube()
203+
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
204+
205+
# Export the model to ONNX format
206+
onnx_program = torch.onnx.dynamo_export(
207+
self.eval(), datacube, export_options=export_options
208+
)
209+
210+
# Save the exported model
211+
onnx_program.save(f"checkpoints/compiled/{name}")
212+
print(f"Model exported to ONNX format: checkpoints/compiled/{name}")
213+
214+
return onnx_program
215+
216+
def export_to_torchep(self, name):
217+
"Save the model to pytorch ExportedProgram format."
218+
219+
datacube = self.fake_datacube()
220+
221+
# dynamic shapes for model export
222+
batch_size = Dim("batch_size", min=2, max=1000)
223+
channel_bands = Dim("channel_bands", min=1, max=10)
224+
dynamic_shapes = {
225+
"datacube": {
226+
"pixels": {0: batch_size, 1: channel_bands},
227+
"time": {0: batch_size},
228+
"latlon": {0: batch_size},
229+
"waves": {0: channel_bands},
230+
"gsd": {0: None},
231+
}
232+
}
233+
234+
# Export the model to pytorch ExportedProgram format
235+
ep = torch.export.export(
236+
self.eval(),
237+
(datacube,),
238+
dynamic_shapes=dynamic_shapes,
239+
strict=True,
240+
)
241+
242+
# Save the exported model
243+
torch.export.save(ep, f"checkpoints/compiled/{name}")
244+
print(
245+
f"Model exported to pytorch ExportedProgram format: checkpoints/compiled/{name}" # noqa: E501
246+
)
247+
248+
return ep
249+
250+
251+
if __name__ == "__main__":
252+
parser = argparse.ArgumentParser(description="Export the Clay model.")
253+
parser.add_argument(
254+
"--img_size",
255+
type=int,
256+
default=256,
257+
help="Image size for the model",
258+
)
259+
parser.add_argument(
260+
"--ckpt_path",
261+
type=str,
262+
default="checkpoints/clay-v1-base.ckpt",
263+
help="Path to the Clay model checkpoint",
264+
)
265+
parser.add_argument(
266+
"--device",
267+
type=str,
268+
default="cuda",
269+
help="Device to use for the model",
270+
)
271+
parser.add_argument(
272+
"--name",
273+
type=str,
274+
default="clay-base.pt",
275+
help="Name of the exported model",
276+
)
277+
parser.add_argument(
278+
"--onnx",
279+
action="store_true",
280+
help="Export the model to ONNX format",
281+
)
282+
parser.add_argument(
283+
"--ep",
284+
action="store_true",
285+
help="Export the model to pytorch ExportedProgram format",
286+
)
287+
288+
args = parser.parse_args()
289+
290+
Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True)
291+
embedder = Embedder(
292+
img_size=args.img_size,
293+
ckpt_path=args.ckpt_path,
294+
device=args.device,
295+
)
296+
297+
if args.onnx:
298+
embedder.export_to_onnx(args.name)
299+
elif args.ep:
300+
embedder.export_to_torchep(args.name)
301+
else:
302+
print("Please specify the format to export the model.")
303+
parser.print_help()

0 commit comments

Comments
 (0)