Skip to content

Commit bf6ef11

Browse files
committed
fix: fix infinite loops
1 parent 398b033 commit bf6ef11

File tree

2 files changed

+88
-17
lines changed

2 files changed

+88
-17
lines changed

src/ssspx/frontier.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,15 @@ def _consume_block_prefix(
148148
"""Greedily take up to ``want`` keys from the head of ``blocks``."""
149149
got = 0
150150
idx_block = 0
151-
while got < want and idx_block < len(blocks):
151+
iterations = 0
152+
max_iterations = len(blocks) * 100 # Safety limit to prevent infinite loops
153+
154+
while got < want and idx_block < len(blocks) and iterations < max_iterations:
155+
iterations += 1
152156
block = blocks[idx_block]
153157
keep: List[Tuple[Vertex, Float]] = []
158+
processed_any = False
159+
154160
for k, v in block:
155161
bestv = self._best.get(k)
156162
if bestv is None or v != bestv:
@@ -162,13 +168,21 @@ def _consume_block_prefix(
162168
chosen[k] = v
163169
pulled_keys.add(k)
164170
got += 1
171+
processed_any = True
165172
else:
166173
keep.append((k, v))
174+
167175
blocks[idx_block] = keep
168176
if not blocks[idx_block]:
169177
blocks.pop(idx_block)
178+
# Don't increment idx_block since we removed an element
170179
else:
171180
idx_block += 1
181+
182+
# If we didn't process anything useful, break to avoid infinite loop
183+
if not processed_any and not keep:
184+
break
185+
172186
return got
173187

174188
def pull(self) -> Tuple[Set[Vertex], Float]:
@@ -258,11 +272,16 @@ def batch_prepend(self, pairs: Iterable[Tuple[Vertex, Float]]) -> None:
258272
def pull(self) -> Tuple[Set[Vertex], Float]:
259273
"""Return up to ``M`` keys with the smallest values."""
260274
s: Set[Vertex] = set()
261-
while self._heap and len(s) < self.M:
275+
iterations = 0
276+
max_iterations = len(self._heap) * 2 # Safety limit
277+
278+
while self._heap and len(s) < self.M and iterations < max_iterations:
279+
iterations += 1
262280
val, key = heapq.heappop(self._heap)
263281
if self._best.get(key) != val:
264282
continue # stale
265283
s.add(key)
284+
266285
if not s:
267286
return set(), self.B
268287
x = self._heap[0][0] if self._heap else self.B

src/ssspx/solver.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
"pulls": 0,
100100
"findpivots_rounds": 0,
101101
"basecase_pops": 0,
102+
"iterations_protected": 0, # Track safety limit hits
102103
}
103104

104105
# Optional transform for outdegree
@@ -133,19 +134,22 @@ def __init__(
133134
self.complete[s] = True
134135
self.root[s] = s
135136

136-
# Parameters (k, t, levels)
137+
# Parameters (k, t, levels) with safety bounds
137138
n = max(2, self.G.n)
138139
if self.cfg.k_t_auto:
139140
log2n = math.log2(n)
140141
k = max(1, int(round(log2n ** (1.0 / 3.0))))
141142
t = max(1, int(round(log2n ** (2.0 / 3.0))))
142143
k = max(1, min(k, t))
144+
# Cap to reasonable values to prevent runaway algorithms
145+
k = min(k, 100)
146+
t = min(t, 20)
143147
else:
144-
k = max(1, self.cfg.k)
145-
t = max(1, self.cfg.t)
148+
k = max(1, min(self.cfg.k, 100)) # Safety cap
149+
t = max(1, min(self.cfg.t, 20)) # Safety cap
146150
self.k: int = k
147151
self.t: int = t
148-
self.L: int = max(1, math.ceil(math.log2(n) / t))
152+
self.L: int = max(1, min(math.ceil(math.log2(n) / t), 10)) # Cap levels
149153

150154
# Best-clone cache for each original vertex after solve()
151155
self._best_clone_for_orig: Optional[List[int]] = None
@@ -211,22 +215,32 @@ def _base_case(self, B: Float, S: Set[Vertex]) -> Tuple[Float, Set[Vertex]]:
211215
seen: Set[Vertex] = set()
212216
heap: List[Tuple[Float, Vertex]] = [(self.dhat[x], x)]
213217
in_heap: Set[Vertex] = {x}
218+
219+
# Safety limits to prevent infinite loops
220+
iterations = 0
221+
max_iterations = min(self.k * 1000, self.G.n * 10)
214222

215-
while heap and len(U0) < self.k + 1:
223+
while heap and len(U0) < self.k + 1 and iterations < max_iterations:
216224
self.counters["basecase_pops"] += 1
225+
iterations += 1
226+
217227
du, u = heapq.heappop(heap)
218228
in_heap.discard(u)
219229
if du != self.dhat[u] or u in seen:
220230
continue
221231
seen.add(u)
222232
self.complete[u] = True
223233
U0.append(u)
234+
224235
for v, w in self.G.adj[u]:
225236
if self._relax(u, v, w) and self.dhat[u] + w < B:
226237
if v not in in_heap:
227238
heapq.heappush(heap, (self.dhat[v], v))
228239
in_heap.add(v)
229240

241+
if iterations >= max_iterations:
242+
self.counters["iterations_protected"] += 1
243+
230244
if len(U0) <= self.k:
231245
return (B, set(U0))
232246
Bprime = max(self.dhat[v] for v in U0)
@@ -248,20 +262,38 @@ def _find_pivots(self, B: Float, S: Set[Vertex]) -> Tuple[Set[Vertex], Set[Verte
248262
"""
249263
W: Set[Vertex] = set(S)
250264
current: Set[Vertex] = set(S)
251-
for _ in range(1, self.k + 1):
265+
266+
# Safety limits for findpivots
267+
iterations = 0
268+
max_iterations = min(self.k * len(S) * 100, self.G.n * 10)
269+
270+
for round_num in range(1, self.k + 1):
271+
if iterations >= max_iterations:
272+
self.counters["iterations_protected"] += 1
273+
break
274+
252275
self.counters["findpivots_rounds"] += 1
253276
nxt: Set[Vertex] = set()
277+
254278
for u in current:
279+
iterations += 1
280+
if iterations >= max_iterations:
281+
break
282+
255283
for v, w in self.G.adj[u]:
256284
if self._relax(u, v, w) and (self.dhat[u] + w < B):
257285
nxt.add(v)
286+
258287
if not nxt:
259288
break
260289
W |= nxt
261-
if len(W) > self.k * max(1, len(S)):
290+
291+
# Early termination if W gets too large
292+
if len(W) > self.k * max(1, len(S)) * 5: # More generous limit
262293
return set(S), W
263294
current = nxt
264295

296+
# Build pivot tree with safety limits
265297
children: Dict[Vertex, List[Vertex]] = {u: [] for u in W}
266298
for v in W:
267299
p = self.pred[v]
@@ -273,7 +305,11 @@ def _find_pivots(self, B: Float, S: Set[Vertex]) -> Tuple[Set[Vertex], Set[Verte
273305
size = 0
274306
stack = [u]
275307
seen: Set[Vertex] = set()
276-
while stack:
308+
iterations = 0
309+
max_tree_iterations = min(self.k * 10, len(W))
310+
311+
while stack and iterations < max_tree_iterations:
312+
iterations += 1
277313
a = stack.pop()
278314
if a in seen:
279315
continue
@@ -288,14 +324,19 @@ def _find_pivots(self, B: Float, S: Set[Vertex]) -> Tuple[Set[Vertex], Set[Verte
288324
# ---------- BMSSP -----------------------------------------------------
289325

290326
def _make_frontier(self, level: int, B: Float) -> FrontierProtocol:
291-
M = max(1, 2 ** ((level - 1) * self.t))
327+
# Cap the frontier size to prevent excessive memory usage
328+
M = max(1, min(2 ** ((level - 1) * self.t), 10000))
292329
if self.cfg.frontier == "heap":
293330
return HeapFrontier(M=M, B=B)
294331
if self.cfg.frontier == "block":
295332
return BlockFrontier(M=M, B=B)
296333
raise ConfigError(f"unknown frontier '{self.cfg.frontier}'")
297334

298-
def _bmssp(self, level: int, B: Float, S: Set[Vertex]) -> Tuple[Float, Set[Vertex]]:
335+
def _bmssp(self, level: int, B: Float, S: Set[Vertex], depth: int = 0) -> Tuple[Float, Set[Vertex]]:
336+
# Prevent excessive recursion
337+
if depth > 50 or level > 20:
338+
return self._base_case(B, S)
339+
299340
if level == 0:
300341
return self._base_case(B, S)
301342

@@ -305,10 +346,14 @@ def _bmssp(self, level: int, B: Float, S: Set[Vertex]) -> Tuple[Float, Set[Verte
305346
D.insert(x, self.dhat[x])
306347

307348
U_accum: Set[Vertex] = set()
308-
cap = self.k * max(1, 2 ** (level * self.t))
349+
cap = min(self.k * max(1, 2 ** (level * self.t)), self.G.n) # Cap to graph size
350+
pull_iterations = 0
351+
max_pull_iterations = min(cap * 10, 1000) # Safety limit on pulls
309352

310-
while len(U_accum) < cap:
353+
while len(U_accum) < cap and pull_iterations < max_pull_iterations:
311354
self.counters["pulls"] += 1
355+
pull_iterations += 1
356+
312357
S_i, B_i = D.pull()
313358
if not S_i:
314359
Bprime = B
@@ -317,7 +362,7 @@ def _bmssp(self, level: int, B: Float, S: Set[Vertex]) -> Tuple[Float, Set[Verte
317362
self.complete[u] = True
318363
return Bprime, U_accum
319364

320-
B_i_prime, U_i = self._bmssp(level - 1, B_i, S_i)
365+
B_i_prime, U_i = self._bmssp(level - 1, B_i, S_i, depth + 1)
321366
for u in U_i:
322367
self.complete[u] = True
323368
U_accum |= U_i
@@ -344,6 +389,9 @@ def _bmssp(self, level: int, B: Float, S: Set[Vertex]) -> Tuple[Float, Set[Verte
344389
self.complete[u] = True
345390
return Bprime, U_accum
346391

392+
if pull_iterations >= max_pull_iterations:
393+
self.counters["iterations_protected"] += 1
394+
347395
return B, U_accum
348396

349397
# ---------- public API ------------------------------------------------
@@ -413,11 +461,15 @@ def path(self, target_original: Vertex) -> List[Vertex]:
413461
return [] # unreachable
414462
src_clone = self.root[start_clone]
415463

416-
# Walk predecessors in clone-space
464+
# Walk predecessors in clone-space with safety limits
417465
chain: List[int] = []
418466
cur: Optional[int] = start_clone
419467
seen = set()
420-
while cur is not None:
468+
iterations = 0
469+
max_iterations = self.G.n * 2 # Safety limit
470+
471+
while cur is not None and iterations < max_iterations:
472+
iterations += 1
421473
chain.append(cur)
422474
if cur == src_clone:
423475
chain.reverse()

0 commit comments

Comments
 (0)