Skip to content

Commit 6e6e533

Browse files
authored
Merge pull request #35 from ucam-eo/fetch_embeddings_lazy
Fetch embeddings lazily and embeddings count
2 parents 0b5411c + f8d70b4 commit 6e6e533

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

geotessera/core.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from pathlib import Path
9-
from typing import Union, List, Tuple, Optional, Dict
9+
from typing import Union, List, Tuple, Optional, Dict, Generator
1010
import json
1111
import numpy as np
1212

@@ -59,21 +59,36 @@ def version(self) -> str:
5959
"""Get the GeoTessera library version."""
6060
return __version__
6161

62+
def embeddings_count(self, bbox: Tuple[float, float, float, float], year: int = 2024) -> int:
63+
"""Get total number of embedding tiles within a bounding box.
64+
65+
Args:
66+
bbox: Bounding box as (min_lon, min_lat, max_lon, max_lat)
67+
year: Year of embeddings to consider
68+
69+
Returns:
70+
Total number of tiles in the bounding box
71+
"""
72+
tiles = self.registry.load_blocks_for_region(bbox, year)
73+
return len(tiles)
74+
75+
# returns a generator
6276
def fetch_embeddings(
6377
self,
6478
bbox: Tuple[float, float, float, float],
6579
year: int = 2024,
6680
progress_callback: Optional[callable] = None,
67-
) -> List[Tuple[float, float, np.ndarray, object, object]]:
68-
"""Fetch all embedding tiles within a bounding box with CRS information.
81+
) -> Generator[Tuple[float, float, np.ndarray, object, object], None, None]:
82+
"""Lazily fetches all embedding tiles within a bounding box with CRS information.
83+
Use as a generator to process tiles one at a time in a memory-efficient manner.
6984
7085
Args:
7186
bbox: Bounding box as (min_lon, min_lat, max_lon, max_lat)
7287
year: Year of embeddings to download
7388
progress_callback: Optional callback function(current, total) for progress tracking
7489
7590
Returns:
76-
List of (tile_lon, tile_lat, embedding_array, crs, transform) tuples where:
91+
Generator of (tile_lon, tile_lat, embedding_array, crs, transform) tuples where:
7792
- tile_lon: Tile center longitude
7893
- tile_lat: Tile center latitude
7994
- embedding_array: shape (H, W, 128) with dequantized values
@@ -84,7 +99,6 @@ def fetch_embeddings(
8499
tiles_to_download = self.registry.load_blocks_for_region(bbox, year)
85100

86101
# Download each tile with progress tracking
87-
results = []
88102
total_tiles = len(tiles_to_download)
89103

90104
for i, (tile_lon, tile_lat) in enumerate(tiles_to_download):
@@ -108,7 +122,8 @@ def tile_progress_callback(
108122
embedding, crs, transform = self.fetch_embedding(
109123
tile_lon, tile_lat, year, tile_progress_callback
110124
)
111-
results.append((tile_lon, tile_lat, embedding, crs, transform))
125+
126+
yield tile_lon, tile_lat, embedding, crs, transform
112127

113128
# Update progress for completed tile
114129
if progress_callback:
@@ -130,7 +145,7 @@ def tile_progress_callback(
130145
)
131146
continue
132147

133-
return results
148+
return None
134149

135150
def fetch_embedding(
136151
self,
@@ -601,7 +616,7 @@ def merge_geotiffs_to_mosaic(
601616
if progress_callback:
602617
progress_callback(i, total_files * 2 + 2, f"Reprojecting file {i+1}/{total_files}...")
603618

604-
result_file, error = self._reproject_geotiff_file(args)
619+
_, error = self._reproject_geotiff_file(args)
605620

606621
if error:
607622
failed_files.append((geotiff_paths[i], error))

0 commit comments

Comments
 (0)