Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 58 additions & 27 deletions predictor/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from .utils import (
clean_building_mask,
georeference_prediction_tiles,
open_images_keras,
open_images_pillow,
Expand Down Expand Up @@ -121,57 +122,83 @@ def predict_tflite(interpreter, image_paths, prediction_path, confidence):
# print(f"Model returns {num_classes} classes")
target_class = 1
target_preds = preds[..., target_class]
binary_masks = np.where(target_preds > confidence, 1, 0)
binary_masks = np.expand_dims(binary_masks, axis=-1)

for idx, path in enumerate(image_batch):

image_filename = Path(path).stem

# Clean the mask
cleaned_mask = clean_building_mask(
target_preds[idx],
confidence_threshold=confidence,
morph_size=3,
small_object_threshold=50,
)

# Expand dimensions for save_mask
cleaned_mask = np.expand_dims(cleaned_mask, axis=-1)

save_mask(
binary_masks[idx],
cleaned_mask,
str(f"{prediction_path}/{Path(path).stem}.png"),
)


def predict_keras(model, image_paths, prediction_path, confidence):

for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
images = open_images_keras(image_batch)
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3)
preds = model.predict(images)
num_classes = preds.shape[-1]
print(f"Model returns {num_classes} classes")

target_class = 1
target_preds = preds[..., target_class]
binary_masks = np.where(target_preds > confidence, 1, 0)
binary_masks = np.expand_dims(binary_masks, axis=-1)


for idx, path in enumerate(image_batch):
# Clean the mask
cleaned_mask = clean_building_mask(
target_preds[idx],
confidence_threshold=confidence,
morph_size=3,
small_object_threshold=50
)

# Expand dimensions for save_mask
cleaned_mask = np.expand_dims(cleaned_mask, axis=-1)

# Save the mask
save_mask(
binary_masks[idx],
cleaned_mask,
str(f"{prediction_path}/{Path(path).stem}.png"),
)


def predict_yolo(model, image_paths, prediction_path, confidence):
for idx in range(0, len(image_paths), BATCH_SIZE):
batch = image_paths[idx : idx + BATCH_SIZE]
for i, r in enumerate(
model.predict(batch, conf=confidence, imgsz=IMAGE_SIZE, verbose=False)
):
if hasattr(r, "masks") and r.masks is not None:
preds = (
r.masks.data.max(dim=0)[0].detach().cpu().numpy()
) # Combine masks and convert to numpy
# Get raw prediction mask
raw_mask = r.masks.data.max(dim=0)[0].detach().cpu().numpy()

# Clean the mask
cleaned_mask = clean_building_mask(
raw_mask,
confidence_threshold=confidence,
)

# Save the cleaned mask
save_mask(
cleaned_mask,
str(f"{prediction_path}/{Path(batch[i]).stem}.png")
)
else:
preds = np.zeros(
(
IMAGE_SIZE,
IMAGE_SIZE,
),
dtype=np.float32,
) # Default if no masks
save_mask(preds, str(f"{prediction_path}/{Path(batch[i]).stem}.png"))

# No detections, create empty mask
empty_mask = np.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=np.float32)
save_mask(
empty_mask,
str(f"{prediction_path}/{Path(batch[i]).stem}.png")
)

def predict_onnx(model_path, image_paths, prediction_path, confidence=0.25):
import cv2
Expand All @@ -186,8 +213,12 @@ def predict_onnx(model_path, image_paths, prediction_path, confidence=0.25):
mask_path = f"{prediction_path}/{Path(image_path).stem}.png"

if len(masks) > 0:
combined_mask = masks.max(axis=0) * 255 # Combine masks and scale to 255
result = Image.fromarray(combined_mask.astype(np.uint8))

combined_mask = masks.max(axis=0)
cleaned_mask = clean_building_mask(
combined_mask,
)
result = Image.fromarray((cleaned_mask * 255).astype(np.uint8))
result.save(mask_path)
else:
preds = np.zeros(
Expand Down Expand Up @@ -236,7 +267,7 @@ def run_prediction(
predict_tflite(model, image_paths, prediction_path, confidence)

elif model_type == "keras":
predict_keras(model, image_batch, confidence)
predict_keras(model, image_paths,prediction_path, confidence)

elif model_type == "yolo":
predict_yolo(model, image_paths, prediction_path, confidence)
Expand Down
61 changes: 61 additions & 0 deletions predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import List

import cv2
import numpy as np
import requests
from geomltoolkits.utils import georeference_tile
Expand Down Expand Up @@ -52,6 +53,19 @@ def save_mask(mask: np.ndarray, filename: str) -> None:
result = Image.fromarray(reshaped_mask.astype(np.uint8))
result.save(filename)

# with rasterio.open(
# filename,
# 'w',
# driver='GTiff',
# height=IMAGE_SIZE,
# width=IMAGE_SIZE,
# count=1,
# dtype=rasterio.float32,
# nodata=0
# ) as dst:
# dst.write(reshaped_mask, 1)



def georeference_prediction_tiles(
prediction_path: str,
Expand Down Expand Up @@ -161,3 +175,50 @@ def download_or_validate_model(model_path: str) -> str:
raise FileNotFoundError(f"Model file not found: {model_path}")

return model_path

def clean_building_mask(target_preds: np.ndarray, confidence_threshold=0.5,
morph_size=3, small_object_threshold=50):
"""
Clean up building masks to remove thin connections and improve precision.

Args:
target_preds: Raw prediction or binary mask (0-1 range)
confidence_threshold: Base threshold for building/non-building (ignored if input is already binary)
morph_size: Size of morphological operation kernel
small_object_threshold: Minimum area of buildings to keep
Returns:
Cleaned binary mask
"""
# Check if input is already binary (only contains 0s and 1s)
is_binary = np.array_equal(target_preds, target_preds.astype(bool).astype(target_preds.dtype))

if is_binary:
# Skip thresholding if input is already binary
binary_mask = target_preds.astype(np.uint8)
else:
# Apply confidence thresholding
confidence_threshold = min(max(confidence_threshold, 0.1), 0.95)
binary_mask = np.where(target_preds > confidence_threshold, 1, 0).astype(np.uint8)

# Define kernel for morphological operations
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_size, morph_size))

# Apply opening to remove thin connections (erode then dilate)
opened_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)

# Final erosion to amplify differences between buildings
eroded_mask = cv2.erode(opened_mask, kernel, iterations=2)

# Fill holes in buildings with closing
filled_mask = cv2.morphologyEx(eroded_mask, cv2.MORPH_CLOSE, kernel)

# Filter small objects if threshold is provided
if small_object_threshold > 0:
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(filled_mask, connectivity=8)
filtered_mask = np.zeros_like(filled_mask)
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] >= small_object_threshold:
filtered_mask[labels == i] = 1
return filtered_mask

return filled_mask