Skip to content

Commit a704562

Browse files
committed
replace tile/landmask dict caches with direct MultiIndex lookups
Remove redundant _tile_index and _landmask_index Python dicts that were rebuilt on every load by iterating all rows. Instead, set a pandas MultiIndex on the existing lon_i/lat_i integer columns already stored in the parquet files, giving the same O(1) lookup via .loc[]. Add _lookup_tile() and _lookup_landmask() helpers to centralise the coord-to-grid-int conversion and index access pattern. fixes #175
1 parent f867855 commit a704562

File tree

1 file changed

+119
-102
lines changed

1 file changed

+119
-102
lines changed

geotessera/registry.py

Lines changed: 119 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -804,14 +804,19 @@ def _load_registry(self):
804804
missing = required_columns - set(self._registry_gdf.columns)
805805
raise ValueError(f"Registry is missing required columns: {missing}")
806806

807-
# Build dictionary index for O(1) lookups - avoids non-deterministic
808-
# pandas DataFrame filtering issues on some platforms (Windows CI)
809-
self._tile_index: Dict[Tuple[int, int, int], int] = {}
810-
for idx, row in enumerate(self._registry_gdf.itertuples()):
811-
lon_i = int(coord_to_grid_int(row.lon))
812-
lat_i = int(coord_to_grid_int(row.lat))
813-
key = (int(row.year), lon_i, lat_i)
814-
self._tile_index[key] = idx
807+
# Ensure lon_i/lat_i columns exist (backwards compat with old parquet files)
808+
if "lon_i" not in self._registry_gdf.columns:
809+
self._registry_gdf["lon_i"] = (
810+
self._registry_gdf["lon"] * 100
811+
).round().astype(np.int32)
812+
if "lat_i" not in self._registry_gdf.columns:
813+
self._registry_gdf["lat_i"] = (
814+
self._registry_gdf["lat"] * 100
815+
).round().astype(np.int32)
816+
817+
# Set MultiIndex for O(1) lookups via .loc[(year, lon_i, lat_i)]
818+
self._registry_gdf["year"] = self._registry_gdf["year"].astype(int)
819+
self._registry_gdf = self._registry_gdf.set_index(["year", "lon_i", "lat_i"])
815820

816821
def _load_landmasks_registry(self):
817822
"""Load landmasks Parquet registry from local path or download from remote with If-Modified-Since refresh."""
@@ -887,14 +892,63 @@ def _load_landmasks_registry(self):
887892
)
888893
self._landmasks_df = None
889894
else:
890-
# Build dictionary index for O(1) lookups - avoids non-deterministic
891-
# pandas DataFrame filtering issues on some platforms (Windows CI)
892-
self._landmask_index: Dict[Tuple[int, int], int] = {}
893-
for idx, row in enumerate(self._landmasks_df.itertuples()):
894-
lon_i = int(coord_to_grid_int(row.lon))
895-
lat_i = int(coord_to_grid_int(row.lat))
896-
key = (lon_i, lat_i)
897-
self._landmask_index[key] = idx
895+
# Ensure lon_i/lat_i columns exist (backwards compat with old parquet files)
896+
if "lon_i" not in self._landmasks_df.columns:
897+
self._landmasks_df["lon_i"] = (
898+
self._landmasks_df["lon"] * 100
899+
).round().astype(np.int32)
900+
if "lat_i" not in self._landmasks_df.columns:
901+
self._landmasks_df["lat_i"] = (
902+
self._landmasks_df["lat"] * 100
903+
).round().astype(np.int32)
904+
905+
# Set index for O(1) lookups via .loc[(lon_i, lat_i)]
906+
self._landmasks_df = self._landmasks_df.set_index(["lon_i", "lat_i"])
907+
908+
def _lookup_tile(self, year: int, lon: float, lat: float) -> pd.Series:
909+
"""Look up a tile row by year and coordinates.
910+
911+
Args:
912+
year: Year of the tile
913+
lon: Longitude of the tile center
914+
lat: Latitude of the tile center
915+
916+
Returns:
917+
pd.Series with the tile's registry data
918+
919+
Raises:
920+
ValueError: If tile not found in registry
921+
"""
922+
lon_i = int(coord_to_grid_int(lon))
923+
lat_i = int(coord_to_grid_int(lat))
924+
try:
925+
return self._registry_gdf.loc[(int(year), lon_i, lat_i)]
926+
except KeyError:
927+
raise ValueError(
928+
f"Tile not found in registry: year={year}, lon={lon:.2f}, lat={lat:.2f}"
929+
)
930+
931+
def _lookup_landmask(self, lon: float, lat: float) -> pd.Series:
932+
"""Look up a landmask row by coordinates.
933+
934+
Args:
935+
lon: Longitude of the tile center
936+
lat: Latitude of the tile center
937+
938+
Returns:
939+
pd.Series with the landmask's registry data
940+
941+
Raises:
942+
ValueError: If landmask not found in registry
943+
"""
944+
lon_i = int(coord_to_grid_int(lon))
945+
lat_i = int(coord_to_grid_int(lat))
946+
try:
947+
return self._landmasks_df.loc[(lon_i, lat_i)]
948+
except KeyError:
949+
raise ValueError(
950+
f"Landmask not found in registry: lon={lon:.2f}, lat={lat:.2f}"
951+
)
898952

899953
def iter_tiles_in_region(
900954
self, bounds: Tuple[float, float, float, float], year: int
@@ -939,16 +993,16 @@ def iter_tiles_in_region(
939993
min_lat - expansion : max_lat + expansion,
940994
]
941995

942-
tiles = tiles[tiles["year"] == year]
996+
# Filter by year using index level
997+
tiles = tiles[tiles.index.get_level_values("year") == year]
943998

944-
# Drop duplicates and yield - compute exact grid centers to ensure
945-
# coordinates match what the dictionary index expects
946-
tiles_unique = tiles[["year", "lon", "lat"]].drop_duplicates()
947-
for year_val, lon_val, lat_val in tiles_unique.values:
948-
# Convert to grid indices and back to get exact grid centers
949-
lon_i = coord_to_grid_int(lon_val)
950-
lat_i = coord_to_grid_int(lat_val)
951-
yield (int(year_val), lon_i / 100.0, lat_i / 100.0)
999+
# Yield unique (year, lon_i, lat_i) tuples from the index
1000+
seen = set()
1001+
for idx in tiles.index:
1002+
if idx not in seen:
1003+
seen.add(idx)
1004+
year_val, lon_i, lat_i = idx
1005+
yield (year_val, lon_i / 100.0, lat_i / 100.0)
9521006

9531007
def load_blocks_for_region(
9541008
self, bounds: Tuple[float, float, float, float], year: int
@@ -978,18 +1032,23 @@ def get_available_years(self) -> List[int]:
9781032
Returns:
9791033
List of years with available data, sorted in ascending order.
9801034
"""
981-
return sorted(self._registry_gdf["year"].unique().tolist())
1035+
return sorted(self._registry_gdf.index.get_level_values("year").unique().tolist())
9821036

9831037
def get_tile_counts_by_year(self) -> Dict[int, int]:
9841038
"""Get count of tiles per year using efficient pandas operations.
9851039
9861040
Returns:
9871041
Dictionary mapping year to tile count
9881042
"""
989-
# Use pandas groupby to count unique (lon, lat) coordinates per year
1043+
# Count unique (lon_i, lat_i) index pairs per year level
1044+
idx = self._registry_gdf.index
9901045
counts = (
991-
self._registry_gdf.groupby("year")[["lon", "lat"]]
992-
.apply(lambda x: len(x.drop_duplicates()))
1046+
pd.DataFrame({"year": idx.get_level_values("year"),
1047+
"lon_i": idx.get_level_values("lon_i"),
1048+
"lat_i": idx.get_level_values("lat_i")})
1049+
.drop_duplicates()
1050+
.groupby("year")
1051+
.size()
9931052
.to_dict()
9941053
)
9951054
return {int(year): int(count) for year, count in counts.items()}
@@ -1000,18 +1059,12 @@ def get_available_embeddings(self) -> List[Tuple[int, float, float]]:
10001059
Returns:
10011060
List of (year, lon, lat) tuples for all available embedding tiles
10021061
"""
1003-
unique_tiles = self._registry_gdf[["year", "lon", "lat"]].drop_duplicates()
1004-
1005-
# Compute exact grid center coordinates to ensure round-trip consistency
1006-
lon_i = np.round(unique_tiles["lon"].values * 100).astype(np.int32)
1007-
lat_i = np.round(unique_tiles["lat"].values * 100).astype(np.int32)
1008-
return list(
1009-
zip(
1010-
unique_tiles["year"].astype(int).values,
1011-
lon_i / 100.0,
1012-
lat_i / 100.0,
1013-
)
1014-
)
1062+
# Use unique index tuples directly - already (year, lon_i, lat_i)
1063+
unique_idx = self._registry_gdf.index.unique()
1064+
return [
1065+
(int(year), lon_i / 100.0, lat_i / 100.0)
1066+
for year, lon_i, lat_i in unique_idx
1067+
]
10151068

10161069
def fetch(
10171070
self,
@@ -1057,7 +1110,7 @@ def fetch(
10571110
# Use existing local file
10581111
return str(local_path)
10591112

1060-
# Query hash from GeoDataFrame for verification if year/lon/lat provided
1113+
# Query hash from registry for verification if year/lon/lat provided
10611114
file_hash = None
10621115
if (
10631116
self.verify_hashes
@@ -1066,12 +1119,8 @@ def fetch(
10661119
and lon is not None
10671120
and lat is not None
10681121
):
1069-
lon_i = coord_to_grid_int(lon)
1070-
lat_i = coord_to_grid_int(lat)
1071-
key = (int(year), int(lon_i), int(lat_i))
1072-
if key in self._tile_index:
1073-
idx = self._tile_index[key]
1074-
row = self._registry_gdf.iloc[idx]
1122+
try:
1123+
row = self._lookup_tile(year, lon, lat)
10751124
if is_scales:
10761125
# Use scales_hash column for scales files
10771126
if "scales_hash" in row.index:
@@ -1083,6 +1132,8 @@ def fetch(
10831132
else:
10841133
# Use hash column for embedding files
10851134
file_hash = row["hash"]
1135+
except ValueError:
1136+
pass # Tile not in registry, skip hash verification
10861137

10871138
# Download to embeddings_dir
10881139
# Use as_posix() to ensure forward slashes in URL even on Windows
@@ -1135,20 +1186,19 @@ def fetch_landmask(
11351186
# Use existing local file
11361187
return str(local_path)
11371188

1138-
# Query hash from landmasks index for verification if lon/lat provided
1189+
# Query hash from landmasks registry for verification if lon/lat provided
11391190
file_hash = None
11401191
if (
11411192
self.verify_hashes
11421193
and self._landmasks_df is not None
11431194
and lon is not None
11441195
and lat is not None
11451196
):
1146-
lon_i = coord_to_grid_int(lon)
1147-
lat_i = coord_to_grid_int(lat)
1148-
key = (int(lon_i), int(lat_i))
1149-
if key in self._landmask_index:
1150-
idx = self._landmask_index[key]
1151-
file_hash = self._landmasks_df.iloc[idx]["hash"]
1197+
try:
1198+
row = self._lookup_landmask(lon, lat)
1199+
file_hash = row["hash"]
1200+
except ValueError:
1201+
pass # Landmask not in registry, skip hash verification
11521202

11531203
# Download to embeddings_dir
11541204
url = f"{TESSERA_BASE_URL}/{self.version}/{LANDMASKS_DIR_NAME}/{filename}"
@@ -1174,11 +1224,12 @@ def get_landmask_count(self) -> int:
11741224
Count of unique landmask tiles
11751225
"""
11761226
if self._landmasks_df is not None:
1177-
# Count unique (lon, lat) combinations in landmasks registry
1178-
return len(self._landmasks_df[["lon", "lat"]].drop_duplicates())
1227+
# Count unique (lon_i, lat_i) index pairs in landmasks registry
1228+
return len(self._landmasks_df.index.unique())
11791229

11801230
# Fallback: count unique tiles in embeddings registry
1181-
return len(self._registry_gdf[["lon", "lat"]].drop_duplicates())
1231+
idx = self._registry_gdf.index.droplevel("year")
1232+
return len(idx.unique())
11821233

11831234
@property
11841235
def available_landmasks(self) -> List[Tuple[float, float]]:
@@ -1190,16 +1241,12 @@ def available_landmasks(self) -> List[Tuple[float, float]]:
11901241
"""
11911242
# Use landmasks registry if available
11921243
if self._landmasks_df is not None:
1193-
unique_tiles = self._landmasks_df[["lon", "lat"]].drop_duplicates()
1194-
lon_i = np.round(unique_tiles["lon"].values * 100).astype(np.int32)
1195-
lat_i = np.round(unique_tiles["lat"].values * 100).astype(np.int32)
1196-
return list(zip(lon_i / 100.0, lat_i / 100.0))
1244+
unique_idx = self._landmasks_df.index.unique()
1245+
return [(lon_i / 100.0, lat_i / 100.0) for lon_i, lat_i in unique_idx]
11971246

11981247
# Fallback: assume landmasks are available for all embedding tiles
1199-
unique_tiles = self._registry_gdf[["lon", "lat"]].drop_duplicates()
1200-
lon_i = np.round(unique_tiles["lon"].values * 100).astype(np.int32)
1201-
lat_i = np.round(unique_tiles["lat"].values * 100).astype(np.int32)
1202-
return list(zip(lon_i / 100.0, lat_i / 100.0))
1248+
unique_idx = self._registry_gdf.index.droplevel("year").unique()
1249+
return [(lon_i / 100.0, lat_i / 100.0) for lon_i, lat_i in unique_idx]
12031250

12041251
def get_manifest_info(self) -> Tuple[Optional[str], Optional[str]]:
12051252
"""Get manifest information (git hash and repo URL).
@@ -1232,18 +1279,8 @@ def get_tile_file_size(self, year: int, lon: float, lat: float) -> int:
12321279
"Please update your registry to include file size metadata."
12331280
)
12341281

1235-
# Use dictionary index for O(1) lookup
1236-
lon_i = coord_to_grid_int(lon)
1237-
lat_i = coord_to_grid_int(lat)
1238-
key = (int(year), int(lon_i), int(lat_i))
1239-
1240-
if key not in self._tile_index:
1241-
raise ValueError(
1242-
f"Tile not found in registry: year={year}, lon={lon:.2f}, lat={lat:.2f}"
1243-
)
1244-
1245-
idx = self._tile_index[key]
1246-
return int(self._registry_gdf.iloc[idx]["file_size"])
1282+
row = self._lookup_tile(year, lon, lat)
1283+
return int(row["file_size"])
12471284

12481285
def get_scales_file_size(self, year: int, lon: float, lat: float) -> int:
12491286
"""Get the file size of a scales file from the registry.
@@ -1265,18 +1302,8 @@ def get_scales_file_size(self, year: int, lon: float, lat: float) -> int:
12651302
"Please update your registry to include scales file size metadata."
12661303
)
12671304

1268-
# Use dictionary index for O(1) lookup
1269-
lon_i = coord_to_grid_int(lon)
1270-
lat_i = coord_to_grid_int(lat)
1271-
key = (int(year), int(lon_i), int(lat_i))
1272-
1273-
if key not in self._tile_index:
1274-
raise ValueError(
1275-
f"Tile not found in registry: year={year}, lon={lon:.2f}, lat={lat:.2f}"
1276-
)
1277-
1278-
idx = self._tile_index[key]
1279-
return int(self._registry_gdf.iloc[idx]["scales_size"])
1305+
row = self._lookup_tile(year, lon, lat)
1306+
return int(row["scales_size"])
12801307

12811308
def get_landmask_file_size(self, lon: float, lat: float) -> int:
12821309
"""Get the file size of a landmask tile from the registry.
@@ -1303,18 +1330,8 @@ def get_landmask_file_size(self, lon: float, lat: float) -> int:
13031330
"Please update your landmasks registry to include file size metadata."
13041331
)
13051332

1306-
# Use dictionary index for O(1) lookup
1307-
lon_i = coord_to_grid_int(lon)
1308-
lat_i = coord_to_grid_int(lat)
1309-
key = (int(lon_i), int(lat_i))
1310-
1311-
if key not in self._landmask_index:
1312-
raise ValueError(
1313-
f"Landmask not found in registry: lon={lon:.2f}, lat={lat:.2f}"
1314-
)
1315-
1316-
idx = self._landmask_index[key]
1317-
return int(self._landmasks_df.iloc[idx]["file_size"])
1333+
row = self._lookup_landmask(lon, lat)
1334+
return int(row["file_size"])
13181335

13191336
def calculate_download_requirements(
13201337
self,

0 commit comments

Comments
 (0)