Skip to content

Commit aa9dc73

Browse files
authored
Merge pull request #1 from remydubois/feature/argpartition
Feature/argpartition
2 parents f1f4b17 + 69f777f commit aa9dc73

11 files changed

Lines changed: 145 additions & 34 deletions

File tree

README.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LSNMS
22
Speeding up Non Maximum Suppresion ran on very large images by a several folds factor, using a sparse implementation of NMS.
3-
This project describes a "sparse" implementation of Non Maximum Suppression, useful in the case of very high dimensional images data, when the amount of predicted instances to prune becomes considerable (> 10,000 objects).
3+
This project becomes useful in the case of very high dimensional images data, when the amount of predicted instances to prune becomes considerable (> 10,000 objects).
44

55
<p float="center">
66
<center><img src="https://raw.githubusercontent.com/remydubois/lsnms/main/assets/images/timings_medium_image.png?token=AEJMSVNEIBF2PMWIVASMKATAMIKHS" width="700" />
@@ -121,22 +121,19 @@ tree = BallTree(data, leaf_size=16)
121121

122122

123123
## Performances
124-
The BallTree implemented in this repo was timed against scikit-learn's `neighbors` one.
124+
The BallTree implemented in this repo was timed against scikit-learn's `neighbors` one. Note that runtimes are not fair to compare since sklearn implementation allows for node to contain
125+
between `leaf_size` and `2 * leaf_size` datapoints. To account for this, I timed my implementation against sklearn tree with `int(0.67 * leaf_size)` as `leaf_size`.
125126
### Tree building time
126127
<p float="center">
127128
<center><img src="https://github.com/remydubois/lsnms/blob/main/assets/images/building_timings.png" width="700" />
128129
<figcaption>Trees building times comparison</figcaption></center>
129130
</p>
130131

131-
The (minor) slow down observed against sklearn implementation is probably related to the node-splitting process. I used the median cutoff (compute median, then assign datapoints depending on their value above or below median) but it is suboptimal: a proper pivot algorithm could easily be implemented.
132132

133133
### Tree query time
134134
<p float="center">
135135
<center><img src="https://github.com/remydubois/lsnms/blob/main/assets/images/query_timings.png" width="700" />
136136
<figcaption>Trees query times comparison (single query, radius=100) in a 1000x1000 space</figcaption></center>
137137
</p>
138138

139-
Query time are somehow identical. However, my implementation does seem to not scale as well as scikit-learn's one, a minor slowdown could be observed for extremely large datasets (million-ish data points).
140-
141-
### Warnings
142-
Because input data needs to be typed: the dimensionality of the process is fixed in advance. This BallTree implementation can not work on 3D and above data (although it is a one-liner fix).
139+
Query time are somehow identical.

assets/images/building_timings.png

5.06 KB
Loading

assets/images/query_timings.png

5.76 KB
Loading

changelog.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Changelog
2+
=========
3+
4+
5+
(unreleased)
6+
------------
7+
- - added changelog, upgraded to version 0.1.1. [Rémy Dubois]
8+
- - black. [Rémy Dubois]
9+
- - improved the node splitting method - Updated the runtimes comparison
10+
versus sklearn - fixed little typos in the readme - fixed typechecks
11+
in the trees. [Rémy Dubois]
12+
- -readme. [Rémy Dubois]
13+
- Typo + fixed image urls. [Rémy Dubois]
14+
- - poetry. [Rémy Dubois]
15+
- - poetry. [Rémy Dubois]
16+
- - poetry. [Rémy Dubois]
17+
- Initial commit. [Rémy Dubois]
18+
19+

lsnms/balltree.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(self, data, leaf_size=16, indices=None):
5454
# Stores the data
5555
self.data = data
5656

57+
if len(self.data) == 0:
58+
raise ValueError("Empty data")
59+
5760
# Stores indices of each data point
5861
if indices is None:
5962
self.indices = np.arange(len(data))
@@ -139,8 +142,8 @@ def query_radius(self, X, max_radius):
139142
"""
140143
if X.ndim > 1:
141144
raise ValueError("query_radius only works on single query point.")
142-
if len(X) != 2:
143-
raise ValueError("Query point must be two-dimensional")
145+
if X.shape[-1] != self.dimensionality:
146+
raise ValueError("Tree and query dimensionality do not match")
144147
# Initialize empty list of int64
145148
# Needs to be typed
146149
buffer = [0][:0]

lsnms/kdtree.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def __init__(self, data, leaf_size=16, axis=0, indices=None):
5757
self.axis = axis
5858
self.dimensionality = data.shape[-1]
5959

60+
if len(self.data) == 0:
61+
raise ValueError("Empty data")
62+
6063
# Stores indices of each data point
6164
if indices is None:
6265
self.indices = np.arange(len(data))
@@ -145,8 +148,8 @@ def query_radius(self, X, max_radius):
145148
"""
146149
if X.ndim > 1:
147150
raise ValueError("query_radius only works on single query point.")
148-
if len(X) != 2:
149-
raise ValueError("Query point must be two-dimensional")
151+
if X.shape[-1] != self.dimensionality:
152+
raise ValueError("Tree and query dimensionality do not match")
150153
# Initialize empty list of int64
151154
# Needs to be typed
152155
buffer = [0][:0]

lsnms/util.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def max_spread_axis(data):
142142
def split_along_axis(data, axis):
143143
"""
144144
Splits the data along axis in two datasets of equal size.
145-
Note that this could probably be optimized further, by implementing the median algorithm from
146-
scratch.
145+
This method uses an adapted re-implementation of `np.argpartition`
147146
148147
Parameters
149148
----------
@@ -157,17 +156,7 @@ def split_along_axis(data, axis):
157156
Tuple[np.array]
158157
Left data point indices, right data point indices
159158
"""
160-
indices = np.arange(len(data))
161-
cap = np.median(data[:, axis])
162-
mask = data[:, axis] <= cap
163-
n_left = mask.sum()
164-
# Account for the case where all positions along this axis are equal: split in the middle
165-
if n_left == len(data) or n_left == 0:
166-
left = indices[: len(indices) // 2]
167-
right = indices[len(indices) // 2 :]
168-
else:
169-
left = indices[mask]
170-
right = indices[np.logical_not(mask)]
159+
left, right = median_argsplit(data[:, axis])
171160
return left, right
172161

173162

@@ -217,3 +206,99 @@ def englobing_box(data):
217206
bounds.insert(j, data[:, j].min())
218207
bounds.insert(2 * j + 1, data[:, j].max())
219208
return np.array(bounds)
209+
210+
211+
@njit
212+
def _partition(A, low, high, indices):
213+
"""
214+
This is straight from numba master:
215+
https://github.com/numba/numba/blob/b5bd9c618e20985acb0b300d52d57595ef6f5442/numba/np/arraymath.py#L1155
216+
I modified it so the swaps operate on the indices as well, because I need a argpartition
217+
"""
218+
mid = (low + high) >> 1
219+
# NOTE: the pattern of swaps below for the pivot choice and the
220+
# partitioning gives good results (i.e. regular O(n log n))
221+
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
222+
# risk breaking this property.
223+
# Use median of three {low, middle, high} as the pivot
224+
if A[mid] < A[low]:
225+
A[low], A[mid] = A[mid], A[low]
226+
indices[low], indices[mid] = indices[mid], indices[low]
227+
if A[high] < A[mid]:
228+
A[high], A[mid] = A[mid], A[high]
229+
indices[high], indices[mid] = indices[mid], indices[high]
230+
if A[mid] < A[low]:
231+
A[low], A[mid] = A[mid], A[low]
232+
indices[low], indices[mid] = indices[mid], indices[low]
233+
pivot = A[mid]
234+
235+
A[high], A[mid] = A[mid], A[high]
236+
indices[high], indices[mid] = indices[mid], indices[high]
237+
i = low
238+
j = high - 1
239+
while True:
240+
while i < high and A[i] < pivot:
241+
i += 1
242+
while j >= low and pivot < A[j]:
243+
j -= 1
244+
if i >= j:
245+
break
246+
A[i], A[j] = A[j], A[i]
247+
indices[i], indices[j] = indices[j], indices[i]
248+
i += 1
249+
j -= 1
250+
# Put the pivot back in its final place (all items before `i`
251+
# are smaller than the pivot, all items at/after `i` are larger)
252+
# print(A)
253+
A[i], A[high] = A[high], A[i]
254+
indices[i], indices[high] = indices[high], indices[i]
255+
256+
return i
257+
258+
259+
@njit
260+
def _select(arry, k, low, high):
261+
"""
262+
This is straight from numba master:
263+
https://github.com/numba/numba/blob/b5bd9c618e20985acb0b300d52d57595ef6f5442/numba/np/arraymath.py#L1155
264+
Select the k'th smallest element in array[low:high + 1].
265+
"""
266+
indices = np.arange(len(arry))
267+
i = _partition(arry, low, high, indices)
268+
while i != k:
269+
if i < k:
270+
low = i + 1
271+
i = _partition(arry, low, high, indices)
272+
else:
273+
high = i - 1
274+
i = _partition(arry, low, high, indices)
275+
return indices, i
276+
277+
278+
@njit
279+
def median_argsplit(arry):
280+
"""
281+
Splits `arry` into two sets of indices, indicating values
282+
above and below the pivot value. Often, pivot is the median.
283+
284+
This is approx. three folds faster than computing the median,
285+
then find indices of values below (left indices) and above (right indices)
286+
287+
Parameters
288+
----------
289+
arry : np.array
290+
One dimensional values array
291+
292+
Returns
293+
-------
294+
Tuple[np.array]
295+
Indices of values below median, indices of values above median
296+
"""
297+
low = 0
298+
high = len(arry) - 1
299+
k = len(arry) >> 1
300+
tmp_arry = arry.flatten()
301+
indices, i = _select(tmp_arry, k, low, high)
302+
left = indices[:k]
303+
right = indices[k:]
304+
return left, right

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "lsnms"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "Large Scale Non Maximum Suppression"
55
authors = ["Rémy Dubois <remydubois14@gmail.com>"]
66
license = "MIT"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="lsnms",
5-
version="0.1.0",
5+
version="0.1.1",
66
description="Large Scale Non Maximum Suppression",
77
author="Rémy Dubois",
88
install_requires=["numpy==1.19.5", "numba==0.53.1"],

tests/timings_balltree.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ def test_tree_query_timing():
1515
ns = np.arange(1000, 200000, 10000)
1616
ts = []
1717
naive_ts = []
18+
leaf_size = 64
1819
repeats = 100
1920
for n in ns:
2021
data = np.random.uniform(0, 1000, (n, 2))
21-
sk_tree = skBT(data, leaf_size=16)
22-
tree = BallTree(data, leaf_size=16)
22+
sk_tree = skBT(data, leaf_size=leaf_size)
23+
tree = BallTree(data, leaf_size=int(leaf_size * 0.67))
2324
_ = tree.query_radius(data[0], 200.0)
2425
timer = Timer(lambda: tree.query_radius(data[0], 100.0))
2526
ts.append(timer.timeit(number=repeats) / repeats * 1000)
@@ -42,13 +43,14 @@ def test_tree_building_timing():
4243

4344
ns = np.arange(1000, 300000, 25000)
4445
ts = []
46+
leaf_size = 64
4547
naive_ts = []
4648
for n in ns:
4749
data = np.random.uniform(0, n, (n, 2))
4850
_ = BallTree(data, 16)
49-
timer = Timer(lambda: BallTree(data, 16))
51+
timer = Timer(lambda: BallTree(data, leaf_size))
5052
ts.append(timer.timeit(number=5) / 5)
51-
naive_timer = Timer(lambda: skBT(data, 16))
53+
naive_timer = Timer(lambda: skBT(data, int(leaf_size * 0.67)))
5254
naive_ts.append(naive_timer.timeit(5) / 5)
5355

5456
with plt.xkcd():

0 commit comments

Comments
 (0)