Skip to content

Commit 5fe1481

Browse files
committed
Update API to support predict start coordinates instead SHAPE_Length
1 parent 827d902 commit 5fe1481

File tree

1 file changed

+44
-9
lines changed
  • Use_Cases/Weather-Aware Routing/Duy Pham

1 file changed

+44
-9
lines changed
Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import pandas as pd
2+
from shapely import wkt
3+
from shapely.geometry import Point
4+
import geopandas as gpd
5+
import joblib
16
from fastapi import FastAPI
27
from pydantic import BaseModel
3-
import joblib
4-
import pandas as pd
58

69
model = joblib.load("ev_model.pkl")
710

@@ -14,9 +17,30 @@
1417
"total_prcp",
1518
]
1619

17-
class StationFeatures(BaseModel):
20+
traffic_df = pd.read_csv("Traffic data.csv")
21+
22+
23+
traffic_df["geometry"] = traffic_df["geometry"].apply(wkt.loads)
24+
25+
traffic_gdf = gpd.GeoDataFrame(traffic_df, geometry="geometry", crs="EPSG:4326")
26+
27+
def get_shape_length_from_coords(lat: float, lon: float) -> float:
28+
"""
29+
Given a lat/lon point, find the nearest road segment in traffic_gdf
30+
and return its SHAPE_Length.
31+
"""
32+
pt = Point(lon, lat)
33+
34+
distances = traffic_gdf.distance(pt)
35+
nearest_idx = distances.idxmin()
36+
37+
shape_length = traffic_gdf.loc[nearest_idx, "SHAPE_Length"]
38+
return float(shape_length)
39+
40+
class CoordRequest(BaseModel):
1841
Year: int
19-
SHAPE_Length: float
42+
start_lat: float
43+
start_lon: float
2044
dist_to_nearest_ev_m: float
2145
ev_within_500m: int
2246
avg_temp: float
@@ -28,11 +52,22 @@ class StationFeatures(BaseModel):
2852
def root():
2953
return {"message": "EV model API is running"}
3054

31-
@app.post("/predict")
32-
def predict(features: StationFeatures):
33-
data = pd.DataFrame([[getattr(features, f) for f in FEATURES]],
34-
columns=FEATURES)
55+
@app.post("/predict_from_coords")
56+
def predict_from_coords(req: CoordRequest):
57+
shape_length = get_shape_length_from_coords(req.start_lat, req.start_lon)
58+
59+
data = pd.DataFrame([[
60+
req.Year,
61+
shape_length,
62+
req.dist_to_nearest_ev_m,
63+
req.ev_within_500m,
64+
req.avg_temp,
65+
req.total_prcp,
66+
]], columns=FEATURES)
3567

3668
pred = model.predict(data)[0]
3769

38-
return {"prediction": float(pred)}
70+
return {
71+
"prediction": float(pred),
72+
"used_SHAPE_Length": shape_length
73+
}

0 commit comments

Comments
 (0)