Skip to content

Commit efe1683

Browse files
authored
Merge pull request #94 from mayrajeo/zarr-downloads
Add zarr to supported download formats
2 parents 5e7d8ad + e94c599 commit efe1683

File tree

10 files changed

+1277
-215
lines changed

10 files changed

+1277
-215
lines changed

geotessera/cli.py

Lines changed: 252 additions & 193 deletions
Large diffs are not rendered by default.

geotessera/core.py

Lines changed: 272 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
import numpy as np
1313
import geopandas as gpd
1414

15-
from .registry import Registry, EMBEDDINGS_DIR_NAME, tile_to_geotiff_path
15+
from .registry import (
16+
Registry,
17+
EMBEDDINGS_DIR_NAME,
18+
tile_to_geotiff_path,
19+
tile_to_zarr_path,
20+
)
1621

1722
try:
1823
import importlib.metadata
@@ -63,7 +68,7 @@ class GeoTessera:
6368
Core functionality:
6469
- Download tiles to local embeddings_dir
6570
- Sample embeddings at point locations from local tiles
66-
- Export individual tiles as GeoTIFF files with correct metadata
71+
- Export individual tiles as GeoTIFF files or zarr archives with correct metadata
6772
- Manage registry and data access
6873
6974
Typical workflows:
@@ -1838,6 +1843,271 @@ def merge_geotiffs_to_mosaic(
18381843

18391844
return str(output_path)
18401845

1846+
def export_embedding_zarr(
1847+
self,
1848+
lon: float,
1849+
lat: float,
1850+
output_path: Union[str, Path],
1851+
year: int = 2024,
1852+
bands: Optional[List[int]] = None,
1853+
) -> str:
1854+
"""Export a single embedding tile as a zarr archive with native UTM projection.
1855+
1856+
Args:
1857+
lon: Tile center longitude
1858+
lat: Tile center latitude
1859+
output_path: Output path for zarr file
1860+
year: Year of embeddings to export
1861+
bands: List of band indices to export (None = all 128 bands)
1862+
compress: Compression method for GeoTIFF
1863+
1864+
Returns:
1865+
Path to created zarr file
1866+
1867+
Raises:
1868+
ImportError: If xarray, rioxarray, zarr or dask is not available
1869+
RuntimeError: If landmask tile or embedding data cannot be fetched
1870+
FileNotFoundError: If registry files are missing
1871+
"""
1872+
try:
1873+
import xarray as xr
1874+
import rioxarray as rxr # noqa: F401 - needed for .rio accessor
1875+
import zarr # noqa: F401 - needed for .to_zarr()
1876+
import dask # noqa: F401 - needed for chunking
1877+
1878+
import warnings
1879+
1880+
warnings.filterwarnings("ignore", category=UserWarning)
1881+
except ImportError:
1882+
raise ImportError(
1883+
"saving to zarr requires xarray, rioxarray, zarr and dask"
1884+
)
1885+
1886+
output_path = Path(output_path)
1887+
output_path.parent.mkdir(parents=True, exist_ok=True)
1888+
1889+
# Fetch single tile with CRS info
1890+
embedding, crs, transform = self.fetch_embedding(lon, lat, year)
1891+
1892+
# Select bands
1893+
if bands is not None:
1894+
data = embedding[:, :, bands].copy()
1895+
band_count = len(bands)
1896+
else:
1897+
data = embedding.copy()
1898+
band_count = 128
1899+
1900+
# Get dimensions for GeoTIFF
1901+
height, width = data.shape[:2]
1902+
1903+
ds = xr.Dataset(
1904+
{"embedding": (("y", "x", "band"), data)},
1905+
coords={
1906+
"y": np.arange(height),
1907+
"x": np.arange(width),
1908+
"band": np.arange(band_count),
1909+
},
1910+
attrs={
1911+
"TESSERA_DATASET_VERSION": self.dataset_version,
1912+
"TESSERA_YEAR": year,
1913+
"TESSERA_TILE_LAT": f"{lat:.2f}",
1914+
"TESSERA_TILE_LON": f"{lon:.2f}",
1915+
"TESSERA_DESCRIPTION": "GeoTessera satellite embedding tile",
1916+
"GEOTESSERA_VERSION": __version__,
1917+
},
1918+
)
1919+
1920+
x_coords = [transform.c + (i + 0.5) * transform.a for i in range(width)]
1921+
y_coords = [transform.f + (j + 0.5) * transform.e for j in range(height)]
1922+
1923+
# Add band descriptions
1924+
if bands is not None:
1925+
output_bands = [f"Tessera_Band_{band_idx}" for band_idx in bands]
1926+
else:
1927+
output_bands = [f"Tessera_Band_{j}" for j in range(128)]
1928+
1929+
ds = ds.assign_coords(
1930+
x=("x", x_coords),
1931+
y=("y", y_coords),
1932+
band=("band", np.array(output_bands, dtype=np.dtypes.StringDType())),
1933+
)
1934+
ds = (
1935+
ds.rio.write_crs(crs)
1936+
.rio.set_spatial_dims(x_dim="x", y_dim="y")
1937+
.rio.write_coordinate_system()
1938+
)
1939+
ds = ds.rio.write_transform(transform)
1940+
ds.to_zarr(output_path, zarr_format=3)
1941+
return str(output_path)
1942+
1943+
def export_embedding_zarrs(
1944+
self,
1945+
tiles_to_fetch: Iterable[Tuple[int, float, float]],
1946+
output_dir: Union[str, Path],
1947+
bands: Optional[List[int]] = None,
1948+
progress_callback: Optional[callable] = None,
1949+
) -> List[str]:
1950+
"""Export all embedding tiles in bounding box as individual zarr files with native UTM projections.
1951+
The list of tiles to fetch can be obtained by registry.load_blocks_for_region().
1952+
1953+
Args:
1954+
tiles_to_fetch: List of tiles to fetch as (year, tile_lon, tile_lat) tuples
1955+
output_dir: Directory to save GeoTIFF files
1956+
bands: List of band indices to export (None = all 128 bands)
1957+
compress: Compression method for GeoTIFF
1958+
progress_callback: Optional callback function(current, total) for progress tracking
1959+
1960+
Returns:
1961+
List of paths to created zarr files
1962+
1963+
Raises:
1964+
ImportError: If rasterio is not available
1965+
RuntimeError: If landmask tiles or embedding data cannot be fetched
1966+
FileNotFoundError: If registry files are missing
1967+
"""
1968+
try:
1969+
import xarray as xr
1970+
import rioxarray as rxr # noqa: F401 - needed for .rio accessor
1971+
import zarr # noqa: F401 - needed for .to_zarr()
1972+
import dask # noqa: F401 - needed for chunking
1973+
1974+
import warnings
1975+
1976+
warnings.filterwarnings("ignore", category=UserWarning)
1977+
except ImportError:
1978+
raise ImportError(
1979+
"saving to zarr requires xarray, rioxarray, zarr and dask"
1980+
)
1981+
1982+
output_dir = Path(output_dir)
1983+
output_dir.mkdir(parents=True, exist_ok=True)
1984+
1985+
# Create a wrapper callback to handle two-phase progress
1986+
def fetch_progress_callback(current: int, total: int, status: str = None):
1987+
# Phase 1: Fetching tiles (0-50% of total progress)
1988+
overall_progress = int((current / total) * 50)
1989+
display_status = status or f"Fetching tile {current}/{total}"
1990+
progress_callback(overall_progress, 100, display_status)
1991+
1992+
# Fetch tiles with progress tracking
1993+
if progress_callback:
1994+
progress_callback(0, 100, "Loading registry blocks...")
1995+
1996+
tiles = list(
1997+
self.fetch_embeddings(
1998+
tiles_to_fetch, fetch_progress_callback if progress_callback else None
1999+
)
2000+
)
2001+
if progress_callback:
2002+
total_tiles = len(tiles_to_fetch)
2003+
2004+
if not tiles:
2005+
self.logger.warning("No tiles found in bounding box")
2006+
return []
2007+
2008+
if progress_callback:
2009+
progress_callback(
2010+
50, 100, f"Fetched {total_tiles} tiles, starting GeoTIFF export..."
2011+
)
2012+
2013+
created_files = []
2014+
2015+
# Sequential zarr writing
2016+
for i, (year, tile_lon, tile_lat, embedding, crs, transform) in enumerate(
2017+
tiles
2018+
):
2019+
# Use centralized path construction from registry
2020+
zarr_rel_path = tile_to_zarr_path(tile_lon, tile_lat, year)
2021+
output_path = output_dir / EMBEDDINGS_DIR_NAME / zarr_rel_path
2022+
output_path.parent.mkdir(parents=True, exist_ok=True)
2023+
2024+
# Update progress to show we're starting this file
2025+
if progress_callback:
2026+
export_progress = int(50 + (i / total_tiles) * 50)
2027+
progress_callback(
2028+
export_progress, 100, f"Creating {output_path.name}..."
2029+
)
2030+
2031+
# Select bands
2032+
if bands is not None:
2033+
data = embedding[:, :, bands].copy()
2034+
band_count = len(bands)
2035+
else:
2036+
data = embedding.copy()
2037+
band_count = 128
2038+
2039+
# Get dimensions for GeoTIFF
2040+
height, width = data.shape[:2]
2041+
2042+
# Select bands
2043+
if bands is not None:
2044+
data = embedding[:, :, bands].copy()
2045+
band_count = len(bands)
2046+
else:
2047+
data = embedding.copy()
2048+
band_count = 128
2049+
2050+
# Get dimensions for GeoTIFF
2051+
height, width = data.shape[:2]
2052+
ds = xr.Dataset(
2053+
{"embedding": (("y", "x", "band"), data)},
2054+
coords={
2055+
"y": np.arange(height),
2056+
"x": np.arange(width),
2057+
"band": np.arange(band_count),
2058+
},
2059+
attrs={
2060+
"TESSERA_DATASET_VERSION": self.dataset_version,
2061+
"TESSERA_YEAR": year,
2062+
"TESSERA_TILE_LAT": f"{tile_lat:.2f}",
2063+
"TESSERA_TILE_LON": f"{tile_lon:.2f}",
2064+
"TESSERA_DESCRIPTION": "GeoTessera satellite embedding tile",
2065+
"GEOTESSERA_VERSION": __version__,
2066+
},
2067+
)
2068+
2069+
x_coords = [transform.c + (i + 0.5) * transform.a for i in range(width)]
2070+
y_coords = [transform.f + (j + 0.5) * transform.e for j in range(height)]
2071+
2072+
# Add band descriptions
2073+
if bands is not None:
2074+
output_bands = [f"Tessera_Band_{band_idx}" for band_idx in bands]
2075+
else:
2076+
output_bands = [f"Tessera_Band_{j}" for j in range(128)]
2077+
2078+
ds = ds.assign_coords(
2079+
x=("x", x_coords),
2080+
y=("y", y_coords),
2081+
band=("band", np.array(output_bands, dtype=np.dtypes.StringDType())),
2082+
)
2083+
ds = (
2084+
ds.rio.write_crs(crs)
2085+
.rio.set_spatial_dims(x_dim="x", y_dim="y")
2086+
.rio.write_coordinate_system()
2087+
)
2088+
ds = ds.rio.write_transform(transform)
2089+
2090+
ds.to_zarr(output_path, zarr_format=3)
2091+
created_files.append(str(output_path))
2092+
2093+
# Update progress for zarr export phase
2094+
if progress_callback:
2095+
# Phase 2: Exporting zarr (50-100% of total progress)
2096+
export_progress = int(50 + ((i + 1) / total_tiles) * 50)
2097+
progress_callback(
2098+
export_progress,
2099+
100,
2100+
f"Exported {output_path.name} ({i + 1}/{total_tiles})",
2101+
)
2102+
2103+
if progress_callback:
2104+
progress_callback(
2105+
100, 100, f"Completed! Exported {len(created_files)} zarr files"
2106+
)
2107+
2108+
self.logger.info(f"Exported {len(created_files)} zarr files to {output_dir}")
2109+
return created_files
2110+
18412111
def apply_pca_to_embeddings(
18422112
self,
18432113
embeddings: List[Tuple[int, float, float, np.ndarray, object, object]],

geotessera/registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,21 @@ def tile_to_geotiff_path(lon: float, lat: float, year: int) -> str:
276276
return f"{year}/{grid_name}/{grid_name}_{year}.tiff"
277277

278278

279+
def tile_to_zarr_path(lon: float, lat: float, year: int) -> str:
280+
"""Generate GeoTIFF file path for a tile.
281+
282+
Args:
283+
lon: Tile center longitude
284+
lat: Tile center latitude
285+
year: Year of embeddings
286+
287+
Returns:
288+
str: Relative path like "{year}/grid_{lon}_{lat}/grid_{lon}_{lat}_{year}.zarr"
289+
"""
290+
grid_name = tile_to_grid_name(lon, lat)
291+
return f"{year}/{grid_name}/{grid_name}_{year}.zarr"
292+
293+
279294
def tile_to_landmask_filename(lon: float, lat: float) -> str:
280295
"""Generate landmask filename for a tile.
281296

0 commit comments

Comments
 (0)