Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions backend/python/app/services/implementations/sweep_algorithm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#!/usr/bin/env python3
"""
Test script for Sweep clustering with real database locations.

Run from backend/python (or from repo root with PYTHONPATH=backend/python):
python -m app.services.implementations.sweep_algorithm_test

Or run this file directly (from any directory):
python backend/python/app/services/implementations/sweep_algorithm_test.py
"""

import os
import sys

# Ensure the backend root is on path so "app" is importable (works when run as script or -m)
_backend_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, os.pardir)
)
if _backend_root not in sys.path:
sys.path.insert(0, _backend_root)

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sqlmodel import Session, create_engine, func, select

from app.models.location import Location
from app.models.location_group import LocationGroup # noqa: F401
from app.models.route import Route # noqa: F401
from app.models.route_group import RouteGroup # noqa: F401
from app.models.route_group_membership import RouteGroupMembership # noqa: F401
from app.models.route_stop import RouteStop # noqa: F401
from app.models.system_settings import SystemSettings
from app.services.implementations.sweep_clustering import (
SweepClusteringAlgorithm,
)
from app.utilities.geocoding import geocode

# Use the same connection string as seed_database.py
DATABASE_URL = "postgresql://postgres:postgres@f4k_db:5432/f4k"

# Configure number of locations pulled from csv for testing
LOCATIONS_COUNT = 18

# Sweep clustering uses these in the loop; both must be set (use high box limit to avoid splitting by boxes).
NUM_CLUSTERS = 10
MAX_LOCATIONS_PER_CLUSTER = 10
MAX_BOXES_PER_CLUSTER = 9999


async def main() -> None:
engine = create_engine(DATABASE_URL, echo=False)

with Session(engine) as session:
# Fetch locations that have coordinates
statement = (
select(Location)
.where(Location.latitude is not None, Location.longitude is not None)
.order_by(func.random())
.limit(LOCATIONS_COUNT)
)

locations = list(session.exec(statement).all())

print(f"Fetched {len(locations)} locations from database\n")

if len(locations) < 2:
print("Not enough locations with coordinates to cluster!")
return

# Warehouse coordinates: from SystemSettings (lat/lon or geocode address) or centroid fallback
warehouse_lat: float
warehouse_lon: float
system_settings = session.exec(select(SystemSettings).limit(1)).first()
if (
system_settings
and system_settings.warehouse_latitude is not None
and system_settings.warehouse_longitude is not None
):
warehouse_lat = system_settings.warehouse_latitude
warehouse_lon = system_settings.warehouse_longitude
print(f"Using warehouse from system settings: ({warehouse_lat}, {warehouse_lon})\n")
elif system_settings and system_settings.warehouse_location:
coords = await geocode(system_settings.warehouse_location)
if coords is not None:
warehouse_lat = coords["lat"]
warehouse_lon = coords["lng"]
print(f"Geocoded warehouse: ({warehouse_lat}, {warehouse_lon})\n")
else:
warehouse_lat = sum(loc.latitude for loc in locations) / len(locations)
warehouse_lon = sum(loc.longitude for loc in locations) / len(locations)
print(f"Geocode failed; using centroid: ({warehouse_lat}, {warehouse_lon})\n")
else:
warehouse_lat = sum(loc.latitude for loc in locations) / len(locations)
warehouse_lon = sum(loc.longitude for loc in locations) / len(locations)
print(f"No warehouse in system settings; using centroid: ({warehouse_lat}, {warehouse_lon})\n")

total_boxes = sum(loc.num_boxes for loc in locations)

print("Locations to cluster:")
print("-" * 60)
for loc in locations:
name = loc.school_name or loc.contact_name
print(f" {name}")
print(f" Address: {loc.address}")
print(f" Coords: ({loc.latitude}, {loc.longitude})")
print(f" Boxes: {loc.num_boxes}")
print()

print("Total number of boxes: ", total_boxes)
print("Total locations: ", len(locations))

clustering_algo = SweepClusteringAlgorithm()

print("Running Sweep clustering:")
print(f" - Number of clusters: {NUM_CLUSTERS}")
print(f" - Max locations per cluster: {MAX_LOCATIONS_PER_CLUSTER}")
print(f" - Max boxes per cluster: {MAX_BOXES_PER_CLUSTER}")
print("-" * 60)

try:
clusters = await clustering_algo.cluster_locations(
locations=locations,
num_clusters=NUM_CLUSTERS,
warehouse_lat=warehouse_lat,
warehouse_lon=warehouse_lon,
max_locations_per_cluster=MAX_LOCATIONS_PER_CLUSTER,
max_boxes_per_cluster=MAX_BOXES_PER_CLUSTER,
timeout_seconds=30.0,
)

print("\nClustering Results:")
print("=" * 60)

df_rows = []
for i, cluster in enumerate(clusters):
print(f"\nCluster {i + 1} ({len(cluster)} locations):")
print("-" * 40)

if not cluster:
print(" (empty cluster)")
continue

cluster_boxes = 0
for loc in cluster:
name = loc.school_name or loc.contact_name
print(f" • {name}")
print(f" {loc.address}")
print(f" Coords: ({loc.latitude}, {loc.longitude})")
print(f" Boxes: {loc.num_boxes}")
cluster_boxes += loc.num_boxes
df_rows.append(
{"name": name, "long": loc.longitude, "lat": loc.latitude, "group": i}
)

print(f"\n Total boxes in cluster: {cluster_boxes}")

if df_rows:
df = pd.DataFrame(data=df_rows)
sns.scatterplot(data=df, x="long", y="lat", hue="group", palette="Set2")
plt.title(
f"Generated Sweep clustering for {len(locations)} locations with {len(clusters)} clusters"
)
plt.xlabel("Longitude")
plt.ylabel("Latitude")
output_dir = "./app/data"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
plt.savefig(
os.path.join(output_dir, "sweep_clustering_test.png"),
dpi=300,
bbox_inches="tight",
)

print("\n" + "=" * 60)
print("Summary:")
print(f" Total clusters: {len(clusters)}")
print(f" Number of locations in each cluster: {[len(c) for c in clusters]}")
print(f" Total locations clustered: {sum(len(c) for c in clusters)}")

except ValueError as e:
print(f"Clustering failed: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
import traceback

traceback.print_exc()


if __name__ == "__main__":
import asyncio

asyncio.run(main())
166 changes: 166 additions & 0 deletions backend/python/app/services/implementations/sweep_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from app.services.protocols.clustering_algorithm import (
ClusteringAlgorithmProtocol,
)

import math
import time

if TYPE_CHECKING:
from app.models.location import Location

class LocationLatitudeError(Exception):
"""Raised when a location doesn't have a latitude."""

pass


class LocationLongitudeError(Exception):
"""Raised when a location doesn't have a longitude."""

pass

class TimeoutError(Exception):
"""Raised when an operation exceeds its timeout limit."""

pass

class SweepClusteringAlgorithm(ClusteringAlgorithmProtocol):
"""Simple mock clustering algorithm that splits locations into clusters.

This is a pure function with no database interaction. It distributes
locations across clusters while respecting max_locations_per_cluster and
max_boxes_per_cluster constraints.
"""

async def cluster_locations(
self,
locations: list[Location],
num_clusters: int,
warehouse_lat: float,
warehouse_lon: float,
max_locations_per_cluster: int | None = None,
max_boxes_per_cluster: int | None = None,
timeout_seconds: float | None = None,
) -> list[list[Location]]:
"""Split locations into clusters while respecting box constraints.

Args:
locations: List of locations to cluster
num_clusters: Target number of clusters to create
max_locations_per_cluster: Optional maximum number of locations
per cluster. If provided, validates that the clustering is
possible and raises an error if violated.
max_boxes_per_cluster: Optional maximum number of boxes per cluster.
If provided, validates that the clustering is possible and
raises an error if violated.
timeout_seconds: Optional timeout in seconds. Not enforced in this
mock implementation.

Returns:
List of clusters, where each cluster is a list of locations

Raises:
ValueError: If the clustering parameters are invalid or cannot
be satisfied
"""

start_time = time.time()

def check_timeout() -> None:
if timeout_seconds is not None:
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
raise TimeoutError(
f"Route generation exceeded timeout of {timeout_seconds}s "
f"(elapsed: {elapsed:.2f}s)"
)

def calculate_angle_from_warehouse(location: Location) -> float | None:
if location.latitude is None:
raise LocationLatitudeError(
f"Location {location.location_id} is missing latitude."
)
if location.longitude is None:
raise LocationLongitudeError(
f"Location {location.location_id} is missing longitude."
)
lat_difference = location.latitude - warehouse_lat
lon_difference = location.longitude - warehouse_lon
return math.atan2(lat_difference, lon_difference) % math.tau

def calculate_distance_squared(location: Location) -> float | None:
if location.latitude is None:
raise LocationLatitudeError(
f"Location {location.location_id} is missing latitude."
)
if location.longitude is None:
raise LocationLongitudeError(
f"Location {location.location_id} is missing longitude."
)
lat_difference = location.latitude - warehouse_lat
lon_difference = location.longitude - warehouse_lon
return lon_difference**2 + lat_difference**2
if len(locations) == 0:
raise ValueError("locations list cannot be empty")

if num_clusters < 1:
raise ValueError("num_clusters must be at least 1")

# Calculate base cluster size and validate constraints
total_locations = len(locations)
base_cluster_size = total_locations // num_clusters
remainder = total_locations % num_clusters

if base_cluster_size == 0:
raise ValueError(
f"Cannot create {num_clusters} clusters: not enough locations"
)

# The largest cluster will have base_cluster_size + 1 if remainder > 0
max_cluster_size = base_cluster_size + (1 if remainder > 0 else 0)
if max_locations_per_cluster and max_cluster_size > max_locations_per_cluster:
raise ValueError(
f"Cannot create {num_clusters} clusters with max "
f"{max_locations_per_cluster} locations per cluster. "
f"Required cluster size would be up to {max_cluster_size}."
)

# Distribute locations while respecting constraints
clusters: list[list[Location]] = []
current_location_count = 0
current_box_count = 0
current_cluster = []

locations_with_angles = []
for location in locations:
check_timeout()
angle = calculate_angle_from_warehouse(location)
distance_squared = calculate_distance_squared(location)
locations_with_angles.append((location, angle, distance_squared))

sorted_locations = sorted(locations_with_angles, key=lambda location: (location[1], location[2]))


for loc, angle, distance in sorted_locations:
check_timeout()
would_exceed_locations = current_location_count + 1 > max_locations_per_cluster
would_exceed_boxes = current_box_count + loc.num_boxes > max_boxes_per_cluster

if current_cluster and (would_exceed_locations or would_exceed_boxes):
clusters.append(current_cluster)
current_cluster = []
current_location_count = 0
current_box_count = 0

current_cluster.append(loc)
current_location_count += 1
current_box_count += loc.num_boxes

if current_cluster:
clusters.append(current_cluster)

return clusters
Loading