-
Notifications
You must be signed in to change notification settings - Fork 569
Fix UMAP graph thresholding #6595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: branch-25.06
Are you sure you want to change the base?
Fix UMAP graph thresholding #6595
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this PR Victor! I have a question before I move on with further reviews. : )
cpp/src/umap/runner.cuh
Outdated
|
||
value_t threshold = get_threshold(handle, fss_graph, inputs.n, params->n_epochs); | ||
perform_thresholding(handle, fss_graph, threshold); | ||
|
||
raft::sparse::op::coo_remove_zeros<value_t>(&fss_graph, graph, stream); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
graph
is the final output exposed at the python level. It looks like umap-learn
runs the fuzzy simplicial operation (link, corresponding to our FuzzySimplSetImpl::symmetrize
, then eliminates zeros, which becomes the final output.
I think if we doperform_thresholding
here, we end up storing the graph after thresholding as the final graph, which seems different from that umap-learn
has?
I understand that we need to do the thresholding before the embedding init, but think by doing this we end up with a different output graph (which should correspond to umap-learn
's self.graph_
). Please correct me if I'm wrong!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking was that once the graph reaches the Python layer, we can no longer make assumptions about how the user intends to use it—so we should avoid modifying or potentially corrupting it. Since our goal is to trim the graph, that would require creating a copy, which could significantly increase VRAM usage. I acknowledge this approach doesn't exactly align with how umap-learn
handles it, but I assumed that most users would primarily be interested in a fuzzy simplicial set graph that retains only the most important connections. I had also even considered performing the trimming before the symmetrization step, since the thrust operations there appears to cause a memory spike—if I’ve understood correctly.
However, I just realized that because the trimming threshold depends on the number of training epochs, trimming before storing the graph might work well with the UMAP estimator, but not as effectively with fuzzy_simplicial_set
, which lacks access to the number of epochs. In that case, creating a copy might indeed be the only viable option—unless the graph is stored as a SciPy array on the Cython side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, creating a copy might indeed be the only viable option—unless the graph is stored as a SciPy array on the Cython side?
When Corey and I discussed this a month or two ago, IIRC we came to agreement that the graph_
attribute could (and probably should) be returned to the user on host. It's mostly there for introspection or for umap-learn compatibility. We do want to make sure the graph is being treated roughly the same as umap-learn does so zero-code-change methods that reference it (e.g. merging of estimators) work roughly the same. Since we never reference it ourselves, having it on host would be fine and would decrease the device memory pressure at this point.
Answers #6539
This PR:
fit
method or within thefuzzy_simplicial_set
function.simplicial_set_embedding
function, using a copied graph only when necessary to avoid unnecessary overhead.