Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions API/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@
logging.basicConfig(
level=logging.getLevelName(os.getenv("LOG_LEVEL", "INFO")),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
force=True,
)
logger = logging.getLogger(__name__)

for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(
level=logging.getLevelName(os.getenv("LOG_LEVEL", "INFO")),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)

# Configure rate limiter
limiter = Limiter(key_func=get_remote_address)

Expand Down
31 changes: 17 additions & 14 deletions predictor/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from geomltoolkits.utils import merge_rasters, validate_polygon_geometries

from .prediction import run_prediction
from .utils import download_or_validate_model
from .utils import download_or_validate_model, morphological_cleaning


async def predict(
Expand All @@ -20,11 +20,10 @@ async def predict(
confidence=0.5,
area_threshold=3,
tolerance=0.5,
remove_metadata=True,
orthogonalize=True,
bbox=None,
geojson=None,
merge_input_images_to_single_image=False,
debug=False,
get_predictions_as_points=True,
ortho_skew_tolerance_deg=15,
ortho_max_angle_change_deg=15,
Expand All @@ -40,11 +39,10 @@ async def predict(
confidence: Threshold for filtering predictions (0-1)
area_threshold: Minimum polygon area in sqm (default: 3)
tolerance: Simplification tolerance in meters (default: 0.5)
remove_metadata: Whether to delete intermediate files after processing
orthogonalize: Whether to square building corners
bbox: Bounding box for prediction area
geojson: GeoJSON object for prediction area
merge_input_images_to_single_image: Whether to merge source images
debug: Whether to produce merged input images and keep intermediate files
get_predictions_as_points: Whether to generate point representations
ortho_skew_tolerance_deg: Max skew angle for orthogonalization (0-45)
ortho_max_angle_change_deg: Max corner adjustment angle (0-45)
Expand Down Expand Up @@ -79,10 +77,13 @@ async def predict(
crs="3857",
)

if merge_input_images_to_single_image:
merge_rasters(
image_download_path, os.path.join(meta_path, "merged_image_chips.tif")
)
if debug:
try:
merge_rasters(
image_download_path, os.path.join(meta_path, "merged_image_chips.tif")
)
except Exception as e:
print(f"Could not merge input images: {e}")

prediction_path = os.path.join(meta_path, "prediction")
os.makedirs(prediction_path, exist_ok=True)
Expand All @@ -101,16 +102,16 @@ async def predict(
os.makedirs(os.path.dirname(prediction_merged_mask_path), exist_ok=True)

merge_rasters(prediction_path, prediction_merged_mask_path)
prediction_poly_geojson_path = os.path.join(geojson_path, "predictions.geojson")
morphological_cleaning(prediction_merged_mask_path)
gdf = VectorizeMasks(
simplify_tolerance=tolerance,
min_area=area_threshold,
orthogonalize=orthogonalize,
tmp_dir=os.path.join(base_path, "tmp"),
ortho_skew_tolerance_deg=ortho_skew_tolerance_deg,
ortho_max_angle_change_deg=ortho_max_angle_change_deg,
).convert(
prediction_merged_mask_path, os.path.join(geojson_path, "predictions.geojson")
)
).convert(prediction_merged_mask_path, prediction_poly_geojson_path)
print(f"It took {round(time.time() - start)} sec to extract polygons")

if gdf.crs and gdf.crs != "EPSG:4326":
Expand All @@ -121,7 +122,7 @@ async def predict(

gdf["building"], gdf["source"] = "yes", "fAIr"

if remove_metadata:
if not debug:
shutil.rmtree(meta_path)

if get_predictions_as_points:
Expand All @@ -138,7 +139,9 @@ async def predict(

prediction_geojson_data = json.loads(gdf.to_json())
if make_geoms_valid:
prediction_geojson_data = validate_polygon_geometries(prediction_geojson_data, output_path=output_path if output_path else None)
prediction_geojson_data = validate_polygon_geometries(
prediction_geojson_data, output_path=prediction_poly_geojson_path
)

if not output_path:
shutil.rmtree(base_path)
Expand Down
4 changes: 0 additions & 4 deletions predictor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ class PredictionRequest(BaseModel):
default=False,
description="Whether to include predictions as points, this will create output geojson with extra points predictions",
)
remove_metadata: Optional[bool] = Field(
default=True,
description="Whether to remove intermediate metadata files after processing",
)
output_path: Optional[str] = Field(
default=None, description="Path to save the output files"
)
Expand Down
2 changes: 0 additions & 2 deletions predictor/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def predict_tflite(interpreter, image_paths, prediction_path, confidence):
cleaned_mask = clean_building_mask(
target_preds[idx],
confidence_threshold=confidence,
morph_size=3,
)

# Expand dimensions for save_mask
Expand All @@ -154,7 +153,6 @@ def predict_keras(model, image_paths, prediction_path, confidence):
cleaned_mask = clean_building_mask(
target_preds[idx],
confidence_threshold=confidence,
morph_size=3,
)

# Expand dimensions for save_mask
Expand Down
41 changes: 24 additions & 17 deletions predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import cv2
import numpy as np
import rasterio
import requests
from geomltoolkits.utils import georeference_tile
from PIL import Image
from skimage.segmentation import clear_border

IMAGE_SIZE = 256

Expand Down Expand Up @@ -179,15 +181,13 @@ def download_or_validate_model(model_path: str) -> str:
def clean_building_mask(
target_preds: np.ndarray,
confidence_threshold=0.5,
morph_size=3,
):
"""
Clean up building masks to remove thin connections and improve precision.

Args:
target_preds: Raw prediction or binary mask (0-1 range)
confidence_threshold: Base threshold for building/non-building (ignored if input is already binary)
morph_size: Size of morphological operation kernel
Returns:
Cleaned binary mask
"""
Expand All @@ -198,20 +198,27 @@ def clean_building_mask(
if is_binary:
print("Input is already binary, skipping confidence thresholding.")
binary_mask = target_preds.astype(np.uint8)
else:
binary_mask = np.where(target_preds > confidence_threshold, 1, 0).astype(
np.uint8
)

kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_size, morph_size))

# Apply opening to remove thin connections (erode then dilate)
opened_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)

# Final erosion to amplify differences between buildings
eroded_mask = cv2.erode(opened_mask, kernel, iterations=1)
return binary_mask
binary_mask = np.where(target_preds > confidence_threshold, 1, 0).astype(np.uint8)

return binary_mask


def morphological_cleaning(prediction_merged_mask_path):
with rasterio.open(prediction_merged_mask_path) as src:
img = src.read(1)
profile = src.profile.copy()
# lets do opening here
opening = cv2.morphologyEx(
img,
cv2.MORPH_OPEN,
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
iterations=2,
)
## remove the boundary objects
clean_img = clear_border(opening)

# Fill holes
filled_mask = cv2.morphologyEx(eroded_mask, cv2.MORPH_CLOSE, kernel)
with rasterio.open(prediction_merged_mask_path, "w", **profile) as dst:
dst.write(clean_img, 1)

return filled_mask
return prediction_merged_mask_path
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"opencv-python-headless>=4.10.0.84",
"pillow>=9.1.0",
"requests>=2.32.3",
"scikit-image>=0.24.0",
"tqdm>=4.67.0",
]

Expand All @@ -33,6 +34,11 @@ dev = [
load-test = [
"locust>=2.25.0",
]
notebook = [
"contextily>=1.6.2",
"matplotlib>=3.9.4",
]


[tool.commitizen]
name = "cz_conventional_commits"
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
tms_url=DEFAULT_OAM_TMS_MOSAIC,
zoom_level=20,
orthogonalize=True,
remove_metadata=False,
debug=False,
confidence=0.5,
)
)
Loading