Skip to content

Commit 1ea5cf9

Browse files
committed
clean and speed up nwa
1 parent 69f1e93 commit 1ea5cf9

File tree

10 files changed

+19
-50
lines changed

10 files changed

+19
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "ssb-sgis"
3-
version = "1.3.7"
3+
version = "1.3.8"
44
description = "GIS functions used at Statistics Norway."
55
authors = ["Morten Letnes <morten.letnes@ssb.no>"]
66
license = "MIT"

src/sgis/geopandas_tools/runners.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,6 @@ def run(
239239
left, right = results
240240
return left, right
241241
return results
242-
left = np.concatenate([x[0] for x in results])
243-
right = np.concatenate([x[1] for x in results])
244-
return left, right
245242
elif (
246243
(self.n_jobs or 1) > 1
247244
and len(arr2) / self.n_jobs > 10_000
@@ -264,9 +261,6 @@ def run(
264261
left, right = results
265262
return left, right
266263
return results
267-
left = np.concatenate([x[0] for x in results])
268-
right = np.concatenate([x[1] for x in results])
269-
return left, right
270264

271265
return _strtree_query(arr1, arr2, method=method, **kwargs)
272266

src/sgis/networkanalysis/_od_cost_matrix.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,11 @@ def _od_cost_matrix(
2626
# calculating all-to-all distances is much faster than looping rowwise,
2727
# so filtering to rowwise afterwards instead
2828
if rowwise:
29-
rowwise_df = DataFrame(
30-
{
31-
"origin": origins.index,
32-
"destination": destinations.index,
33-
}
29+
keys = pd.MultiIndex.from_arrays(
30+
[origins.index, destinations.index],
31+
names=["origin", "destination"],
3432
)
35-
results = rowwise_df.merge(results, on=["origin", "destination"], how="left")
33+
results = results.set_index(["origin", "destination"]).loc[keys].reset_index()
3634

3735
results["geom_ori"] = results["origin"].map(origins.geometry)
3836
results["geom_des"] = results["destination"].map(destinations.geometry)

src/sgis/networkanalysis/_points.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717

1818
class Points:
19-
def __init__(
20-
self,
21-
points: GeoDataFrame,
22-
) -> None:
19+
def __init__(self, points: GeoDataFrame) -> None:
2320
self.gdf = points.copy()
2421

2522
def _make_temp_idx(self, start: int) -> None:

src/sgis/networkanalysis/closing_network_holes.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def close_network_holes(
7979
gdf: GeoDataFrame,
8080
max_distance: int | float,
8181
max_angle: int,
82+
*,
8283
hole_col: str | None = "hole",
8384
) -> GeoDataFrame:
8485
"""Fills network gaps with straigt lines.
@@ -282,11 +283,13 @@ def _close_holes_all_lines(
282283
) -> GeoSeries:
283284
k = min(len(nodes), 50)
284285

286+
n_dict = nodes.set_index("wkt")["n"]
287+
285288
# make points for the deadends and the other endpoint of the deadend lines
286-
deadends_target = lines.loc[lines["n_target"] == 1].rename(
289+
deadends_target = lines.loc[lines["target_wkt"].map(n_dict) == 1].rename(
287290
columns={"target_wkt": "wkt", "source_wkt": "wkt_other_end"}
288291
)
289-
deadends_source = lines.loc[lines["n_source"] == 1].rename(
292+
deadends_source = lines.loc[lines["source_wkt"].map(n_dict) == 1].rename(
290293
columns={"source_wkt": "wkt", "target_wkt": "wkt_other_end"}
291294
)
292295
deadends = pd.concat([deadends_source, deadends_target], ignore_index=True)
@@ -349,12 +352,6 @@ def get_angle_difference(angle1, angle2):
349352
to_idx = indices[condition]
350353
to_wkt = nodes.iloc[to_idx]["wkt"]
351354

352-
# all_angles = all_angles + [
353-
# diff
354-
# for f, diff in zip(from_wkt, angles_difference[condition], strict=True)
355-
# if f not in new_sources
356-
# ]
357-
358355
# now add the wkts to the lists of new sources and targets. If the source
359356
# is already added, the new wks will not be added again
360357
new_targets = new_targets + [

src/sgis/networkanalysis/finding_isolated_networks.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ def get_connected_components(gdf: GeoDataFrame) -> GeoDataFrame:
5757

5858
gdf["connected"] = gdf.source.map(largest_component_dict).fillna(0)
5959

60-
gdf = gdf.drop(
61-
["source_wkt", "target_wkt", "source", "target", "n_source", "n_target"], axis=1
62-
)
60+
gdf = gdf.drop(["source_wkt", "target_wkt", "source", "target"], axis=1)
6361

6462
return gdf
6563

@@ -120,8 +118,6 @@ def get_component_size(gdf: GeoDataFrame) -> GeoDataFrame:
120118
gdf["component_index"] = gdf["source"].map(mapper["component_index"])
121119
gdf["component_size"] = gdf["source"].map(mapper["component_size"])
122120

123-
gdf = gdf.drop(
124-
["source_wkt", "target_wkt", "source", "target", "n_source", "n_target"], axis=1
125-
)
121+
gdf = gdf.drop(["source_wkt", "target_wkt", "source", "target"], axis=1)
126122

127123
return gdf

src/sgis/networkanalysis/network.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(self, gdf: GeoDataFrame) -> None:
4141
self.gdf = self._prepare_network(gdf)
4242

4343
self._make_node_ids()
44-
self._percent_bidirectional = self._check_percent_bidirectional()
4544

4645
def _make_node_ids(self) -> None:
4746
"""Gives the lines node ids and return lines (edges) and nodes.

src/sgis/networkanalysis/networkanalysis.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,6 @@ def _get_edges_and_weights(
14131413
"""
14141414
if self.rules.split_lines:
14151415
self._split_lines()
1416-
# self.network._make_node_ids()
14171416
self.origins._make_temp_idx(
14181417
start=max(self.network.nodes.node_id.astype(int)) + 1
14191418
)
@@ -1428,6 +1427,7 @@ def _get_edges_and_weights(
14281427

14291428
self.network.gdf["src_tgt_wt"] = self.network._create_edge_ids(edges, weights)
14301429

1430+
# add edges between origins+destinations to the network nodes
14311431
edges_start, weights_start = self.origins._get_edges_and_weights(
14321432
nodes=self.network.nodes,
14331433
rules=self.rules,
@@ -1587,7 +1587,7 @@ def _graph_is_up_to_date(self) -> bool:
15871587
for points in ["origins", "destinations"]:
15881588
if self[points] is None:
15891589
continue
1590-
if points not in self.wkts:
1590+
if not hasattr(self, points) or self[points] is None:
15911591
return False
15921592
if self._points_have_changed(self[points].gdf, what=points):
15931593
return False
@@ -1617,8 +1617,6 @@ def _update_wkts(self) -> None:
16171617
"""
16181618
self.wkts = {}
16191619

1620-
self.wkts["network"] = self.network.gdf.geometry.to_wkt().values
1621-
16221620
if not hasattr(self, "origins"):
16231621
return
16241622

src/sgis/networkanalysis/nodes.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,14 @@ def make_node_ids(
4747
gdf = make_edge_coords_cols(gdf)
4848
geomcol1, geomcol2, geomcol_final = "source_coords", "target_coords", "coords"
4949

50-
# remove identical lines in opposite directions
50+
# remove identical lines in opposite directions in order to get n==1 for deadends
5151
gdf["meters_"] = gdf.length.astype(str)
52-
5352
sources = gdf[[geomcol1, geomcol2, "meters_"]].rename(
5453
columns={geomcol1: geomcol_final, geomcol2: "temp"}
5554
)
5655
targets = gdf[[geomcol1, geomcol2, "meters_"]].rename(
5756
columns={geomcol2: geomcol_final, geomcol1: "temp"}
5857
)
59-
6058
nodes = (
6159
pd.concat([sources, targets], axis=0, ignore_index=True)
6260
.drop_duplicates([geomcol_final, "temp", "meters_"])
@@ -66,18 +64,12 @@ def make_node_ids(
6664
gdf = gdf.drop("meters_", axis=1)
6765

6866
nodes["n"] = nodes.assign(n=1).groupby(geomcol_final)["n"].transform("sum")
69-
7067
nodes = nodes.drop_duplicates(subset=[geomcol_final]).reset_index(drop=True)
71-
7268
nodes["node_id"] = nodes.index
7369
nodes["node_id"] = nodes["node_id"].astype(str)
7470

7571
gdf = _map_node_ids_from_wkt(gdf, nodes, wkt=wkt)
7672

77-
n_dict = {geom: n for geom, n in zip(nodes[geomcol_final], nodes["n"], strict=True)}
78-
gdf["n_source"] = gdf[geomcol1].map(n_dict)
79-
gdf["n_target"] = gdf[geomcol2].map(n_dict)
80-
8173
if wkt:
8274
nodes["geometry"] = gpd.GeoSeries.from_wkt(nodes[geomcol_final], crs=gdf.crs)
8375
else:

tests/test_network_analysis.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def not_test_get_k_routes(nwa, p):
273273

274274
def not_test_direction(roads_oslo):
275275
"""Check that a route that should go in separate tunnels, goes in correct tunnels."""
276+
m = 5
277+
276278
vippetangen = sg.to_gdf([10.741527, 59.9040595], crs=4326).to_crs(roads_oslo.crs)
277279
ryen = sg.to_gdf([10.8047522, 59.8949826], crs=4326).to_crs(roads_oslo.crs)
278280

@@ -283,25 +285,21 @@ def not_test_direction(roads_oslo):
283285
tunnel_tofrom = sg.to_gdf([10.7724645, 59.899908], crs=4326).to_crs(roads_oslo.crs)
284286

285287
clipped = sg.clean_clip(roads_oslo, tunnel_fromto.buffer(2000))
288+
286289
connected_roads = sg.get_connected_components(clipped).query("connected == 1")
287290
directed_roads = sg.make_directed_network_norway(connected_roads, dropnegative=True)
288291
rules = sg.NetworkAnalysisRules(directed=True, weight="minutes")
289292
nwa = sg.NetworkAnalysis(directed_roads, rules=rules)
290-
291293
route_fromto = nwa.get_route(vippetangen, ryen)
292294
route_tofrom = nwa.get_route(ryen, vippetangen)
293-
294-
m = 5
295295
should_be_within = route_fromto.sjoin_nearest(tunnel_fromto, distance_col="dist")
296296
assert should_be_within["dist"].max() < m, should_be_within["dist"]
297297
should_be_within = route_tofrom.sjoin_nearest(tunnel_tofrom, distance_col="dist")
298298
assert should_be_within["dist"].max() < m, should_be_within["dist"]
299-
300299
should_not_be_within = route_fromto.sjoin_nearest(
301300
tunnel_tofrom, distance_col="dist"
302301
)
303302
assert should_not_be_within["dist"].max() > m, should_not_be_within["dist"]
304-
305303
should_not_be_within = route_tofrom.sjoin_nearest(
306304
tunnel_fromto, distance_col="dist"
307305
)

0 commit comments

Comments
 (0)