Skip to content

Commit a9e5790

Browse files
authored
Add HeapSet.peekn (#6947)
1 parent 599708e commit a9e5790

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

distributed/collections.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import heapq
4+
import itertools
45
import weakref
56
from collections import OrderedDict, UserDict
67
from collections.abc import Callable, Hashable, Iterator
@@ -99,6 +100,20 @@ def peek(self) -> T:
99100
return value
100101
heapq.heappop(self._heap)
101102

103+
def peekn(self, n: int) -> Iterator[T]:
104+
"Iterator over the N smallest elements. This is O(1) for n == 1, O(n*logn) otherwise."
105+
if n <= 0:
106+
return # empty iterator
107+
if n == 1:
108+
yield self.peek()
109+
else:
110+
# NOTE: we could pop N items off the queue, then push them back.
111+
# But copying the list N times is probably slower than just sorting it
112+
# with fast C code.
113+
# If we had a `heappop` that sliced the list instead of popping from it,
114+
# we could implement an optimized version for small `n`s.
115+
yield from itertools.islice(self.sorted(), n)
116+
102117
def pop(self) -> T:
103118
if not self._data:
104119
raise KeyError("pop from an empty set")

distributed/tests/test_collections.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __hash__(self):
3838
def __eq__(self, other):
3939
return isinstance(other, C) and other.k == self.k
4040

41+
def __repr__(self):
42+
return f"C({self.k}, {self.i})"
43+
4144

4245
def test_heapset():
4346
heap = HeapSet(key=operator.attrgetter("i"))
@@ -131,6 +134,21 @@ def test_heapset():
131134
heap.add(cx)
132135
assert cx in heap
133136

137+
# Test peekn()
138+
heap.add(cy)
139+
heap.add(cw)
140+
heap.add(cz)
141+
heap.add(cx)
142+
assert list(heap.peekn(3)) == [cy, cx, cz]
143+
heap.remove(cz)
144+
assert list(heap.peekn(10)) == [cy, cx, cw]
145+
assert list(heap.peekn(0)) == []
146+
assert list(heap.peekn(-1)) == []
147+
heap.remove(cy)
148+
assert list(heap.peekn(1)) == [cx]
149+
heap.remove(cw)
150+
assert list(heap.peekn(1)) == [cx]
151+
134152
# Test resilience to failure in key()
135153
bad_key = C("bad_key", 0)
136154
del bad_key.i

0 commit comments

Comments
 (0)