Skip to content

Lanczos Solver which=SA,SM,LA,LM argument #2628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: branch-25.06
Choose a base branch
from

Conversation

aamijar
Copy link
Member

@aamijar aamijar commented Apr 8, 2025

Resolves #2624
Resolves #2483

Copy link

copy-pr-bot bot commented Apr 8, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aamijar aamijar force-pushed the lanczos-solver-which-argument branch from 3585778 to 6327b2f Compare April 8, 2025 18:41
Copy link

copy-pr-bot bot commented Apr 8, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@aamijar aamijar added enhancement New feature or request non-breaking Non-breaking change improvement Improvement / enhancement to an existing function and removed enhancement New feature or request labels Apr 8, 2025
@aamijar aamijar marked this pull request as ready for review April 8, 2025 23:59
@aamijar aamijar requested review from a team as code owners April 8, 2025 23:59
@@ -20,6 +20,8 @@

namespace raft::sparse::solver {

enum LANCZOS_WHICH { LA, LM, SA, SM };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation here on the meaning (smallest/largest - Algebraic/Magnitude)

@@ -192,6 +205,15 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
config_float.ncv = ncv
config_float.tolerance = tol
config_float.seed = seed

if which.lower() == "sa":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Factorize this if-else machine in a function to reuse

@@ -100,7 +108,7 @@ cdef extern from "raft_runtime/solver/lanczos.hpp" \


@auto_sync_handle
def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
def eigsh(A, k=6, which="SA", v0=None, ncv=None, maxiter=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main use case for us is SA but the default param should be the same as cupy and scipy in order to match the API.

auto sm_eigenvectors =
raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, nEigVecs);

if (which == LANCZOS_WHICH::SA) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this additional treatment of which happen in lanczos_solve_ritz? The current function is already long and this will follow more closely cupy's code architecture.
And lanczos_solve_ritz already accepts a which arguments that's unused right now.
(+ avoid code duplication)

}

// Re-sort these indices by algebraic value to maintain algebraic ordering
thrust::sort(thrust::device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cupy's version is doing only one argsort. Is this one absolutely necessary?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change python
Projects
Status: In Progress
2 participants