Skip to content

Commit cfa3a9d

Browse files
export area of interest, track state with sqlite (#3)
* add command to export area of interst, track state in sqlite db * set task id in db * add simple task tracking with sqlite * remove job_name
1 parent e37e18f commit cfa3a9d

7 files changed

Lines changed: 264 additions & 3 deletions

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,7 @@ __marimo__/
208208

209209
# uv
210210
.python-version
211-
uv.lock
211+
uv.lock
212+
213+
# sqlite
214+
*.db

aef_export/cli.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import click
2+
import json
23

3-
from aef_export.embeddings import export_image
4+
from aef_export.embeddings import export_image, export_aoi
45
from aef_export.coverage import export_image_collection
56
from aef_export.settings import get_settings
7+
from aef_export.task_tracking import update_db_state, get_task_summary, BillingTier
68
from aef_export.utils import initialize_ee
79

810

@@ -54,3 +56,56 @@ def image(
5456
initialize_ee(settings.google_cloud_project)
5557
task_id = export_image(image_id, gcs_bucket_name, gcs_key_prefix, quantize)
5658
click.echo(f"Task id: {task_id}")
59+
60+
61+
@app.command()
62+
@click.argument(
63+
"geojson_filepath", type=click.Path(exists=True, file_okay=True, dir_okay=False)
64+
)
65+
@click.argument("bq_dataset_name")
66+
@click.argument("bq_table_name")
67+
@click.argument("gcs_bucket_name")
68+
@click.option("--limit", type=int, required=False, default=None)
69+
def aoi(
70+
geojson_filepath: str,
71+
bq_dataset_name: str,
72+
bq_table_name: str,
73+
gcs_bucket_name: str,
74+
limit: int | None = None,
75+
):
76+
settings = get_settings()
77+
initialize_ee(settings.google_cloud_project)
78+
79+
with open(geojson_filepath) as f:
80+
data = json.load(f)
81+
82+
if data["type"] == "Feature":
83+
polygon = data["geometry"]
84+
elif data["type"] == "Polygon":
85+
polygon = data
86+
else:
87+
raise ValueError(
88+
"Input file must be a geojson polygon, either geometry or feature"
89+
)
90+
91+
export_aoi(polygon, bq_dataset_name, bq_table_name, gcs_bucket_name, limit)
92+
93+
94+
@app.group()
95+
def db():
96+
pass
97+
98+
99+
@db.command()
100+
def update_task_status():
101+
settings = get_settings()
102+
initialize_ee(settings.google_cloud_project)
103+
update_db_state()
104+
105+
106+
@db.command()
107+
@click.option(
108+
"--billing-tier", type=click.Choice(BillingTier), default=BillingTier.tier1
109+
)
110+
def summarize(billing_tier: BillingTier = BillingTier.tier1):
111+
click.echo(json.dumps(get_task_summary(billing_tier)))

aef_export/coverage.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
import ee
21
import uuid
2+
from typing import Generator
3+
4+
import ee
5+
from google.cloud import bigquery
6+
from shapely.geometry import shape
37

48
from aef_export.utils import set_workload_tag
59

@@ -77,3 +81,27 @@ def export_image_collection(
7781
task.start()
7882

7983
return task.id
84+
85+
86+
def query_coverage(
87+
geojson_geometry: dict,
88+
bq_dataset_name: str,
89+
bq_table_name: str,
90+
limit: int | None = None,
91+
) -> Generator[dict, None, None]:
92+
geom = shape(geojson_geometry)
93+
94+
bq_client = bigquery.Client()
95+
query = f"""
96+
SELECT
97+
`system:id` as system_id,
98+
SPLIT(start_date, '-')[OFFSET(0)] as year,
99+
UTM_ZONE as utm_zone
100+
FROM {bq_dataset_name}.{bq_table_name}
101+
WHERE ST_Intersects(ST_GEOGFROMTEXT('{geom.wkt}'), geo)
102+
"""
103+
if limit:
104+
query += f" LIMIT {limit}"
105+
query_job = bq_client.query(query)
106+
for row in query_job.result():
107+
yield dict(row)

aef_export/embeddings.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import ee
22

33
from aef_export.utils import set_workload_tag
4+
from aef_export.sqlite import init_database, Row, insert_row
5+
from aef_export.coverage import query_coverage
46

57

68
def _quantize_embeddings(image: ee.Image) -> ee.Image:
@@ -71,3 +73,40 @@ def export_image(
7173
task.start()
7274

7375
return task.id
76+
77+
78+
def export_aoi(
79+
geojson_geometry: dict,
80+
bq_dataset_name: str,
81+
bq_table_name: str,
82+
gcs_bucket_name: str,
83+
limit: int | None = None,
84+
):
85+
init_database()
86+
87+
rows_to_export = query_coverage(
88+
geojson_geometry, bq_dataset_name, bq_table_name, limit
89+
)
90+
for row in rows_to_export:
91+
# Build the key prefix.
92+
system_id = row["system_id"]
93+
year = row["year"]
94+
utm_zone = row["utm_zone"]
95+
key_prefix = "/".join([year, utm_zone, system_id.split("/")[-1]]) + "/"
96+
97+
# Start the export.
98+
# TODO: Protect against 3000+ tasks in the queue.
99+
task_id = export_image(system_id, gcs_bucket_name, key_prefix, quantize=True)
100+
print("Submitting task: ", task_id)
101+
102+
# Insert record of this row into sqlite
103+
row = Row(
104+
task_id=task_id,
105+
eecu_seconds=None,
106+
runtime_seconds=None,
107+
status="queued",
108+
image_id=system_id,
109+
year=year,
110+
s3_path=f"s3://{gcs_bucket_name}/{key_prefix}",
111+
)
112+
insert_row(row)

aef_export/sqlite.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from dataclasses import dataclass, asdict
2+
import sqlite3
3+
4+
5+
DATABASE_NAME = "sqlite_aef_export.db"
6+
7+
8+
@dataclass
9+
class Row:
10+
task_id: str
11+
eecu_seconds: float | None
12+
runtime_seconds: float | None
13+
status: str
14+
image_id: str
15+
year: int
16+
s3_path: str
17+
18+
19+
def get_connection():
20+
return sqlite3.connect(DATABASE_NAME)
21+
22+
23+
def init_database():
24+
cur = get_connection()
25+
cur.execute("""
26+
CREATE TABLE IF NOT EXISTS exports(
27+
id INTEGER PRIMARY KEY AUTOINCREMENT,
28+
task_id VARCHAR,
29+
eecu_seconds FLOAT,
30+
runtime_seconds FLOAT,
31+
status VARCHAR,
32+
image_id VARCHAR,
33+
year INTEGER,
34+
s3_path VARCHAR
35+
)
36+
""")
37+
38+
39+
def insert_row(row: Row):
40+
d = asdict(row)
41+
columns = ", ".join(d.keys())
42+
placeholders = ", ".join(["?" for _ in d.values()])
43+
query = f"INSERT INTO exports ({columns}) VALUES ({placeholders})"
44+
45+
cur = get_connection()
46+
cur.execute(query, tuple(d.values()))
47+
cur.commit()
48+
49+
50+
def update_row(
51+
task_id: str,
52+
status: str,
53+
eecu_seconds: float | None = None,
54+
runtime_seconds: float | None = None,
55+
):
56+
query = "UPDATE exports SET status = ?, eecu_seconds = ?, runtime_seconds = ? WHERE task_id = ?"
57+
cur = get_connection()
58+
cur.execute(query, (status, eecu_seconds, runtime_seconds, task_id))
59+
cur.commit()
60+
61+
62+
def get_summary():
63+
query = """
64+
SELECT
65+
status,
66+
COUNT(*) as count,
67+
SUM(eecu_seconds) as eecu_seconds,
68+
AVG(runtime_seconds) as avg_runtime_seconds
69+
FROM exports
70+
GROUP BY status;
71+
"""
72+
conn = get_connection()
73+
conn.row_factory = sqlite3.Row
74+
cur = conn.cursor()
75+
cur.execute(query)
76+
resp = cur.fetchall()
77+
return [dict(row) for row in resp]

aef_export/task_tracking.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from datetime import datetime
2+
from enum import Enum
3+
from typing import Generator
4+
5+
import ee
6+
7+
from aef_export.sqlite import update_row, get_summary
8+
9+
10+
class BillingTier(Enum):
11+
tier1 = "tier1"
12+
tier2 = "tier2"
13+
tier3 = "tier3"
14+
15+
16+
def list_tasks() -> Generator[dict, None, None]:
17+
tasks = ee.data.listOperations()
18+
for task in tasks:
19+
if task["metadata"]["type"] != "EXPORT_IMAGE":
20+
continue
21+
yield task
22+
23+
24+
def update_db_state():
25+
for task in list_tasks():
26+
task_state = task["metadata"]["state"]
27+
task_id = task["name"].split("/")[-1]
28+
eecu_seconds = None
29+
duration_seconds = None
30+
if task_state == "SUCCEEDED":
31+
eecu_seconds = task["metadata"]["batchEecuUsageSeconds"]
32+
start_time = datetime.fromisoformat(task["metadata"]["startTime"])
33+
end_time = datetime.fromisoformat(task["metadata"]["endTime"])
34+
duration_seconds = (end_time - start_time).total_seconds()
35+
36+
update_row(task_id, task_state, eecu_seconds, duration_seconds)
37+
38+
39+
def get_task_summary(tier: BillingTier) -> list[dict]:
40+
billing_tiers = {
41+
BillingTier.tier1: 0.40,
42+
BillingTier.tier2: 0.28,
43+
BillingTier.tier3: 0.16,
44+
}
45+
46+
summary = get_summary()
47+
48+
out_rows = []
49+
for row in summary:
50+
if eecu_seconds := row.get("eecu_seconds"):
51+
eecu_hours = eecu_seconds / 3600
52+
row["compute_cost"] = billing_tiers[tier] * eecu_hours
53+
54+
if avg_runtime_seconds := row.pop("avg_runtime_seconds"):
55+
row["avg_runtime_minutes"] = avg_runtime_seconds / 60
56+
out_rows.append(row)
57+
return out_rows

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ build-backend = "setuptools.build_meta"
1616

1717
[dependency-groups]
1818
dev = [
19+
"google-cloud-bigquery>=3.37.0",
1920
"pre-commit>=4.3.0",
2021
"pytest>=8.4.2",
2122
"ruff>=0.12.12",
23+
"shapely>=2.1.1",
2224
]
2325

2426
[project.scripts]

0 commit comments

Comments
 (0)