Skip to content

Commit 5f4aee5

Browse files
committed
enable new embeddings for the marimo copernicus
1 parent 318ab3c commit 5f4aee5

1 file changed

Lines changed: 12 additions & 33 deletions

File tree

copernicus_marimo.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,18 +1792,15 @@ def _(galileo_tif_selector, galileo_tif_custom_path, generate_embeddings_button,
17921792
title="Loading Galileo model and generating embeddings..."
17931793
) as _spinner:
17941794
try:
1795-
import numpy as _np
17961795
import torch as _torch
17971796
from einops import rearrange as _rearrange
17981797
from sklearn.cluster import KMeans as _KMeans
17991798
from sklearn.decomposition import PCA as _PCA
1800-
from tqdm import tqdm as _tqdm
18011799

18021800
from src.data.config import NORMALIZATION_DICT_FILENAME as _NORM_FILENAME
18031801
from src.data.dataset import Dataset as _Dataset
18041802
from src.data.dataset import Normalizer as _Normalizer
18051803
from src.galileo import Encoder as _Encoder
1806-
from src.masking import MaskedOutput as _MaskedOutput
18071804
from src.utils import config_dir as _config_dir
18081805

18091806
_DATA_FOLDER = _Path("data")
@@ -1823,37 +1820,19 @@ def _(galileo_tif_selector, galileo_tif_custom_path, generate_embeddings_button,
18231820
_model = _Encoder.load_from_folder(_DATA_FOLDER / "models/nano")
18241821
_model.eval()
18251822

1826-
# --- Generate embeddings ---
1827-
_spinner.update(title="Generating embeddings (this may take a while)...")
1828-
_device = _torch.device("cpu")
1829-
_output_list = []
1830-
_batch_count = 0
1831-
for i in _tqdm(
1832-
_dataset_output.in_pixel_batches(batch_size=128, window_size=1)
1833-
):
1834-
_batch_count += 1
1835-
_masked = _MaskedOutput.from_datasetoutput(i, device=_device)
1836-
with _torch.no_grad():
1837-
_model_out = _model(
1838-
_masked.space_time_x.float(),
1839-
_masked.space_x.float(),
1840-
_masked.time_x.float(),
1841-
_masked.static_x.float(),
1842-
_masked.space_time_mask,
1843-
_masked.space_mask,
1844-
_torch.ones_like(_masked.time_mask),
1845-
_torch.ones_like(_masked.static_mask),
1846-
_masked.months.long(),
1847-
patch_size=1,
1848-
)
1849-
_output_list.append(
1850-
_model.average_tokens(*_model_out[:-1]).cpu().numpy()
1851-
)
1823+
# --- Generate embeddings (memory-efficient via memmap) ---
1824+
_spinner.update(title="Generating embeddings (memory-efficient memmap)...")
1825+
from src.inference import make_embeddings as _make_embeddings
18521826

1853-
_all = _np.concatenate(_output_list, axis=0)
1854-
_h_b = _dataset_output.space_time_x.shape[0]
1855-
_w_b = _dataset_output.space_time_x.shape[1]
1856-
embeddings_arr = _rearrange(_all, "(h w) d -> h w d", h=_h_b, w=_w_b)
1827+
_device = _torch.device("cpu")
1828+
embeddings_arr = _make_embeddings(
1829+
model=_model,
1830+
datasetoutput=_dataset_output,
1831+
window_size=1,
1832+
patch_size=1,
1833+
batch_size=128,
1834+
device=_device,
1835+
)
18571836
embeddings_flat_arr = _rearrange(embeddings_arr, "h w d -> (h w) d")
18581837

18591838
# --- K-means ---

0 commit comments

Comments
 (0)