Skip to content

Commit e303b8d

Browse files
committed
sweep clustering
1 parent d96429d commit e303b8d

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test script for KMeans clustering with real database locations.
4+
Run with: python -m app.services.implementations.k_means_test
5+
"""
6+
7+
import os
8+
import sys
9+
10+
import matplotlib.pyplot as plt
11+
import pandas as pd # Often useful for data handling
12+
import seaborn as sns
13+
14+
from app.utilities.geocoding import geocode
15+
16+
sys.path.insert(0, "/app")
17+
18+
from app.models.admin import Admin
19+
from sqlmodel import Session, create_engine, func, select
20+
21+
# Import all models to register them with SQLModel
22+
from app.models.location import Location
23+
from app.models.location_group import LocationGroup # noqa: F401
24+
from app.models.route import Route # noqa: F401
25+
from app.models.route_group import RouteGroup # noqa: F401
26+
from app.models.route_group_membership import RouteGroupMembership # noqa: F401
27+
from app.models.route_stop import RouteStop # noqa: F401
28+
from app.services.implementations.sweep_clustering import (
29+
SweepClusteringAlgorithm,
30+
)
31+
32+
# Use the same connection string as seed_database.py
33+
DATABASE_URL = "postgresql://postgres:postgres@f4k_db:5432/f4k"
34+
35+
36+
def main() -> None:
37+
engine = create_engine(DATABASE_URL, echo=False)
38+
39+
with Session(engine) as session:
40+
# Fetch locations that have coordinates
41+
statement = (
42+
select(Location)
43+
.where(Location.latitude is not None, Location.longitude is not None)
44+
.order_by(func.random())
45+
.limit(20)
46+
)
47+
48+
locations = list(session.exec(statement).all())
49+
50+
print(f"Fetched {len(locations)} locations from database\n")
51+
52+
if len(locations) < 2:
53+
print("Not enough locations with coordinates to cluster!")
54+
return
55+
56+
# fetch warehouse location
57+
statement = (
58+
select(Admin)
59+
)
60+
61+
admin = session.exec(statement).all()
62+
warehouse_address_string = admin.warehouse_location
63+
64+
warehouse_address_coordinates = geocode(warehouse_address_string)
65+
longitude, latitude = warehouse_address_coordinates["lng"], warehouse_address_coordinates["lat"]
66+
# Count total number of boxes
67+
total_boxes = 0
68+
69+
# Print the locations
70+
print("Locations to cluster:")
71+
print("-" * 60)
72+
for loc in locations:
73+
name = loc.school_name or loc.contact_name
74+
print(f" {name}")
75+
print(f" Address: {loc.address}")
76+
print(f" Coords: ({loc.latitude}, {loc.longitude})")
77+
print(f" Boxes: {loc.num_boxes}")
78+
print()
79+
total_boxes = sum(loc.num_boxes for loc in locations)
80+
81+
print("Total number of boxes: ", total_boxes)
82+
print("Total locations: ", len(locations))
83+
84+
# Run clustering
85+
clustering_algo = SweepClusteringAlgorithm()
86+
num_clusters = 9
87+
max_locations_per_cluster = 10
88+
max_boxes_per_cluster = None
89+
90+
print("Running K-Means clustering:")
91+
print(f" - Number of clusters: {num_clusters}")
92+
print(f" - Max locations per cluster: {max_locations_per_cluster}")
93+
print(f" - Max boxes per cluster: {max_boxes_per_cluster}")
94+
print("-" * 60)
95+
96+
try:
97+
clusters = clustering_algo.cluster_locations(
98+
locations=locations,
99+
num_clusters=num_clusters,
100+
max_locations_per_cluster=max_locations_per_cluster,
101+
max_boxes_per_cluster=max_boxes_per_cluster,
102+
timeout_seconds=30.0,
103+
)
104+
105+
# Print results
106+
print("\nClustering Results:")
107+
print("=" * 60)
108+
109+
df_rows = []
110+
for i, cluster in enumerate(clusters):
111+
print(f"\nCluster {i + 1} ({len(cluster)} locations):")
112+
print("-" * 40)
113+
114+
if not cluster:
115+
print(" (empty cluster)")
116+
continue
117+
118+
total_boxes = 0
119+
for loc in cluster:
120+
name = loc.school_name or loc.contact_name
121+
print(f" • {name}")
122+
print(f" {loc.address}")
123+
print(f" Coords: ({loc.latitude}, {loc.longitude})")
124+
print(f" Boxes: {loc.num_boxes}")
125+
total_boxes += loc.num_boxes
126+
new_row = {
127+
"name": name,
128+
"long": loc.longitude,
129+
"lat": loc.latitude,
130+
"group": i,
131+
}
132+
df_rows.append(new_row)
133+
df = pd.DataFrame(data=df_rows)
134+
sns.scatterplot(data=df, x="long", y="lat", hue="group", palette="Set2")
135+
plt.title(
136+
f"Generated Sweep Clustering classification for {len(locations)} locations with {len(clusters)} clusters"
137+
)
138+
plt.xlabel("Longitude")
139+
plt.ylabel("Latitude")
140+
output_dir = "./app/data"
141+
if not os.path.exists(output_dir):
142+
os.makedirs(output_dir)
143+
filename = os.path.join(output_dir, "sweep_clustering_test.png")
144+
plt.savefig(filename, dpi=300, bbox_inches="tight")
145+
146+
print(f"\n Total boxes in cluster: {total_boxes}")
147+
print("\n" + "=" * 60)
148+
print("Summary:")
149+
print(f" Total clusters: {len(clusters)}")
150+
print(
151+
f" Number of locations in each cluster: {[len(c) for c in clusters]}"
152+
)
153+
print(f" Total locations clustered: {sum(len(c) for c in clusters)}")
154+
155+
except ValueError as e:
156+
print(f"Clustering failed: {e}")
157+
except Exception as e:
158+
print(f"Unexpected error: {e}")
159+
import traceback
160+
161+
traceback.print_exc()
162+
163+
164+
if __name__ == "__main__":
165+
main()
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from app.services.protocols.clustering_algorithm import (
6+
ClusteringAlgorithmProtocol,
7+
)
8+
9+
import math
10+
import time
11+
12+
if TYPE_CHECKING:
13+
from app.models.location import Location
14+
15+
class LocationLatitudeError(Exception):
16+
"""Raised when a location doesn't have a latitude."""
17+
18+
pass
19+
20+
21+
class LocationLongitudeError(Exception):
22+
"""Raised when a location doesn't have a longitude."""
23+
24+
pass
25+
26+
class TimeoutError(Exception):
27+
"""Raised when an operation exceeds its timeout limit."""
28+
29+
pass
30+
31+
class SweepClusteringAlgorithm(ClusteringAlgorithmProtocol):
32+
"""Simple mock clustering algorithm that splits locations into clusters.
33+
34+
This is a pure function with no database interaction. It distributes
35+
locations across clusters while respecting max_locations_per_cluster and
36+
max_boxes_per_cluster constraints.
37+
"""
38+
39+
async def cluster_locations(
40+
self,
41+
locations: list[Location],
42+
num_clusters: int,
43+
warehouse_lat: float,
44+
warehouse_lon: float,
45+
max_locations_per_cluster: int | None = None,
46+
max_boxes_per_cluster: int | None = None,
47+
timeout_seconds: float | None = None,
48+
) -> list[list[Location]]:
49+
"""Split locations into clusters while respecting box constraints.
50+
51+
Args:
52+
locations: List of locations to cluster
53+
num_clusters: Target number of clusters to create
54+
max_locations_per_cluster: Optional maximum number of locations
55+
per cluster. If provided, validates that the clustering is
56+
possible and raises an error if violated.
57+
max_boxes_per_cluster: Optional maximum number of boxes per cluster.
58+
If provided, validates that the clustering is possible and
59+
raises an error if violated.
60+
timeout_seconds: Optional timeout in seconds. Not enforced in this
61+
mock implementation.
62+
63+
Returns:
64+
List of clusters, where each cluster is a list of locations
65+
66+
Raises:
67+
ValueError: If the clustering parameters are invalid or cannot
68+
be satisfied
69+
"""
70+
71+
start_time = time.time()
72+
73+
def check_timeout() -> None:
74+
if timeout_seconds is not None:
75+
elapsed = time.time() - start_time
76+
if elapsed > timeout_seconds:
77+
raise TimeoutError(
78+
f"Route generation exceeded timeout of {timeout_seconds}s "
79+
f"(elapsed: {elapsed:.2f}s)"
80+
)
81+
82+
def calculate_angle_from_warehouse(location: Location) -> float | None:
83+
if location.latitude is None:
84+
raise LocationLatitudeError(
85+
f"Location {location.location_id} is missing latitude."
86+
)
87+
if location.longitude is None:
88+
raise LocationLongitudeError(
89+
f"Location {location.location_id} is missing longitude."
90+
)
91+
lat_difference = location.latitude - warehouse_lat
92+
lon_difference = location.longitude - warehouse_lon
93+
return math.atan2(lat_difference, lon_difference) % math.tau
94+
95+
def calculate_distance_squared(location: Location) -> float | None:
96+
if location.latitude is None:
97+
raise LocationLatitudeError(
98+
f"Location {location.location_id} is missing latitude."
99+
)
100+
if location.longitude is None:
101+
raise LocationLongitudeError(
102+
f"Location {location.location_id} is missing longitude."
103+
)
104+
lat_difference = location.latitude - warehouse_lat
105+
lon_difference = location.longitude - warehouse_lon
106+
return lon_difference**2 + lat_difference**2
107+
if len(locations) == 0:
108+
raise ValueError("locations list cannot be empty")
109+
110+
if num_clusters < 1:
111+
raise ValueError("num_clusters must be at least 1")
112+
113+
# Calculate base cluster size and validate constraints
114+
total_locations = len(locations)
115+
base_cluster_size = total_locations // num_clusters
116+
remainder = total_locations % num_clusters
117+
118+
if base_cluster_size == 0:
119+
raise ValueError(
120+
f"Cannot create {num_clusters} clusters: not enough locations"
121+
)
122+
123+
# The largest cluster will have base_cluster_size + 1 if remainder > 0
124+
max_cluster_size = base_cluster_size + (1 if remainder > 0 else 0)
125+
if max_locations_per_cluster and max_cluster_size > max_locations_per_cluster:
126+
raise ValueError(
127+
f"Cannot create {num_clusters} clusters with max "
128+
f"{max_locations_per_cluster} locations per cluster. "
129+
f"Required cluster size would be up to {max_cluster_size}."
130+
)
131+
132+
# Distribute locations while respecting constraints
133+
clusters: list[list[Location]] = []
134+
current_location_count = 0
135+
current_box_count = 0
136+
current_cluster = []
137+
138+
locations_with_angles = []
139+
for location in locations:
140+
check_timeout()
141+
angle = calculate_angle_from_warehouse(location)
142+
distance_squared = calculate_distance_squared(location)
143+
locations_with_angles.append((location, angle, distance_squared))
144+
145+
sorted_locations = sorted(locations_with_angles, key=lambda location: (location[1], location[2]))
146+
147+
148+
for loc, angle, distance in sorted_locations:
149+
check_timeout()
150+
would_exceed_locations = current_location_count + 1 > max_locations_per_cluster
151+
would_exceed_boxes = current_box_count + loc.num_boxes > max_boxes_per_cluster
152+
153+
if current_cluster and (would_exceed_locations or would_exceed_boxes):
154+
clusters.append(current_cluster)
155+
current_cluster = []
156+
current_location_count = 0
157+
current_box_count = 0
158+
159+
current_cluster.append(loc)
160+
current_location_count += 1
161+
current_box_count += loc.num_boxes
162+
163+
if current_cluster:
164+
clusters.append(current_cluster)
165+
166+
return clusters

0 commit comments

Comments
 (0)