Dispatch to use fp32 distance computation in NN Descent depending on data dimensions #1415
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Closes #1370
Closes #195
From heuristics, chose dim=16 as the threshold for dispatching to a fp32 distance kernel.
We no longer use wmma in the fp32 kernel. Originally wmma was done on matrices of shape [64 x 32] x [32 x 64] per block (multiple iterations if data dimension is larger than 32).
We do manual computation, but since we only target small dimensions, fp32 dispatching ends up being slightly faster end to end with much better recall for small dimensions.
All number below are run on L40 machine and AMD EPYC CPU with 128 cores. Perf and recall is averaged over 5 runs and all time is in seconds. Baseline knn graph is computed using
sklearn.neighbors.NearestNeighborsbrute for method.Max iters=20
For larger dimensions there is an inherent issue with the NN Descent algorithm itself that makes the recall low. This can be improved slightly with more iterations.
Also notice that the e2e time taken is similar or slightly less for using fp32.
Max iters=100
Notice how the blue part, the recall doesn't get better compared to the table above even with more iterations (i.e. why we need the fp32 appraoch for this part)