Skip to content

Commit 0363eb5

Browse files
committed
speed up points traversal
1 parent 88e01c8 commit 0363eb5

File tree

5 files changed

+41
-21
lines changed

5 files changed

+41
-21
lines changed

Dockerfile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ RUN bash /tmp/conda.bash
2424

2525
# install python requirements
2626
COPY . /tmp/trimesh
27-
RUN /home/user/conda/bin/pip install /tmp/trimesh[all]
27+
RUN /home/user/conda/bin/pip install /tmp/trimesh[all] pytest
2828

2929
# add user python to path
3030
ENV PATH="/home/user/conda/bin:$PATH"
@@ -34,3 +34,6 @@ ENV XVFB_WHD="1920x1080x24"\
3434
DISPLAY=":99" \
3535
LIBGL_ALWAYS_SOFTWARE="1" \
3636
GALLIUM_DRIVER="llvmpipe"
37+
38+
# make sure build fails if tests are failing
39+
RUN pytest /tmp/trimesh/tests

tests/test_points.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def test_tsp(self):
129129
assert len(idx) == len(points)
130130
assert len(dist) == len(points) - 1
131131

132+
# shouldn't be any negative indexes
133+
assert (idx >= 0).all()
134+
132135
# make sure distances returned are correct
133136
dist_check = g.np.linalg.norm(g.np.diff(points[idx], axis=0), axis=1)
134137
assert g.np.allclose(dist_check, dist)

trimesh/points.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,15 @@ def k_means(points, k, **kwargs):
227227
def tsp(points, start=0):
228228
"""
229229
Find an ordering of points where each is visited and
230-
the next point is the closest in euclidean distance.
230+
the next point is the closest in euclidean distance,
231+
and if there are multiple points with equal distance
232+
go to an arbitrary one.
231233
232234
Assumes every point is visitable from every other point,
233235
i.e. the travelling salesman problem on a fully connected
234-
graph.
236+
graph. It is not a MINIMUM traversal; rather it is a
237+
"not totally goofy traversal, quickly." On random points
238+
this traversal is often ~20x shorter than random ordering.
235239
236240
Parameters
237241
---------------
@@ -249,32 +253,41 @@ def tsp(points, start=0):
249253
"""
250254
# points should be float
251255
points = np.asanyarray(points, dtype=np.float64)
256+
257+
if len(points.shape) != 2:
258+
raise ValueError('points must be (n, dimension)!')
259+
252260
# start should be an index
253261
start = int(start)
254262

255263
# a mask of unvisited points by index
256264
unvisited = np.ones(len(points), dtype=np.bool)
257265
unvisited[start] = False
266+
258267
# traversal of points by index
259-
traversal = [start]
268+
traversal = np.zeros(len(points), dtype=np.int64) - 1
269+
traversal[0] = start
260270
# list of distances
261-
distances = []
271+
distances = np.zeros(len(points) - 1, dtype=np.float64)
262272
# a mask of indexes in order
263273
index_mask = np.arange(len(points), dtype=np.int64)
264274

265-
# bound our traversal in case it's dumb
266-
for i in range(len(points) + 2):
267-
268-
# we should always exit via this break
269-
if not unvisited.any():
270-
break
275+
# in the loop we want to call distances.sum(axis=1)
276+
# a lot and it's actually kind of slow for "reasons"
277+
# but dot products with ones are equivilant and roughly
278+
# 2x faster
279+
sum_ones = np.ones(points.shape[1])
271280

272-
# which point are we currently at
281+
# loop through all points
282+
for i in range(len(points) - 1):
283+
# which point are we currently on
273284
current = points[traversal[i]]
274285

275286
# do NlogN distance query
276-
# use sum instead of np.linalg.norm as it is slightly faster
277-
dist = ((points[unvisited] - current) ** 2).sum(axis=1) ** 0.5
287+
# use dot instead of .sum(axis=1) or np.linalg.norm
288+
# as it is much faster, also don't square root
289+
dist = np.dot((points[unvisited] - current) ** 2,
290+
sum_ones)
278291

279292
# minimum distance index
280293
min_index = dist.argmin()
@@ -283,13 +296,12 @@ def tsp(points, start=0):
283296
# update the mask
284297
unvisited[successor] = False
285298
# append the index to the traversal
286-
traversal.append(successor)
299+
traversal[i + 1] = successor
287300
# append the distance
288-
distances.append(dist[min_index])
301+
distances[i] = dist[min_index]
289302

290-
# make sure results are numpy arrays of correct dtype
291-
traversal = np.array(traversal, dtype=np.int64)
292-
distances = np.array(distances, dtype=np.float64)
303+
# we were comparing squared distance so root result
304+
distances **= 0.5
293305

294306
return traversal, distances
295307

trimesh/util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def unitize(vectors,
7575
if len(vectors.shape) == 2:
7676
# for (m, d) arrays take the per- row unit vector
7777
# using sqrt and avoiding exponents is slightly faster
78-
norm = np.sqrt((vectors * vectors).sum(axis=1))
78+
# also dot with ones is faser than .sum(axis=1)
79+
norm = np.sqrt(np.dot(vectors * vectors,
80+
np.ones(vectors.shape[1])))
7981
# non-zero norms
8082
valid = norm > threshold
8183
# in-place reciprocal of nonzero norms

trimesh/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.34.3'
1+
__version__ = '2.34.4'

0 commit comments

Comments
 (0)