Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions aef_export/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import click
import json

import geopandas as gpd

from aef_export.embeddings import export_image, export_aoi
from aef_export.coverage import export_image_collection
from aef_export.settings import get_settings
from aef_export.task_tracking import update_db_state, get_task_summary, BillingTier
from aef_export.utils import initialize_ee
from aef_export.ftw.export_labels import export_labels_for_year


@click.group()
Expand Down Expand Up @@ -109,3 +112,33 @@ def update_task_status():
)
def summarize(billing_tier: BillingTier = BillingTier.tier1):
click.echo(json.dumps(get_task_summary(billing_tier)))


@app.group()
def ftw():
pass


@ftw.command()
@click.argument("infile")
@click.argument("outfile")
@click.option("--year", type=int, required=True)
@click.option("--bucket-name", type=str, required=True)
@click.option("--num-threads", type=int, required=False, default=None)
@click.option("--width", type=int, required=False, default=None)
@click.option("--height", type=int, required=False, default=None)
def process_labels(
infile: str,
outfile: str,
year: int,
bucket_name: str,
num_threads: int | None = None,
width: int | None = None,
height: int | None = None,
):
settings = get_settings()
initialize_ee(settings.google_cloud_project)

gdf = gpd.read_parquet(infile)
out_df = export_labels_for_year(gdf, year, bucket_name, num_threads, width, height)
out_df.to_parquet(outfile)
Empty file added aef_export/ftw/__init__.py
Empty file.
165 changes: 165 additions & 0 deletions aef_export/ftw/export_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import ee
import io
from functools import lru_cache
import concurrent.futures

import affine
import numpy as np
from pyproj import Transformer
from shapely.geometry import Polygon
import utm
import geopandas as gpd
from google.cloud import storage

from aef_export.embeddings import _quantize_embeddings


@lru_cache()
def get_gcs_client():
return storage.Client()


@lru_cache()
def _get_transformer(out_epsg: str, in_epsg: str = "EPSG:4326") -> Transformer:
"""Get and cache a pyproj transformer."""
return Transformer.from_crs(in_epsg, out_epsg, always_xy=True)


def _utm_zone_from_latlon(lat: float, lon: float) -> str:
_, _, zone_number, _ = utm.from_latlon(lat, lon)

if lat < 0:
return f"EPSG:327{zone_number}"
else:
return f"EPSG:326{zone_number}"


def _fetch_array(
year: int,
geom: Polygon,
spatial_res_meters: int = 10,
width: int | None = None,
height: int | None = None,
) -> np.ndarray:
# Find the right EE image
images = (
ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
.filterBounds(ee.Geometry(geom.__geo_interface__))
.filterDate(ee.Date(f"{year}-01-01"), ee.Date(f"{year + 1}-01-01"))
)
image = images.first()
image = _quantize_embeddings(image)

# Build an export request
centroid = geom.centroid
utm_zone = _utm_zone_from_latlon(centroid.y, centroid.x)
transformer = _get_transformer(utm_zone)
xmin, ymin, xmax, ymax = transformer.transform_bounds(*geom.bounds)

if not width:
width = int((xmax - xmin) / spatial_res_meters)

if not height:
height = int((ymax - ymin) / spatial_res_meters)

transform = affine.Affine.translation(xmin, ymax) * affine.Affine.scale(
(xmax - xmin) / width, (ymin - ymax) / height
)

request = {
"expression": image,
"fileFormat": "NUMPY_NDARRAY",
"grid": {
"dimensions": {"width": width, "height": height},
"affineTransform": {
"scaleX": transform.a,
"shearX": transform.b,
"translateX": transform.c,
"shearY": transform.d,
"scaleY": transform.e,
"translateY": transform.f,
},
"crsCode": utm_zone,
},
}

# Send the export request
arr = ee.data.computePixels(request)
restructured_arr = arr.view((arr.dtype[0], len(arr.dtype.names)))
return restructured_arr


def _upload_numpy_array_to_gcs(
bucket_name: str, destination_blob_name: str, numpy_array: np.ndarray
) -> str:
"""Uploads a NumPy array to a GCS bucket.

Args:
bucket_name (str): The name of your GCS bucket.
destination_blob_name (str): The name of the blob in the bucket (e.g., 'my_array.npy').
numpy_array (np.ndarray): The NumPy array to upload.
"""
bucket = get_gcs_client().bucket(bucket_name)
blob = bucket.blob(destination_blob_name)

# Serialize the NumPy array to bytes
buffer = io.BytesIO()
np.save(buffer, numpy_array)
buffer.seek(0) # Rewind the buffer to the beginning

# Upload the bytes to GCS
blob.upload_from_file(buffer, content_type="application/octet-stream")

dest_path = f"gs://{bucket_name}/{destination_blob_name}"
print(f"Uploaded to {dest_path}")

return dest_path


def _fetch_and_upload_embeddings(
geom: Polygon,
country: str,
aoi_id: str,
year: int,
bucket_name: str,
width: int | None = None,
height: int | None = None,
) -> str:
embeddings = _fetch_array(year, geom, width=width, height=height)
key = f"chips/{country}/{year}/{aoi_id}.npy"
return _upload_numpy_array_to_gcs(bucket_name, key, embeddings)


def export_labels_for_year(
gdf: gpd.GeoDataFrame,
year: int,
bucket_name: str,
max_workers: int | None = None,
width: int | None = None,
height: int | None = None,
) -> gpd.GeoDataFrame:
with concurrent.futures.ThreadPoolExecutor(max_workers) as exec:
tasks = {}
for row in gdf.itertuples():
future = exec.submit(
_fetch_and_upload_embeddings,
row.geometry,
row.country,
row.aoi_id,
year,
bucket_name,
width,
height,
)
tasks[future] = row.Index

for future in concurrent.futures.as_completed(tasks):
try:
gcs_path = future.result()
except Exception as exc:
print(exc)
continue

gdf.loc[tasks[future], "gcs_path"] = gcs_path

return gdf
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"affine>=2.4.0",
"click>=8.1.8",
"earthengine-api>=1.6.6",
"fastparquet>=2024.11.0",
"geopandas>=1.1.1",
"google-cloud-storage>=3.3.1",
"pandas>=2.3.3",
"pyarrow>=21.0.0",
"pydantic-settings>=2.10.1",
"pyproj>=3.7.2",
"utm>=0.8.1",
]

[build-system]
Expand Down