@@ -1792,18 +1792,15 @@ def _(galileo_tif_selector, galileo_tif_custom_path, generate_embeddings_button,
17921792 title = "Loading Galileo model and generating embeddings..."
17931793 ) as _spinner :
17941794 try :
1795- import numpy as _np
17961795 import torch as _torch
17971796 from einops import rearrange as _rearrange
17981797 from sklearn .cluster import KMeans as _KMeans
17991798 from sklearn .decomposition import PCA as _PCA
1800- from tqdm import tqdm as _tqdm
18011799
18021800 from src .data .config import NORMALIZATION_DICT_FILENAME as _NORM_FILENAME
18031801 from src .data .dataset import Dataset as _Dataset
18041802 from src .data .dataset import Normalizer as _Normalizer
18051803 from src .galileo import Encoder as _Encoder
1806- from src .masking import MaskedOutput as _MaskedOutput
18071804 from src .utils import config_dir as _config_dir
18081805
18091806 _DATA_FOLDER = _Path ("data" )
@@ -1823,37 +1820,19 @@ def _(galileo_tif_selector, galileo_tif_custom_path, generate_embeddings_button,
18231820 _model = _Encoder .load_from_folder (_DATA_FOLDER / "models/nano" )
18241821 _model .eval ()
18251822
1826- # --- Generate embeddings ---
1827- _spinner .update (title = "Generating embeddings (this may take a while)..." )
1828- _device = _torch .device ("cpu" )
1829- _output_list = []
1830- _batch_count = 0
1831- for i in _tqdm (
1832- _dataset_output .in_pixel_batches (batch_size = 128 , window_size = 1 )
1833- ):
1834- _batch_count += 1
1835- _masked = _MaskedOutput .from_datasetoutput (i , device = _device )
1836- with _torch .no_grad ():
1837- _model_out = _model (
1838- _masked .space_time_x .float (),
1839- _masked .space_x .float (),
1840- _masked .time_x .float (),
1841- _masked .static_x .float (),
1842- _masked .space_time_mask ,
1843- _masked .space_mask ,
1844- _torch .ones_like (_masked .time_mask ),
1845- _torch .ones_like (_masked .static_mask ),
1846- _masked .months .long (),
1847- patch_size = 1 ,
1848- )
1849- _output_list .append (
1850- _model .average_tokens (* _model_out [:- 1 ]).cpu ().numpy ()
1851- )
1823+ # --- Generate embeddings (memory-efficient via memmap) ---
1824+ _spinner .update (title = "Generating embeddings (memory-efficient memmap)..." )
1825+ from src .inference import make_embeddings as _make_embeddings
18521826
1853- _all = _np .concatenate (_output_list , axis = 0 )
1854- _h_b = _dataset_output .space_time_x .shape [0 ]
1855- _w_b = _dataset_output .space_time_x .shape [1 ]
1856- embeddings_arr = _rearrange (_all , "(h w) d -> h w d" , h = _h_b , w = _w_b )
1827+ _device = _torch .device ("cpu" )
1828+ embeddings_arr = _make_embeddings (
1829+ model = _model ,
1830+ datasetoutput = _dataset_output ,
1831+ window_size = 1 ,
1832+ patch_size = 1 ,
1833+ batch_size = 128 ,
1834+ device = _device ,
1835+ )
18571836 embeddings_flat_arr = _rearrange (embeddings_arr , "h w d -> (h w) d" )
18581837
18591838 # --- K-means ---
0 commit comments