|
10 | 10 | import scipy.sparse as sp |
11 | 11 |
|
12 | 12 | from scipy.linalg import toeplitz |
13 | | -from scipy.sparse.linalg import eigs, LinearOperator |
| 13 | +from scipy.sparse.linalg import LinearOperator |
14 | 14 |
|
15 | | -from arnoldi.krylov_schur import partial_schur |
16 | 15 | from arnoldi.utils import arg_largest_magnitude, arg_largest_real |
17 | 16 |
|
18 | 17 |
|
19 | 18 | HERE = os.path.dirname(__file__) |
20 | 19 | sys.path.insert(0, HERE) |
21 | 20 |
|
22 | 21 | from utils import ( |
23 | | - WHICH_TO_SORT, MatvecCounter, find_best_matching, load_suitesparse_mat, |
24 | | - print_residuals |
| 22 | + WHICH_TO_SORT, EigensolverParameters, MatvecCounter, arnoldi_py_eig, arpack_eig, |
| 23 | + find_best_matching, load_suitesparse_mat, print_residuals |
25 | 24 | ) |
26 | 25 |
|
27 | 26 |
|
@@ -108,95 +107,51 @@ def main(): |
108 | 107 | # complex types, so we cast both solvers to the same dtype. |
109 | 108 | A = A_raw.astype(np.complex128) |
110 | 109 |
|
111 | | - max_dim = args.max_dim if args.max_dim is not None else min(max(2 * nev + 1, 20), n) |
112 | | - |
| 110 | + parameters = EigensolverParameters.from_cli_args(args, n) |
| 111 | + parameters.p = p |
113 | 112 | print(f"Matrix: {args.mat_file}") |
114 | 113 | print(f" shape={n}x{n}, nnz={nnz}, dtype={A.dtype}") |
115 | | - print( |
116 | | - f" nev={nev}, tol={tol}, max_dim={max_dim}, " |
117 | | - f"max_restarts={max_it}, which={which}" |
118 | | - ) |
| 114 | + print(parameters) |
119 | 115 |
|
120 | 116 | # ------------------------------------------------------------------ |
121 | 117 | # ARPACK |
122 | 118 | # ------------------------------------------------------------------ |
123 | 119 | print(f"\n--- Running ARPACK ---") |
124 | | - arpack_counter = MatvecCounter(A) |
125 | | - t0 = time.perf_counter() |
126 | | - arpack_vals, arpack_vecs = eigs( |
127 | | - arpack_counter, |
128 | | - k=nev, |
129 | | - which=which, |
130 | | - ncv=max_dim, |
131 | | - tol=tol, |
132 | | - maxiter=max_it, |
133 | | - ) |
134 | | - arpack_elapsed = time.perf_counter() - t0 |
135 | | - |
136 | | - # Sort by descending real part for consistent display |
137 | | - idx = np.argsort(-arpack_vals.real) |
138 | | - arpack_vals = arpack_vals[idx] |
139 | | - arpack_vecs = arpack_vecs[:, idx] |
140 | | - |
141 | | - matvecs = arpack_counter.matvecs |
142 | | - n_iters = (matvecs - max_dim) // (max_dim - nev) |
143 | | - print(f" matvecs={arpack_counter.matvecs}, elapsed={arpack_elapsed:.2f}s for {n_iters} iterations") |
| 120 | + arpack_vals, arpack_vecs, arpack_stats = arpack_eig(A, parameters) |
| 121 | + print(f" matvecs={arpack_stats.matvecs}, elapsed={arpack_stats.elapsed:.2f}s for {arpack_stats.restarts} iterations") |
144 | 122 |
|
145 | 123 | # ------------------------------------------------------------------ |
146 | 124 | # partial_schur |
147 | 125 | # ------------------------------------------------------------------ |
148 | | - print(f"\n--- Running partial_schur (p = {p}) ---") |
149 | | - ps_counter = MatvecCounter(A) |
150 | | - t0 = time.perf_counter() |
151 | | - pQ, pT, history = partial_schur( |
152 | | - ps_counter, |
153 | | - nev, |
154 | | - max_dim=max_dim, |
155 | | - stopping_criterion=tol, |
156 | | - max_restarts=max_it, |
157 | | - sort_function=sort_function, |
158 | | - p=p, |
159 | | - ) |
160 | | - ps_elapsed = time.perf_counter() - t0 |
161 | | - |
162 | | - # Extract eigenpairs from the partial Schur form |
163 | | - ps_eig_vals, S = np.linalg.eig(pT) |
164 | | - ps_eig_vecs = pQ @ S |
165 | | - |
166 | | - idx = np.argsort(-ps_eig_vals.real) |
167 | | - ps_eig_vals = ps_eig_vals[idx] |
168 | | - ps_eig_vecs = ps_eig_vecs[:, idx] |
169 | | - |
170 | | - ps_matvecs = int(np.max(history.matvecs)) |
171 | | - n_iters = np.max(history.restarts) |
172 | | - print(f" matvecs={ps_matvecs}, elapsed={ps_elapsed:.2f}s") |
173 | | - print(f" matvecs={ps_matvecs}, elapsed={ps_elapsed:.2f}s for {n_iters} iterations") |
| 126 | + print(f"\n--- Running partial_schur (p = {parameters.p}) ---") |
| 127 | + ps_vals, ps_vecs, ps_stats = arnoldi_py_eig(A, parameters) |
| 128 | + print(f" matvecs={ps_stats.matvecs}, elapsed={ps_stats.elapsed:.2f}s for {ps_stats.restarts} iterations") |
174 | 129 |
|
175 | 130 | # ------------------------------------------------------------------ |
176 | 131 | # True residuals |
177 | 132 | # ------------------------------------------------------------------ |
178 | 133 | print_residuals("ARPACK", A, arpack_vals, arpack_vecs) |
179 | | - print_residuals("partial_schur", A, ps_eig_vals, ps_eig_vecs) |
| 134 | + print_residuals("partial_schur", A, ps_vals, ps_vecs) |
180 | 135 |
|
181 | 136 | # ------------------------------------------------------------------ |
182 | 137 | # Matvec comparison |
183 | 138 | # ------------------------------------------------------------------ |
184 | | - arpack_mv = arpack_counter.matvecs |
185 | | - pct = (ps_matvecs - arpack_mv) / arpack_mv * 100 |
| 139 | + arpack_matvecs = arpack_stats.matvecs |
| 140 | + ps_matvecs = ps_stats.matvecs |
| 141 | + pct = (ps_matvecs - arpack_matvecs) / arpack_matvecs * 100 |
186 | 142 | direction = "more" if pct >= 0 else "fewer" |
187 | 143 |
|
188 | | - print(f"\n--- Matvec comparison ---") |
189 | | - print(f" ARPACK: {arpack_mv} matvecs ({arpack_elapsed:.2f}s)") |
190 | | - print(f" partial_schur: {ps_matvecs} matvecs ({ps_elapsed:.2f}s)") |
| 144 | + print(f"\n--- Perf comparison ---") |
| 145 | + print(f" ARPACK: {arpack_matvecs} matvecs in {arpack_stats.restarts} iterations ({arpack_stats.elapsed:.2f}s)") |
| 146 | + print(f" partial_schur: {ps_matvecs} matvecs in {ps_stats.restarts} iterations ({ps_stats.elapsed:.2f}s)") |
191 | 147 | print(f" partial_schur uses {abs(pct):.1f}% {direction} matvecs than ARPACK") |
192 | | - print(history) |
193 | 148 |
|
194 | | - print(arpack_vals) |
195 | | - print(ps_eig_vals) |
| 149 | + # print(arpack_vals) |
| 150 | + # print(ps_vals) |
196 | 151 |
|
197 | 152 | # Ensure the eigenvalues match. This check + ensure normalized residuals |
198 | 153 | # are close to 0 should be enough to ensure the output is correct. |
199 | | - x, y = find_best_matching(arpack_vals, ps_eig_vals) |
| 154 | + x, y = find_best_matching(arpack_vals, ps_vals) |
200 | 155 | np.testing.assert_allclose(x, y, rtol=tol) |
201 | 156 |
|
202 | 157 |
|
|
0 commit comments