Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ fastapi>=0.115
google-cloud-bigtable>=2.18
google-cloud-pubsub>=2.18
interrogate>=1.7
prometheus-client>=0.21.1
pydantic>=2.8
pytest>=8.2
python-dotenv>=1.0
Expand Down
43 changes: 35 additions & 8 deletions rslp/landsat_vessels/api_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any

import prometheus_client
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Response
from prometheus_client import make_asgi_app, multiprocess
from pydantic import BaseModel

from rslp.landsat_vessels.predict_pipeline import FormattedPrediction, predict_pipeline
from rslp.landsat_vessels.prom_metrics import time_operation, TimerOperations
from rslp.log_utils import get_logger
from rslp.utils.mp import init_mp

Expand Down Expand Up @@ -179,14 +183,15 @@ async def get_detections(info: LandsatRequest, response: Response) -> LandsatRes
)
try:
logger.info("Processing request with input data.")
json_data = predict_pipeline(
scene_id=info.scene_id,
scene_zip_path=info.scene_zip_path,
image_files=info.image_files,
json_path=info.json_path,
scratch_path=info.scratch_path,
crop_path=info.crop_path,
)
with time_operation(TimerOperations.TotalInferenceTime):
json_data = predict_pipeline(
scene_id=info.scene_id,
scene_zip_path=info.scene_zip_path,
image_files=info.image_files,
json_path=info.json_path,
scratch_path=info.scratch_path,
crop_path=info.crop_path,
)
return LandsatResponse(
status=StatusEnum.SUCCESS,
predictions=json_data,
Expand All @@ -207,6 +212,28 @@ async def get_detections(info: LandsatRequest, response: Response) -> LandsatRes
error_message=f"Unexpected error in prediction pipeline: {e}",
)

# Setup prometheus
def setup_prom_metrics() -> Any:
multi_proc_dir = os.environ.get('PROMETHEUS_MULTIPROC_DIR')
if not multi_proc_dir:
# If we're not using multiproc, then just use the default registry
return make_asgi_app()

# Otherwise setup prometheus multiproc mode.
if os.path.isdir(multi_proc_dir):
for multi_proc_file in os.scandir(multi_proc_dir):
os.remove(multi_proc_file.path)
else:
os.makedirs(multi_proc_dir)

# Create the multiproc collector, and set it up to be connected to fastapi
registry = prometheus_client.CollectorRegistry()
multiprocess.MultiProcessCollector(registry, path=multi_proc_dir)
return make_asgi_app(registry=registry)


app.mount("/metrics", setup_prom_metrics())


if __name__ == "__main__":
uvicorn.run(
Expand Down
183 changes: 101 additions & 82 deletions rslp/landsat_vessels/predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LOCAL_FILES_DATASET_CONFIG,
OUTPUT_LAYER_NAME,
)
from rslp.landsat_vessels.prom_metrics import time_operation, TimerOperations
from rslp.log_utils import get_logger
from rslp.utils.filter import NearInfraFilter
from rslp.utils.rslearn import (
Expand Down Expand Up @@ -126,14 +127,16 @@ def get_vessel_detections(
ignore_errors=False, apply_windows_args=apply_windows_args
),
)
materialize_dataset(ds_path, materialize_pipeline_args)
with time_operation(TimerOperations.MaterializeDataset):
materialize_dataset(ds_path, materialize_pipeline_args)

# Sanity check that the layer is completed.
if not window.is_layer_completed(LANDSAT_LAYER_NAME):
raise ValueError("landsat layer did not get materialized")

# Run object detector.
run_model_predict(DETECT_MODEL_CONFIG, ds_path)
with time_operation(TimerOperations.RunModelPredict):
run_model_predict(DETECT_MODEL_CONFIG, ds_path)

# Read the detections.
layer_dir = window.get_layer_dir(OUTPUT_LAYER_NAME)
Expand Down Expand Up @@ -438,29 +441,52 @@ def predict_pipeline(

# Determine which of the arguments to use and setup dataset and get SceneData
# appropriately.
scene_data = setup_dataset(
ds_path,
scene_id=scene_id,
scene_zip_path=scene_zip_path,
image_files=image_files,
window_path=window_path,
)
with time_operation(TimerOperations.SetupDataset):
scene_data = setup_dataset(
ds_path,
scene_id=scene_id,
scene_zip_path=scene_zip_path,
image_files=image_files,
window_path=window_path,
)

# Run pipeline.
print("run detector")
detections = get_vessel_detections(ds_path, scene_data)
print("run classifier")
detections = run_classifier(ds_path, detections=detections, scene_data=scene_data)
with time_operation(TimerOperations.GetVesselDetections):
detections = get_vessel_detections(ds_path, scene_data)
with time_operation(TimerOperations.RunClassifier):
detections = run_classifier(ds_path, detections=detections, scene_data=scene_data)

with time_operation(TimerOperations.BuildPredictionsAndCrops):
json_data = _build_predictions_and_crops(detections, crop_path)

if json_path:
json_upath = UPath(json_path)
with json_upath.open("w") as f:
json.dump(json_data, f)

if geojson_path:
geojson_features = [d.to_feature() for d in detections]
geojson_upath = UPath(geojson_path)
with geojson_upath.open("w") as f:
json.dump(
{
"type": "FeatureCollection",
"properties": {},
"features": geojson_features,
},
f,
)

return json_data

def _build_predictions_and_crops(detections: list[VesselDetection], crop_path: str | None) -> list[FormattedPrediction]:
# Write JSON and crops.
if crop_path:
crop_upath = UPath(crop_path)
crop_upath.mkdir(parents=True, exist_ok=True)

json_data = []
geojson_features = []
near_infra_filter = NearInfraFilter(infra_distance_threshold=INFRA_THRESHOLD_KM)
raster_format = GeotiffRasterFormat()
infra_detections = 0
for idx, detection in enumerate(detections):
# Apply near infra filter (True -> filter out, False -> keep)
Expand All @@ -469,58 +495,12 @@ def predict_pipeline(
infra_detections += 1
continue

# Load crops from the window directory for writing output PNGs.
# We create two PNGs:
# - b8.png: just has B8 (panchromatic band).
# - rgb.png: true color with pan-sharpening. The RGB is from B4, B3, and B2
# respectively while B8 is used for pan-sharpening.
images = {}
crop_window: Window = detection.metadata["crop_window"]
if crop_window is None:
raise ValueError("Crop window is None")
for band in ["B2", "B3", "B4", "B8"]:
raster_dir = crop_window.get_raster_dir(LANDSAT_LAYER_NAME, [band])

# Different bands are in different resolutions so get the bounds from the
# raster since anyway we just want to read the whole raster.
band_bounds = raster_format.get_raster_bounds(raster_dir)
image = raster_format.decode_raster(raster_dir, band_bounds)
if image.shape[0] != 1:
raise ValueError(
f"expected single-band image for {band} but got {image.shape[0]} bands"
)

images[band] = image[0, :, :]

# Apply simple pan-sharpening for the RGB.
# This is just linearly scaling RGB bands to add up to B8, which is captured at
# a higher resolution.
for band in ["B2", "B3", "B4"]:
sharp = images[band].astype(np.int32)
sharp = sharp.repeat(repeats=2, axis=0).repeat(repeats=2, axis=1)
images[band + "_sharp"] = sharp
total = np.clip(
(images["B2_sharp"] + images["B3_sharp"] + images["B4_sharp"]) // 3, 1, 255
)
for band in ["B2", "B3", "B4"]:
images[band + "_sharp"] = np.clip(
images[band + "_sharp"] * images["B8"] // total, 0, 255
).astype(np.uint8)
rgb = np.stack(
[images["B4_sharp"], images["B3_sharp"], images["B2_sharp"]], axis=2
)

if crop_path:
rgb_fname = crop_upath / f"{idx}_rgb.png"
with rgb_fname.open("wb") as f:
Image.fromarray(rgb).save(f, format="PNG")

b8_fname = crop_upath / f"{idx}_b8.png"
with b8_fname.open("wb") as f:
Image.fromarray(images["B8"]).save(f, format="PNG")
crops = _write_detection_crop(detection, crop_upath, idx)
rgb_fname = crops.rgb_fname
b8_fname = crops.b8_fname
else:
rgb_fname = ""
b8_fname = ""
rgb_fname, b8_fname = "", ""

json_data.append(
FormattedPrediction(
Expand All @@ -531,23 +511,62 @@ def predict_pipeline(
b8_fname=str(b8_fname),
),
)
geojson_features.append(detection.to_feature())

if json_path:
json_upath = UPath(json_path)
with json_upath.open("w") as f:
json.dump(json_data, f)
return json_data

if geojson_path:
geojson_upath = UPath(geojson_path)
with geojson_upath.open("w") as f:
json.dump(
{
"type": "FeatureCollection",
"properties": {},
"features": geojson_features,
},
f,
@dataclass
class DetectionCrop:
rgb_fname: UPath
b8_fname: UPath

def _write_detection_crop(detection: VesselDetection, crop_upath: UPath, idx: int) -> DetectionCrop:
# Load crops from the window directory for writing output PNGs.
# We create two PNGs:
# - b8.png: just has B8 (panchromatic band).
# - rgb.png: true color with pan-sharpening. The RGB is from B4, B3, and B2
# respectively while B8 is used for pan-sharpening.
images = {}
crop_window: Window = detection.metadata["crop_window"]
if crop_window is None:
raise ValueError("Crop window is None")
for band in ["B2", "B3", "B4", "B8"]:
raster_dir = crop_window.get_raster_dir(LANDSAT_LAYER_NAME, [band])

# Different bands are in different resolutions so get the bounds from the
# raster since anyway we just want to read the whole raster.
raster_format = GeotiffRasterFormat()
band_bounds = raster_format.get_raster_bounds(raster_dir)
image = raster_format.decode_raster(raster_dir, band_bounds)
if image.shape[0] != 1:
raise ValueError(
f"expected single-band image for {band} but got {image.shape[0]} bands"
)

return json_data
images[band] = image[0, :, :]

# Apply simple pan-sharpening for the RGB.
# This is just linearly scaling RGB bands to add up to B8, which is captured at
# a higher resolution.
for band in ["B2", "B3", "B4"]:
sharp = images[band].astype(np.int32)
sharp = sharp.repeat(repeats=2, axis=0).repeat(repeats=2, axis=1)
images[band + "_sharp"] = sharp
total = np.clip(
(images["B2_sharp"] + images["B3_sharp"] + images["B4_sharp"]) // 3, 1, 255
)
for band in ["B2", "B3", "B4"]:
images[band + "_sharp"] = np.clip(
images[band + "_sharp"] * images["B8"] // total, 0, 255
).astype(np.uint8)
rgb = np.stack(
[images["B4_sharp"], images["B3_sharp"], images["B2_sharp"]], axis=2
)

rgb_fname = crop_upath / f"{idx}_rgb.png"
with rgb_fname.open("wb") as f:
Image.fromarray(rgb).save(f, format="PNG")

b8_fname = crop_upath / f"{idx}_b8.png"
with b8_fname.open("wb") as f:
Image.fromarray(images["B8"]).save(f, format="PNG")

return DetectionCrop(rgb_fname=rgb_fname, b8_fname=b8_fname)
27 changes: 27 additions & 0 deletions rslp/landsat_vessels/prom_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Definitions for prometheus metrics that are captured during inference to report on performance of the landsat
detection service.
"""

from enum import StrEnum

from prometheus_client import Histogram
from prometheus_client.context_managers import Timer

request_timer = Histogram(
"landsat_rslearn_timer", "Timers for inference requests", ["operation"]
)


class TimerOperations(StrEnum):
TotalInferenceTime = "TotalInferenceTime"
SetupDataset = "SetupDataset"
MaterializeDataset = "MaterializeDataset"
RunModelPredict = "RunModelPredict"
GetVesselDetections = "GetVesselDetections"
RunClassifier = "RunClassifier"
BuildPredictionsAndCrops = "BuildPredictionsAndCrops"


def time_operation(operation: TimerOperations) -> Timer:
return request_timer.labels(operation=operation.value).time()