Skip to content

Commit 79a2668

Browse files
authored
Merge pull request #26 from cournape/feat/benchmarking
Feat/benchmarking
2 parents bbf7d58 + 6a07cc1 commit 79a2668

File tree

5 files changed

+523
-329
lines changed

5 files changed

+523
-329
lines changed

scripts/compare-against-arpack.py

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@
1010
import scipy.sparse as sp
1111

1212
from scipy.linalg import toeplitz
13-
from scipy.sparse.linalg import eigs, LinearOperator
13+
from scipy.sparse.linalg import LinearOperator
1414

15-
from arnoldi.krylov_schur import partial_schur
1615
from arnoldi.utils import arg_largest_magnitude, arg_largest_real
1716

1817

1918
HERE = os.path.dirname(__file__)
2019
sys.path.insert(0, HERE)
2120

2221
from utils import (
23-
WHICH_TO_SORT, MatvecCounter, find_best_matching, load_suitesparse_mat,
24-
print_residuals
22+
WHICH_TO_SORT, EigensolverParameters, MatvecCounter, arnoldi_py_eig, arpack_eig,
23+
find_best_matching, load_suitesparse_mat, print_residuals
2524
)
2625

2726

@@ -108,95 +107,51 @@ def main():
108107
# complex types, so we cast both solvers to the same dtype.
109108
A = A_raw.astype(np.complex128)
110109

111-
max_dim = args.max_dim if args.max_dim is not None else min(max(2 * nev + 1, 20), n)
112-
110+
parameters = EigensolverParameters.from_cli_args(args, n)
111+
parameters.p = p
113112
print(f"Matrix: {args.mat_file}")
114113
print(f" shape={n}x{n}, nnz={nnz}, dtype={A.dtype}")
115-
print(
116-
f" nev={nev}, tol={tol}, max_dim={max_dim}, "
117-
f"max_restarts={max_it}, which={which}"
118-
)
114+
print(parameters)
119115

120116
# ------------------------------------------------------------------
121117
# ARPACK
122118
# ------------------------------------------------------------------
123119
print(f"\n--- Running ARPACK ---")
124-
arpack_counter = MatvecCounter(A)
125-
t0 = time.perf_counter()
126-
arpack_vals, arpack_vecs = eigs(
127-
arpack_counter,
128-
k=nev,
129-
which=which,
130-
ncv=max_dim,
131-
tol=tol,
132-
maxiter=max_it,
133-
)
134-
arpack_elapsed = time.perf_counter() - t0
135-
136-
# Sort by descending real part for consistent display
137-
idx = np.argsort(-arpack_vals.real)
138-
arpack_vals = arpack_vals[idx]
139-
arpack_vecs = arpack_vecs[:, idx]
140-
141-
matvecs = arpack_counter.matvecs
142-
n_iters = (matvecs - max_dim) // (max_dim - nev)
143-
print(f" matvecs={arpack_counter.matvecs}, elapsed={arpack_elapsed:.2f}s for {n_iters} iterations")
120+
arpack_vals, arpack_vecs, arpack_stats = arpack_eig(A, parameters)
121+
print(f" matvecs={arpack_stats.matvecs}, elapsed={arpack_stats.elapsed:.2f}s for {arpack_stats.restarts} iterations")
144122

145123
# ------------------------------------------------------------------
146124
# partial_schur
147125
# ------------------------------------------------------------------
148-
print(f"\n--- Running partial_schur (p = {p}) ---")
149-
ps_counter = MatvecCounter(A)
150-
t0 = time.perf_counter()
151-
pQ, pT, history = partial_schur(
152-
ps_counter,
153-
nev,
154-
max_dim=max_dim,
155-
stopping_criterion=tol,
156-
max_restarts=max_it,
157-
sort_function=sort_function,
158-
p=p,
159-
)
160-
ps_elapsed = time.perf_counter() - t0
161-
162-
# Extract eigenpairs from the partial Schur form
163-
ps_eig_vals, S = np.linalg.eig(pT)
164-
ps_eig_vecs = pQ @ S
165-
166-
idx = np.argsort(-ps_eig_vals.real)
167-
ps_eig_vals = ps_eig_vals[idx]
168-
ps_eig_vecs = ps_eig_vecs[:, idx]
169-
170-
ps_matvecs = int(np.max(history.matvecs))
171-
n_iters = np.max(history.restarts)
172-
print(f" matvecs={ps_matvecs}, elapsed={ps_elapsed:.2f}s")
173-
print(f" matvecs={ps_matvecs}, elapsed={ps_elapsed:.2f}s for {n_iters} iterations")
126+
print(f"\n--- Running partial_schur (p = {parameters.p}) ---")
127+
ps_vals, ps_vecs, ps_stats = arnoldi_py_eig(A, parameters)
128+
print(f" matvecs={ps_stats.matvecs}, elapsed={ps_stats.elapsed:.2f}s for {ps_stats.restarts} iterations")
174129

175130
# ------------------------------------------------------------------
176131
# True residuals
177132
# ------------------------------------------------------------------
178133
print_residuals("ARPACK", A, arpack_vals, arpack_vecs)
179-
print_residuals("partial_schur", A, ps_eig_vals, ps_eig_vecs)
134+
print_residuals("partial_schur", A, ps_vals, ps_vecs)
180135

181136
# ------------------------------------------------------------------
182137
# Matvec comparison
183138
# ------------------------------------------------------------------
184-
arpack_mv = arpack_counter.matvecs
185-
pct = (ps_matvecs - arpack_mv) / arpack_mv * 100
139+
arpack_matvecs = arpack_stats.matvecs
140+
ps_matvecs = ps_stats.matvecs
141+
pct = (ps_matvecs - arpack_matvecs) / arpack_matvecs * 100
186142
direction = "more" if pct >= 0 else "fewer"
187143

188-
print(f"\n--- Matvec comparison ---")
189-
print(f" ARPACK: {arpack_mv} matvecs ({arpack_elapsed:.2f}s)")
190-
print(f" partial_schur: {ps_matvecs} matvecs ({ps_elapsed:.2f}s)")
144+
print(f"\n--- Perf comparison ---")
145+
print(f" ARPACK: {arpack_matvecs} matvecs in {arpack_stats.restarts} iterations ({arpack_stats.elapsed:.2f}s)")
146+
print(f" partial_schur: {ps_matvecs} matvecs in {ps_stats.restarts} iterations ({ps_stats.elapsed:.2f}s)")
191147
print(f" partial_schur uses {abs(pct):.1f}% {direction} matvecs than ARPACK")
192-
print(history)
193148

194-
print(arpack_vals)
195-
print(ps_eig_vals)
149+
# print(arpack_vals)
150+
# print(ps_vals)
196151

197152
# Ensure the eigenvalues match. This check + ensure normalized residuals
198153
# are close to 0 should be enough to ensure the output is correct.
199-
x, y = find_best_matching(arpack_vals, ps_eig_vals)
154+
x, y = find_best_matching(arpack_vals, ps_vals)
200155
np.testing.assert_allclose(x, y, rtol=tol)
201156

202157

0 commit comments

Comments
 (0)