@@ -227,11 +227,15 @@ def k_means(points, k, **kwargs):
227227def tsp (points , start = 0 ):
228228 """
229229 Find an ordering of points where each is visited and
230- the next point is the closest in euclidean distance.
230+ the next point is the closest in euclidean distance,
231+ and if there are multiple points with equal distance
232+ go to an arbitrary one.
231233
232234 Assumes every point is visitable from every other point,
233235 i.e. the travelling salesman problem on a fully connected
234- graph.
236+ graph. It is not a MINIMUM traversal; rather it is a
237+ "not totally goofy traversal, quickly." On random points
238+ this traversal is often ~20x shorter than random ordering.
235239
236240 Parameters
237241 ---------------
@@ -249,32 +253,41 @@ def tsp(points, start=0):
249253 """
250254 # points should be float
251255 points = np .asanyarray (points , dtype = np .float64 )
256+
257+ if len (points .shape ) != 2 :
258+ raise ValueError ('points must be (n, dimension)!' )
259+
252260 # start should be an index
253261 start = int (start )
254262
255263 # a mask of unvisited points by index
256264 unvisited = np .ones (len (points ), dtype = np .bool )
257265 unvisited [start ] = False
266+
258267 # traversal of points by index
259- traversal = [start ]
268+ traversal = np .zeros (len (points ), dtype = np .int64 ) - 1
269+ traversal [0 ] = start
260270 # list of distances
261- distances = []
271+ distances = np . zeros ( len ( points ) - 1 , dtype = np . float64 )
262272 # a mask of indexes in order
263273 index_mask = np .arange (len (points ), dtype = np .int64 )
264274
265- # bound our traversal in case it's dumb
266- for i in range (len (points ) + 2 ):
267-
268- # we should always exit via this break
269- if not unvisited .any ():
270- break
275+ # in the loop we want to call distances.sum(axis=1)
276+ # a lot and it's actually kind of slow for "reasons"
277+ # but dot products with ones are equivilant and roughly
278+ # 2x faster
279+ sum_ones = np .ones (points .shape [1 ])
271280
272- # which point are we currently at
281+ # loop through all points
282+ for i in range (len (points ) - 1 ):
283+ # which point are we currently on
273284 current = points [traversal [i ]]
274285
275286 # do NlogN distance query
276- # use sum instead of np.linalg.norm as it is slightly faster
277- dist = ((points [unvisited ] - current ) ** 2 ).sum (axis = 1 ) ** 0.5
287+ # use dot instead of .sum(axis=1) or np.linalg.norm
288+ # as it is much faster, also don't square root
289+ dist = np .dot ((points [unvisited ] - current ) ** 2 ,
290+ sum_ones )
278291
279292 # minimum distance index
280293 min_index = dist .argmin ()
@@ -283,13 +296,12 @@ def tsp(points, start=0):
283296 # update the mask
284297 unvisited [successor ] = False
285298 # append the index to the traversal
286- traversal . append ( successor )
299+ traversal [ i + 1 ] = successor
287300 # append the distance
288- distances . append ( dist [min_index ])
301+ distances [ i ] = dist [min_index ]
289302
290- # make sure results are numpy arrays of correct dtype
291- traversal = np .array (traversal , dtype = np .int64 )
292- distances = np .array (distances , dtype = np .float64 )
303+ # we were comparing squared distance so root result
304+ distances **= 0.5
293305
294306 return traversal , distances
295307
0 commit comments