Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
48748f1
Compile clay model encoder
srmsoumya Jul 15, 2024
97eb19a
Add benchmark & test files for the compiled clay encoder
srmsoumya Jul 18, 2024
eba1867
Revert changes to Encoder, don't change the API
srmsoumya Jul 24, 2024
73171dd
Add embedder to load clay encoder & save in onnx/ep format
srmsoumya Jul 24, 2024
1f2fcc9
Remove files from src, fix utils to run everything on same device
srmsoumya Jul 25, 2024
0a3ce9d
Bump torch==2.3.1 & torchvision==0.18.1, add onnx & onnxsxript as dep…
srmsoumya Jul 25, 2024
37503fe
Release few contraints on env
srmsoumya Jul 25, 2024
db8a3f2
Add notebook to show how to embed using compiled embedders
srmsoumya Jul 25, 2024
c0552bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
1e97506
Clear outputs from the notebook
srmsoumya Jul 25, 2024
0c37c58
Merge branch 'compile' of https://github.com/Clay-foundation/model in…
srmsoumya Jul 25, 2024
1803954
Add torchdata as a pip dependency
srmsoumya Jul 25, 2024
fdf4e80
Randomly pass time & latlon as zeros 20% of time
srmsoumya Aug 6, 2024
519a171
Add modis band info to metadata.yaml
srmsoumya Aug 6, 2024
a9efd2b
Add prefetch factor as an arg to DataModule
srmsoumya Aug 6, 2024
6434473
Change wavelength to millimeter scale for modis
srmsoumya Aug 6, 2024
7021ab4
change batch_first to True
srmsoumya Aug 6, 2024
5770cd6
Add transformer code from vit_pytorch as module
srmsoumya Aug 6, 2024
6bbc902
Add MRL
srmsoumya Aug 6, 2024
94f2b75
SAM as teacher, MRL, split code into modules
srmsoumya Aug 6, 2024
08df662
Remove outdated README, all in docs
srmsoumya Aug 6, 2024
c6df3e6
Fix trainer
srmsoumya Aug 7, 2024
c6151c6
Fix mean, std for s1
srmsoumya Aug 7, 2024
a8cee0f
update config for 1 node run
srmsoumya Aug 7, 2024
4dceba3
Modify Sentinel 1 from raw pixels to dB scale
srmsoumya Aug 8, 2024
fe01673
Cluster template for multi-node training
srmsoumya Aug 8, 2024
940fbab
Fix distributed DataLoader, add new env & slurm script
srmsoumya Aug 12, 2024
f12e221
Fix docs
srmsoumya Aug 12, 2024
ea4136b
Scale down recontruction loss for MODIS, change alpha to 0.9 for rec/…
srmsoumya Aug 13, 2024
224c547
Use groups in wandb
srmsoumya Aug 21, 2024
79712a7
Add random dropping for channels
srmsoumya Aug 21, 2024
380e814
Add script to check sanity of npz files
srmsoumya Aug 21, 2024
4fe7138
Add script to split npz files of batch 128 to 32
srmsoumya Aug 21, 2024
bb6d13c
Adapt script to large model size
srmsoumya Sep 20, 2024
c638972
Modify classify, segment examples for clay v1.5
srmsoumya Nov 4, 2024
05c167e
Check non MRL loss for the model (#331)
srmsoumya Nov 4, 2024
38a3c27
Document MODIS data sampling
yellowcap Jul 28, 2024
4f464b9
intermediate commit
yellowcap Sep 26, 2024
8ecff64
intermediate
yellowcap Sep 30, 2024
e701052
Update to v1.5 and add logging
yellowcap Oct 1, 2024
d46522a
Fix for v1.5 module input
yellowcap Oct 1, 2024
c7283d7
Update adding sentinel-2
yellowcap Oct 14, 2024
c71bebf
Upgrade stacchip
yellowcap Oct 14, 2024
97404e8
Update adding sentinel-2
yellowcap Oct 14, 2024
308eb80
Update docker file
yellowcap Oct 14, 2024
c9af8e6
Update docker file
yellowcap Oct 14, 2024
8ce1515
Update docker file
yellowcap Oct 14, 2024
300d978
update dockerfile
yellowcap Oct 14, 2024
9b5b7a6
Fix import path
yellowcap Oct 15, 2024
d79f93e
Fix paths again
yellowcap Oct 15, 2024
4d8a1a5
Fix paths in docker file
yellowcap Oct 15, 2024
70030ba
Fix docker file
yellowcap Oct 15, 2024
7beb600
Fix docker file
yellowcap Oct 15, 2024
d442661
Use relative path
yellowcap Oct 15, 2024
dd0d0f9
Make checkpoint path relative
yellowcap Oct 15, 2024
40659da
Make path relative
yellowcap Oct 15, 2024
2b1a567
Make metadata dynamic
yellowcap Oct 15, 2024
3ad580c
Make metadata dynamic fix
yellowcap Oct 15, 2024
c08926e
Log device early
yellowcap Oct 15, 2024
c14f8a1
Use pip instead of conda to avoid re-install of torch
yellowcap Oct 16, 2024
9952c4e
Use pip instead of conda to avoid re-install of torch
yellowcap Oct 16, 2024
f2773ee
Use pip instead of conda to avoid re-install of torch
yellowcap Oct 16, 2024
cdf6e41
Use pip instead of conda to avoid re-install of torch
yellowcap Oct 16, 2024
bdb2ca3
Use pip instead of conda to avoid re-install of torch
yellowcap Oct 16, 2024
9146bdd
Remove patch level embeddings storing
yellowcap Oct 16, 2024
b093ec5
Flaten path to mirror naip-analytic bucket
yellowcap Oct 16, 2024
ab703ce
Add option to limit to state
yellowcap Oct 17, 2024
e902414
Remove stale files
yellowcap Nov 5, 2024
2f33b44
Update model checkpoint name
yellowcap Nov 5, 2024
d85a664
Update model checkpoint and s3 sign strategy
yellowcap Nov 5, 2024
85f0544
Adapt to rio-stac 0.10.0 and pin requirement
yellowcap Nov 5, 2024
89f1be1
Fix datetime bug for files that have date stamps in them
yellowcap Nov 5, 2024
eefae20
Add check for files that are already processed
yellowcap Nov 6, 2024
2496953
Move embeddings to CPU early
yellowcap Nov 6, 2024
097e291
Remove patch embedding extraction
yellowcap Nov 6, 2024
0a6f476
Sentinel-2 2024 run preparation
yellowcap Nov 12, 2024
6c867ee
Fix sentinel paths for output
yellowcap Nov 12, 2024
ac1434b
Pre-download S2 scene, batch pixel load and embeding generation
yellowcap Nov 22, 2024
c270a4c
Use torch tensor for normalization
yellowcap Nov 22, 2024
88ed3c0
Use custom endpoint on demand
yellowcap Nov 22, 2024
ceecb61
Check exists for Sentinel-2 process
yellowcap Nov 24, 2024
55a82a8
Make Dockerfile cacheable
yellowcap Nov 26, 2024
085f371
Update all-sentinel.py
MaceGrim Dec 21, 2024
2a48007
Update utils.py
MaceGrim Dec 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions cluster/ml-cluster.yaml.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
Region: us-east-2

# DL AMI
Image:
Os: ubuntu2004
CustomAmi: <ami-id>

# FSx LUSTRE SHARED STORAGE
SharedStorage:
- MountDir: /fsx
Name: fsx
StorageType: FsxLustre
FsxLustreSettings:
FileSystemId: <fsx-id>

# HEAD NODE
HeadNode:
InstanceType: c5.12xlarge
Networking:
SubnetId: <subnet-public-id>
SecurityGroups:
- <sg-id> # EFA enabled SG
Ssh:
KeyName: <ssh-key>
LocalStorage:
RootVolume:
Size: 200
Iam:
S3Access:
- BucketName: <read-data-mount>
EnableWriteAccess: false
- BucketName: <write-data-mount>
EnableWriteAccess: true


# SCHEDULER
Scheduling:
Scheduler: slurm
SlurmQueues:
- Name: gpu-queue
ComputeResources:
- Name: <g-series or p-series>
Instances:
- InstanceType: <type>
MinCount: 0
MaxCount: 8
Efa:
Enabled: true
Networking:
SubnetIds:
- <subnet-private-id>
SecurityGroups:
- <sg-id> # EFA enabled SG
PlacementGroup:
Enabled: true
Iam:
S3Access:
- BucketName: <read-data-mount>
EnableWriteAccess: false
- BucketName: <write-data-mount>
EnableWriteAccess: true
7 changes: 4 additions & 3 deletions configs/classify_eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
seed_everything: 42
data:
metadata_path: configs/metadata.yaml
batch_size: 256
batch_size: 128
num_workers: 8
model:
num_classes: 10
ckpt_path: checkpoints/clay-v1-base.ckpt
lr: 1e-4
ckpt_path: checkpoints/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt
lr: 5e-5
wd: 0.05
b1: 0.9
b2: 0.95
Expand All @@ -28,6 +28,7 @@ trainer:
init_args:
entity: developmentseed
project: clay-classify
group: v1.5-test
log_model: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
Expand Down
45 changes: 26 additions & 19 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,56 +1,63 @@
# lightning.pytorch==2.1.2
seed_everything: 42
seed_everything: 108
data:
data_dir: data
size: 224
data_dir: /fsx
size: 256
metadata_path: configs/metadata.yaml
platforms:
- landsat-c2l1
- landsat-c2l2-sr
- linz
- modis
- naip
- sentinel-1-rtc
- sentinel-2-l2a
batch_size: 8
num_workers: 8
batch_size: 1
num_workers: 12
model:
model_size: base
model_size: large
mask_ratio: 0.75
norm_pix_loss: True
norm_pix_loss: False
patch_size: 8
shuffle: True
metadata_path: configs/metadata.yaml
teacher: vit_base_patch16_224.dino
lr: 1e-5
teacher: vit_large_patch14_reg4_dinov2.lvd142m
dolls: [16, 32, 64, 128, 256, 768, 1024]
doll_weights: [1, 1, 1, 1, 1, 1, 1]
lr: 5e-6
wd: 0.05
b1: 0.9
b2: 0.95
embeddings_level: mean
trainer:
accelerator: auto
accelerator: gpu
strategy: ddp
devices: auto
num_nodes: 1
devices: 8
num_nodes: 48
precision: bf16-mixed
log_every_n_steps: 10
max_epochs: 200
log_every_n_steps: 1
max_epochs: 1000
accumulate_grad_batches: 1
default_root_dir: s3://clay-model-ckpt/v1.0.0/
default_root_dir: checkpoints/v1.5.0/
fast_dev_run: False
num_sanity_val_steps: 0
use_distributed_sampler: False
limit_train_batches: 0.99
limit_val_batches: 0.99
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
entity: developmentseed
project: clay
group: v1.5-nomrl-dinov2
id: 0uy3in7l
resume: must
log_model: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: s3://clay-model-ckpt/v1.0.0/
dirpath: checkpoints/v1.5.0/
auto_insert_metric_name: False
filename: mae_v1.0.0_epoch-{epoch:02d}_val-loss-{val/loss:.4f}
filename: mae_v1.5.0_epoch-{epoch:02d}_val-loss-{val/loss:.4f}
monitor: val/loss
mode: min
save_last: True
Expand All @@ -63,4 +70,4 @@ trainer:
- class_path: src.callbacks_wandb.LogIntermediatePredictions
plugins:
- class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO
ckpt_path: null
ckpt_path: checkpoints/v1.5.0/last.ckpt
47 changes: 43 additions & 4 deletions configs/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,50 @@ sentinel-1-rtc:
gsd: 10
bands:
mean:
vv: 0.123273
vh: 0.027337
vv: -12.113
vh: -18.673
std:
vv: 1.492154
vh: 0.122182
vv: 8.314
vh: 8.017
wavelength:
vv: 3.5
vh: 4.0
modis:
band_order:
- sur_refl_b01
- sur_refl_b02
- sur_refl_b03
- sur_refl_b04
- sur_refl_b05
- sur_refl_b06
- sur_refl_b07
rgb_indices:
- 0
- 3
- 2
gsd: 500
bands:
mean:
sur_refl_b01: 1072.
sur_refl_b02: 1624.
sur_refl_b03: 931.
sur_refl_b04: 1023.
sur_refl_b05: 1599.
sur_refl_b06: 1404.
sur_refl_b07: 1051.
std:
sur_refl_b01: 1643.
sur_refl_b02: 1878.
sur_refl_b03: 1449.
sur_refl_b04: 1538.
sur_refl_b05: 1763.
sur_refl_b06: 1618.
sur_refl_b07: 1396.
wavelength:
sur_refl_b01: .645
sur_refl_b02: .858
sur_refl_b03: .469
sur_refl_b04: .555
sur_refl_b05: 1.240
sur_refl_b06: 1.640
sur_refl_b07: 2.130
9 changes: 5 additions & 4 deletions configs/segment_chesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ data:
val_chip_dir: data/cvpr/ny/val/chips/
val_label_dir: data/cvpr/ny/val/labels/
metadata_path: configs/metadata.yaml
batch_size: 40
batch_size: 16
num_workers: 8
platform: naip
model:
num_classes: 7
feature_maps:
- 3
- 5
- 7
- 11
ckpt_path: checkpoints/clay-v1-base.ckpt
- 15
- 23
ckpt_path: checkpoints/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-05_val-loss-0.1734.ckpt
lr: 1e-5
wd: 0.05
b1: 0.9
Expand All @@ -38,6 +38,7 @@ trainer:
init_args:
entity: developmentseed
project: clay-segment
group: v1.5-test
log_model: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
Expand Down
18 changes: 18 additions & 0 deletions copy_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

# Define source and destination directories
src="/fsx"
dest="data/pretrain"

# Create the destination directory if it doesn't exist
mkdir -p "$dest"

# Find all directories in the source directory
find "$src" -type d -print0 | while IFS= read -r -d '' dir; do
# Create corresponding directory in the destination
newdir="$dest${dir#$src}"
mkdir -p "$newdir"

# Copy the first 100 files from the source directory to the new directory
find "$dir" -maxdepth 1 -type f -print0 | head -z -n 100 | xargs -0 -I{} cp {} "$newdir"
done
62 changes: 62 additions & 0 deletions docs/release-notes/data_sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,67 @@ and a maximum of 2000 scenes for each catalog that was included.
We selected the latest imagery for each of the available regions
of new zealand. The list of catalogs is in the linz processor file.

### MODIS sampling strategy

For MODIS we used the [Surface Reflectance 8-Day (500m)](https://planetarycomputer.microsoft.com/dataset/modis-09A1-061)
product. The data is distributed in SIN grid tiles. We included all SIN grid
tiles that do not have any nodata inside. The selected SIN grid tiles are then
transform to EPSG:3857 for all tiles. This results in some variation between the
nominal resolution, although the original resolution from the SIN projection is
500 meters. For input to the model, we assumed the 500m resolution as a fixed
resolution size for all tiles.

Algorithm to determine which tiles do not have nodata is shown in the code block
below. This resulted in 233 SIN grid tiles to be selected. For each of these
we sampled the first STAC search result for each month in each year from 2018
until 2023. This therefore resulted in 72 (`6 years * 12 months`) separate scenes
for each of the 233 SIN grid tiles.

Script for selection of SIN grid tiles included in the sampling:

```python
from multiprocessing import Pool
import rasterio
import planetary_computer as pc
import pystac_client
import numpy as np

SIN_GRID_TILES = []
for i in SIN_VERTICAL_RANGE:
for j in SIN_HORIZONTAL_RANGE:
SIN_GRID_TILES.append((i, j))

def evaluate_nodata(i, j):
catalog = pystac_client.Client.open(STAC_API, modifier=pc.sign_inplace)
items = catalog.search(
collections=[COLLECTION],
query={
"modis:vertical-tile": {
"eq": i,
},
"modis:horizontal-tile": {
"eq": j,
},
},
max_items=1,
)
item = list(items.item_collection())[0]

with rasterio.open(item.assets["sur_refl_b01"].href) as src:
data = src.read()

nodata = np.sum(data == -28672)

if nodata == 0:
print(i, j)
return i, j

if __name__ == '__main__':
with Pool(16) as p:
indexes = p.starmap(evaluate_nodata, SIN_GRID_TILES)
print("done")
print(indexes)
```

## Data preparation

Expand All @@ -136,6 +197,7 @@ Using stacchip, we created a dataset with a size of 33.8 TB of imagery, with abo
| Landsat-c2l1 | 5827333 |
| Landsat-c2l2-sr | 5790651 |
| Sentinel-1-rtc | 16133394 |
| MODIS | 1350864 |

# Older versions

Expand Down
45 changes: 45 additions & 0 deletions embeddings/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
FROM 763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-ec2

WORKDIR /model

RUN git clone -b all-of-naip https://github.com/Clay-foundation/model.git .

RUN aws s3 cp --no-sign-request s3://clay-model-ckpt/v1.5.0-no-mrl-dinov2/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt data/mae_v1.5.0_epoch-07_val-loss-0.1718.ckpt
RUN aws s3 cp --no-sign-request s3://clay-mgrs-samples/naip-manifest.txt.zip data/naip-manifest.txt.zip
RUN aws s3 cp --no-sign-request s3://clay-mgrs-samples/element84-tiles-2023.gz data/element84-tiles-2023.gz

RUN pip install \
einops~=0.7.0 \
fiona~=1.9.5 \
geopandas~=0.14.1 \
jsonargparse~=4.27.0 \
lightning~=2.1.0 \
matplotlib~=3.9.0 \
planetary-computer~=1.0.0 \
python-box~=7.1.0 \
pyarrow~=15.0.2 \
rasterio~=1.3.10 \
s3fs~=2024.6.0 \
boto3~=1.34.122 \
botocore~=1.34.122 \
scikit-image~=0.22.0 \
scikit-learn~=1.4.0 \
stackstac~=0.5.0 \
timm~=0.9.16 \
transformers~=4.35.2 \
typeshed-client~=2.4.0 \
vit-pytorch~=1.6.4 \
zarr~=2.16.1 \
geoarrow-pyarrow==0.1.2 \
torchdata==0.7.1 \
stacchip==0.1.35 \
wandb==0.17.5 \
rio_stac~=0.10.0

RUN git pull && git checkout ceecb6138705cb28a5f4d3f61f22b19a2f625edb

# Move file to home directory so that relative imports work
RUN cp embeddings/all-naip.py .
RUN cp embeddings/all-sentinel.py .

ENTRYPOINT ["python"]
Loading
Loading