Skip to content

Commit 87d259a

Browse files
authored
Improve speed of BottomUp (#309)
1 parent f9952f7 commit 87d259a

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

src/ruptures/detection/bottomup.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def __init__(self, model="l2", custom_cost=None, min_size=2, jump=5, params=None
3737

3838
def _grow_tree(self):
3939
"""Grow the entire binary tree."""
40-
partition = [(0, self.n_samples)]
40+
partition = [(-self.n_samples, (0, self.n_samples))]
4141
stop = False
4242
while not stop: # recursively divide the signal
4343
stop = True
44-
start, end = max(partition, key=lambda t: t[1] - t[0])
44+
_, (start, end) = partition[0]
4545
mid = (start + end) * 0.5
4646
bkps = list()
4747
for bkp in range(start, end):
@@ -50,15 +50,15 @@ def _grow_tree(self):
5050
bkps.append(bkp)
5151
if len(bkps) > 0: # at least one admissible breakpoint was found
5252
bkp = min(bkps, key=lambda x: abs(x - mid))
53-
partition.remove((start, end))
54-
partition.append((start, bkp))
55-
partition.append((bkp, end))
53+
heapq.heappop(partition)
54+
heapq.heappush(partition, (-bkp + start, (start, bkp)))
55+
heapq.heappush(partition, (-end + bkp, (bkp, end)))
5656
stop = False
5757

58-
partition.sort()
58+
partition.sort(key=lambda x: x[1])
5959
# compute segment costs
6060
leaves = list()
61-
for start, end in partition:
61+
for _, (start, end) in partition:
6262
val = self.cost.error(start, end)
6363
leaf = Bnode(start, end, val)
6464
leaves.append(leaf)
@@ -87,6 +87,7 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None):
8787
dict: partition dict {(start, end): cost value,...}
8888
"""
8989
leaves = sorted(self.leaves)
90+
keys = [leaf.start for leaf in leaves]
9091
removed = set()
9192
merged = []
9293
for left, right in pairwise(leaves):
@@ -121,10 +122,13 @@ def _seg(self, n_bkps=None, pen=None, epsilon=None):
121122
if not stop:
122123
# updates the list of leaves (i.e. segments of the partitions)
123124
# find the merged segments indexes
124-
keys = [leaf.start for leaf in leaves]
125125
left_idx = bisect_left(keys, leaf.left.start)
126-
leaves[left_idx] = leaf # replace leaf.left
127-
del leaves[left_idx + 1] # remove leaf.right
126+
# replace leaf.left
127+
leaves[left_idx] = leaf
128+
keys[left_idx] = leaf.start
129+
# remove leaf.right
130+
del leaves[left_idx + 1]
131+
del keys[left_idx + 1]
128132
# add to the set of removed segments.
129133
removed.add(leaf.left)
130134
removed.add(leaf.right)

0 commit comments

Comments
 (0)