|
12 | 12 | import numpy as np |
13 | 13 | import geopandas as gpd |
14 | 14 |
|
15 | | -from .registry import Registry, EMBEDDINGS_DIR_NAME, tile_to_geotiff_path |
| 15 | +from .registry import ( |
| 16 | + Registry, |
| 17 | + EMBEDDINGS_DIR_NAME, |
| 18 | + tile_to_geotiff_path, |
| 19 | + tile_to_zarr_path, |
| 20 | +) |
16 | 21 |
|
17 | 22 | try: |
18 | 23 | import importlib.metadata |
@@ -63,7 +68,7 @@ class GeoTessera: |
63 | 68 | Core functionality: |
64 | 69 | - Download tiles to local embeddings_dir |
65 | 70 | - Sample embeddings at point locations from local tiles |
66 | | - - Export individual tiles as GeoTIFF files with correct metadata |
| 71 | + - Export individual tiles as GeoTIFF files or zarr archives with correct metadata |
67 | 72 | - Manage registry and data access |
68 | 73 |
|
69 | 74 | Typical workflows: |
@@ -1838,6 +1843,271 @@ def merge_geotiffs_to_mosaic( |
1838 | 1843 |
|
1839 | 1844 | return str(output_path) |
1840 | 1845 |
|
| 1846 | + def export_embedding_zarr( |
| 1847 | + self, |
| 1848 | + lon: float, |
| 1849 | + lat: float, |
| 1850 | + output_path: Union[str, Path], |
| 1851 | + year: int = 2024, |
| 1852 | + bands: Optional[List[int]] = None, |
| 1853 | + ) -> str: |
| 1854 | + """Export a single embedding tile as a zarr archive with native UTM projection. |
| 1855 | +
|
| 1856 | + Args: |
| 1857 | + lon: Tile center longitude |
| 1858 | + lat: Tile center latitude |
| 1859 | + output_path: Output path for zarr file |
| 1860 | + year: Year of embeddings to export |
| 1861 | + bands: List of band indices to export (None = all 128 bands) |
| 1862 | + compress: Compression method for GeoTIFF |
| 1863 | +
|
| 1864 | + Returns: |
| 1865 | + Path to created zarr file |
| 1866 | +
|
| 1867 | + Raises: |
| 1868 | + ImportError: If xarray, rioxarray, zarr or dask is not available |
| 1869 | + RuntimeError: If landmask tile or embedding data cannot be fetched |
| 1870 | + FileNotFoundError: If registry files are missing |
| 1871 | + """ |
| 1872 | + try: |
| 1873 | + import xarray as xr |
| 1874 | + import rioxarray as rxr # noqa: F401 - needed for .rio accessor |
| 1875 | + import zarr # noqa: F401 - needed for .to_zarr() |
| 1876 | + import dask # noqa: F401 - needed for chunking |
| 1877 | + |
| 1878 | + import warnings |
| 1879 | + |
| 1880 | + warnings.filterwarnings("ignore", category=UserWarning) |
| 1881 | + except ImportError: |
| 1882 | + raise ImportError( |
| 1883 | + "saving to zarr requires xarray, rioxarray, zarr and dask" |
| 1884 | + ) |
| 1885 | + |
| 1886 | + output_path = Path(output_path) |
| 1887 | + output_path.parent.mkdir(parents=True, exist_ok=True) |
| 1888 | + |
| 1889 | + # Fetch single tile with CRS info |
| 1890 | + embedding, crs, transform = self.fetch_embedding(lon, lat, year) |
| 1891 | + |
| 1892 | + # Select bands |
| 1893 | + if bands is not None: |
| 1894 | + data = embedding[:, :, bands].copy() |
| 1895 | + band_count = len(bands) |
| 1896 | + else: |
| 1897 | + data = embedding.copy() |
| 1898 | + band_count = 128 |
| 1899 | + |
| 1900 | + # Get dimensions for GeoTIFF |
| 1901 | + height, width = data.shape[:2] |
| 1902 | + |
| 1903 | + ds = xr.Dataset( |
| 1904 | + {"embedding": (("y", "x", "band"), data)}, |
| 1905 | + coords={ |
| 1906 | + "y": np.arange(height), |
| 1907 | + "x": np.arange(width), |
| 1908 | + "band": np.arange(band_count), |
| 1909 | + }, |
| 1910 | + attrs={ |
| 1911 | + "TESSERA_DATASET_VERSION": self.dataset_version, |
| 1912 | + "TESSERA_YEAR": year, |
| 1913 | + "TESSERA_TILE_LAT": f"{lat:.2f}", |
| 1914 | + "TESSERA_TILE_LON": f"{lon:.2f}", |
| 1915 | + "TESSERA_DESCRIPTION": "GeoTessera satellite embedding tile", |
| 1916 | + "GEOTESSERA_VERSION": __version__, |
| 1917 | + }, |
| 1918 | + ) |
| 1919 | + |
| 1920 | + x_coords = [transform.c + (i + 0.5) * transform.a for i in range(width)] |
| 1921 | + y_coords = [transform.f + (j + 0.5) * transform.e for j in range(height)] |
| 1922 | + |
| 1923 | + # Add band descriptions |
| 1924 | + if bands is not None: |
| 1925 | + output_bands = [f"Tessera_Band_{band_idx}" for band_idx in bands] |
| 1926 | + else: |
| 1927 | + output_bands = [f"Tessera_Band_{j}" for j in range(128)] |
| 1928 | + |
| 1929 | + ds = ds.assign_coords( |
| 1930 | + x=("x", x_coords), |
| 1931 | + y=("y", y_coords), |
| 1932 | + band=("band", np.array(output_bands, dtype=np.dtypes.StringDType())), |
| 1933 | + ) |
| 1934 | + ds = ( |
| 1935 | + ds.rio.write_crs(crs) |
| 1936 | + .rio.set_spatial_dims(x_dim="x", y_dim="y") |
| 1937 | + .rio.write_coordinate_system() |
| 1938 | + ) |
| 1939 | + ds = ds.rio.write_transform(transform) |
| 1940 | + ds.to_zarr(output_path, zarr_format=3) |
| 1941 | + return str(output_path) |
| 1942 | + |
| 1943 | + def export_embedding_zarrs( |
| 1944 | + self, |
| 1945 | + tiles_to_fetch: Iterable[Tuple[int, float, float]], |
| 1946 | + output_dir: Union[str, Path], |
| 1947 | + bands: Optional[List[int]] = None, |
| 1948 | + progress_callback: Optional[callable] = None, |
| 1949 | + ) -> List[str]: |
| 1950 | + """Export all embedding tiles in bounding box as individual zarr files with native UTM projections. |
| 1951 | + The list of tiles to fetch can be obtained by registry.load_blocks_for_region(). |
| 1952 | +
|
| 1953 | + Args: |
| 1954 | + tiles_to_fetch: List of tiles to fetch as (year, tile_lon, tile_lat) tuples |
| 1955 | + output_dir: Directory to save GeoTIFF files |
| 1956 | + bands: List of band indices to export (None = all 128 bands) |
| 1957 | + compress: Compression method for GeoTIFF |
| 1958 | + progress_callback: Optional callback function(current, total) for progress tracking |
| 1959 | +
|
| 1960 | + Returns: |
| 1961 | + List of paths to created zarr files |
| 1962 | +
|
| 1963 | + Raises: |
| 1964 | + ImportError: If rasterio is not available |
| 1965 | + RuntimeError: If landmask tiles or embedding data cannot be fetched |
| 1966 | + FileNotFoundError: If registry files are missing |
| 1967 | + """ |
| 1968 | + try: |
| 1969 | + import xarray as xr |
| 1970 | + import rioxarray as rxr # noqa: F401 - needed for .rio accessor |
| 1971 | + import zarr # noqa: F401 - needed for .to_zarr() |
| 1972 | + import dask # noqa: F401 - needed for chunking |
| 1973 | + |
| 1974 | + import warnings |
| 1975 | + |
| 1976 | + warnings.filterwarnings("ignore", category=UserWarning) |
| 1977 | + except ImportError: |
| 1978 | + raise ImportError( |
| 1979 | + "saving to zarr requires xarray, rioxarray, zarr and dask" |
| 1980 | + ) |
| 1981 | + |
| 1982 | + output_dir = Path(output_dir) |
| 1983 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 1984 | + |
| 1985 | + # Create a wrapper callback to handle two-phase progress |
| 1986 | + def fetch_progress_callback(current: int, total: int, status: str = None): |
| 1987 | + # Phase 1: Fetching tiles (0-50% of total progress) |
| 1988 | + overall_progress = int((current / total) * 50) |
| 1989 | + display_status = status or f"Fetching tile {current}/{total}" |
| 1990 | + progress_callback(overall_progress, 100, display_status) |
| 1991 | + |
| 1992 | + # Fetch tiles with progress tracking |
| 1993 | + if progress_callback: |
| 1994 | + progress_callback(0, 100, "Loading registry blocks...") |
| 1995 | + |
| 1996 | + tiles = list( |
| 1997 | + self.fetch_embeddings( |
| 1998 | + tiles_to_fetch, fetch_progress_callback if progress_callback else None |
| 1999 | + ) |
| 2000 | + ) |
| 2001 | + if progress_callback: |
| 2002 | + total_tiles = len(tiles_to_fetch) |
| 2003 | + |
| 2004 | + if not tiles: |
| 2005 | + self.logger.warning("No tiles found in bounding box") |
| 2006 | + return [] |
| 2007 | + |
| 2008 | + if progress_callback: |
| 2009 | + progress_callback( |
| 2010 | + 50, 100, f"Fetched {total_tiles} tiles, starting GeoTIFF export..." |
| 2011 | + ) |
| 2012 | + |
| 2013 | + created_files = [] |
| 2014 | + |
| 2015 | + # Sequential zarr writing |
| 2016 | + for i, (year, tile_lon, tile_lat, embedding, crs, transform) in enumerate( |
| 2017 | + tiles |
| 2018 | + ): |
| 2019 | + # Use centralized path construction from registry |
| 2020 | + zarr_rel_path = tile_to_zarr_path(tile_lon, tile_lat, year) |
| 2021 | + output_path = output_dir / EMBEDDINGS_DIR_NAME / zarr_rel_path |
| 2022 | + output_path.parent.mkdir(parents=True, exist_ok=True) |
| 2023 | + |
| 2024 | + # Update progress to show we're starting this file |
| 2025 | + if progress_callback: |
| 2026 | + export_progress = int(50 + (i / total_tiles) * 50) |
| 2027 | + progress_callback( |
| 2028 | + export_progress, 100, f"Creating {output_path.name}..." |
| 2029 | + ) |
| 2030 | + |
| 2031 | + # Select bands |
| 2032 | + if bands is not None: |
| 2033 | + data = embedding[:, :, bands].copy() |
| 2034 | + band_count = len(bands) |
| 2035 | + else: |
| 2036 | + data = embedding.copy() |
| 2037 | + band_count = 128 |
| 2038 | + |
| 2039 | + # Get dimensions for GeoTIFF |
| 2040 | + height, width = data.shape[:2] |
| 2041 | + |
| 2042 | + # Select bands |
| 2043 | + if bands is not None: |
| 2044 | + data = embedding[:, :, bands].copy() |
| 2045 | + band_count = len(bands) |
| 2046 | + else: |
| 2047 | + data = embedding.copy() |
| 2048 | + band_count = 128 |
| 2049 | + |
| 2050 | + # Get dimensions for GeoTIFF |
| 2051 | + height, width = data.shape[:2] |
| 2052 | + ds = xr.Dataset( |
| 2053 | + {"embedding": (("y", "x", "band"), data)}, |
| 2054 | + coords={ |
| 2055 | + "y": np.arange(height), |
| 2056 | + "x": np.arange(width), |
| 2057 | + "band": np.arange(band_count), |
| 2058 | + }, |
| 2059 | + attrs={ |
| 2060 | + "TESSERA_DATASET_VERSION": self.dataset_version, |
| 2061 | + "TESSERA_YEAR": year, |
| 2062 | + "TESSERA_TILE_LAT": f"{tile_lat:.2f}", |
| 2063 | + "TESSERA_TILE_LON": f"{tile_lon:.2f}", |
| 2064 | + "TESSERA_DESCRIPTION": "GeoTessera satellite embedding tile", |
| 2065 | + "GEOTESSERA_VERSION": __version__, |
| 2066 | + }, |
| 2067 | + ) |
| 2068 | + |
| 2069 | + x_coords = [transform.c + (i + 0.5) * transform.a for i in range(width)] |
| 2070 | + y_coords = [transform.f + (j + 0.5) * transform.e for j in range(height)] |
| 2071 | + |
| 2072 | + # Add band descriptions |
| 2073 | + if bands is not None: |
| 2074 | + output_bands = [f"Tessera_Band_{band_idx}" for band_idx in bands] |
| 2075 | + else: |
| 2076 | + output_bands = [f"Tessera_Band_{j}" for j in range(128)] |
| 2077 | + |
| 2078 | + ds = ds.assign_coords( |
| 2079 | + x=("x", x_coords), |
| 2080 | + y=("y", y_coords), |
| 2081 | + band=("band", np.array(output_bands, dtype=np.dtypes.StringDType())), |
| 2082 | + ) |
| 2083 | + ds = ( |
| 2084 | + ds.rio.write_crs(crs) |
| 2085 | + .rio.set_spatial_dims(x_dim="x", y_dim="y") |
| 2086 | + .rio.write_coordinate_system() |
| 2087 | + ) |
| 2088 | + ds = ds.rio.write_transform(transform) |
| 2089 | + |
| 2090 | + ds.to_zarr(output_path, zarr_format=3) |
| 2091 | + created_files.append(str(output_path)) |
| 2092 | + |
| 2093 | + # Update progress for zarr export phase |
| 2094 | + if progress_callback: |
| 2095 | + # Phase 2: Exporting zarr (50-100% of total progress) |
| 2096 | + export_progress = int(50 + ((i + 1) / total_tiles) * 50) |
| 2097 | + progress_callback( |
| 2098 | + export_progress, |
| 2099 | + 100, |
| 2100 | + f"Exported {output_path.name} ({i + 1}/{total_tiles})", |
| 2101 | + ) |
| 2102 | + |
| 2103 | + if progress_callback: |
| 2104 | + progress_callback( |
| 2105 | + 100, 100, f"Completed! Exported {len(created_files)} zarr files" |
| 2106 | + ) |
| 2107 | + |
| 2108 | + self.logger.info(f"Exported {len(created_files)} zarr files to {output_dir}") |
| 2109 | + return created_files |
| 2110 | + |
1841 | 2111 | def apply_pca_to_embeddings( |
1842 | 2112 | self, |
1843 | 2113 | embeddings: List[Tuple[int, float, float, np.ndarray, object, object]], |
|
0 commit comments