66// where A is (m x n) with m >> n. Sketching reduces the system from
77// (m x n) to (k x n) with n < k << m, then solves the smaller problem.
88
9- #include < ginkgo/ginkgo.hpp>
10-
119#include < chrono>
1210#include < iomanip>
1311#include < iostream>
1715
1816#include < cxxopts.hpp>
1917
18+ #include < ginkgo/ginkgo.hpp>
19+
2020
2121using ValueType = double ;
2222using RealValueType = gko::remove_complex<ValueType>;
@@ -58,9 +58,9 @@ struct LeastSquaresProblem {
5858 std::unique_ptr<vec> x_true;
5959};
6060
61- LeastSquaresProblem generate_problem (
62- std::shared_ptr< const gko::Executor> exec , gko::size_type m ,
63- gko::size_type n, unsigned int data_seed)
61+ LeastSquaresProblem generate_problem (std::shared_ptr< const gko::Executor> exec,
62+ gko::size_type m , gko::size_type n ,
63+ unsigned int data_seed)
6464{
6565 auto host = exec->get_master ();
6666 std::mt19937 rng (data_seed);
@@ -96,10 +96,9 @@ LeastSquaresProblem generate_problem(
9696
9797// Solve the sketched least-squares problem: min_x || S*A*x - S*b ||_2
9898// via normal equations on the sketched system: (SA)^T SA x = (SA)^T Sb
99- void sketch_and_solve (std::shared_ptr<const gko::Executor> exec,
100- const vec* A, const vec* b, vec* x,
101- gko::LinOp* sketch_op, const std::string& label,
102- int num_reps, long setup_us)
99+ void sketch_and_solve (std::shared_ptr<const gko::Executor> exec, const vec* A,
100+ const vec* b, vec* x, gko::LinOp* sketch_op,
101+ const std::string& label, int num_reps, long setup_us)
103102{
104103 auto m = A->get_size ()[0 ];
105104 auto n = A->get_size ()[1 ];
@@ -140,11 +139,10 @@ void sketch_and_solve(std::shared_ptr<const gko::Executor> exec,
140139
141140 auto solver =
142141 gko::solver::Gmres<ValueType>::build ()
143- .with_criteria (
144- gko::stop::Iteration::build ().with_max_iters (
145- static_cast <gko::uint32>(10 * n)),
146- gko::stop::ResidualNorm<ValueType>::build ()
147- .with_reduction_factor (RealValueType{1e-14 }))
142+ .with_criteria (gko::stop::Iteration::build ().with_max_iters (
143+ static_cast <gko::uint32>(10 * n)),
144+ gko::stop::ResidualNorm<ValueType>::build ()
145+ .with_reduction_factor (RealValueType{1e-14 }))
148146 .on (exec)
149147 ->generate (gko::share (std::move (AtA)));
150148
@@ -195,11 +193,10 @@ void direct_solve(std::shared_ptr<const gko::Executor> exec, const vec* A,
195193
196194 auto solver =
197195 gko::solver::Gmres<ValueType>::build ()
198- .with_criteria (
199- gko::stop::Iteration::build ().with_max_iters (
200- static_cast <gko::uint32>(10 * n)),
201- gko::stop::ResidualNorm<ValueType>::build ()
202- .with_reduction_factor (RealValueType{1e-14 }))
196+ .with_criteria (gko::stop::Iteration::build ().with_max_iters (
197+ static_cast <gko::uint32>(10 * n)),
198+ gko::stop::ResidualNorm<ValueType>::build ()
199+ .with_reduction_factor (RealValueType{1e-14 }))
203200 .on (exec)
204201 ->generate (gko::share (std::move (AtA)));
205202
@@ -222,10 +219,11 @@ void direct_solve(std::shared_ptr<const gko::Executor> exec, const vec* A,
222219 auto rel_res = res_norm->at (0 , 0 ) / b_norm->at (0 , 0 );
223220
224221 std::cout << std::left << std::setw (16 ) << " Direct" << std::right
225- << std::setw (10 ) << " ---" << " " << std::setw (10 ) << " ---"
226- << " " << std::setw (10 ) << solve_us << " us"
227- << std::setw (14 ) << std::scientific << std::setprecision (4 )
228- << rel_res << std::endl;
222+ << std::setw (10 ) << " ---"
223+ << " " << std::setw (10 ) << " ---"
224+ << " " << std::setw (10 ) << solve_us << " us" << std::setw (14 )
225+ << std::scientific << std::setprecision (4 ) << rel_res
226+ << std::endl;
229227}
230228
231229
@@ -284,8 +282,8 @@ int main(int argc, char* argv[])
284282
285283 std::cout << " Problem: " << m << " x " << n << " (overdetermined"
286284 << " , ratio " << m / n << " :1)"
287- << " \n Sketch size: " << k << " (compression "
288- << std::fixed << std:: setprecision (1 )
285+ << " \n Sketch size: " << k << " (compression " << std::fixed
286+ << std::setprecision (1 )
289287 << static_cast <double >(m) / static_cast <double >(k) << " x)"
290288 << " \n Executor: " << exec_name << " Seed: " << seed
291289 << " Reps: " << num_reps << " \n "
@@ -297,23 +295,24 @@ int main(int argc, char* argv[])
297295 // Header
298296 // Setup = time to create the sketch operator (generate random data)
299297 // Sketch = time to compute S*A and S*b
300- // Solve = time to form normal equations + CG solve (on sketched or full system)
298+ // Solve = time to form normal equations + CG solve (on sketched or full
299+ // system)
301300 std::cout << std::left << std::setw (16 ) << " Method" << std::right
302301 << std::setw (13 ) << " Setup" << std::setw (13 ) << " Sketch"
303- << std::setw (13 ) << " Solve" << std::setw (14 )
304- << " ||Ax-b||/||b|| " << " \n "
302+ << std::setw (13 ) << " Solve" << std::setw (14 ) << " ||Ax-b||/||b|| "
303+ << " \n "
305304 << std::string (69 , ' -' ) << std::endl;
306305
307306 if (sketch_type == " all" || sketch_type == " gaussian" ) {
308307 exec->synchronize ();
309308 auto t0 = std::chrono::steady_clock::now ();
310- auto gaussian = gko::sketch::GaussianSketch<ValueType>:: create (
311- exec, k, m, seed);
309+ auto gaussian =
310+ gko::sketch::GaussianSketch<ValueType>:: create ( exec, k, m, seed);
312311 exec->synchronize ();
313312 auto t1 = std::chrono::steady_clock::now ();
314- auto setup_us = std::chrono::duration_cast<std::chrono::microseconds>(
315- t1 - t0)
316- .count ();
313+ auto setup_us =
314+ std::chrono::duration_cast<std::chrono::microseconds>( t1 - t0)
315+ .count ();
317316 auto x = vec::create (exec, gko::dim<2 >{n, 1 });
318317 sketch_and_solve (exec, problem.A .get (), problem.b .get (), x.get (),
319318 gaussian.get (), " Gaussian" , num_reps, setup_us);
@@ -326,9 +325,9 @@ int main(int argc, char* argv[])
326325 exec, k, m, seed);
327326 exec->synchronize ();
328327 auto t1 = std::chrono::steady_clock::now ();
329- auto setup_us = std::chrono::duration_cast<std::chrono::microseconds>(
330- t1 - t0)
331- .count ();
328+ auto setup_us =
329+ std::chrono::duration_cast<std::chrono::microseconds>( t1 - t0)
330+ .count ();
332331 auto x = vec::create (exec, gko::dim<2 >{n, 1 });
333332 sketch_and_solve (exec, problem.A .get (), problem.b .get (), x.get (),
334333 cs.get (), " CountSketch" , num_reps, setup_us);
@@ -337,17 +336,17 @@ int main(int argc, char* argv[])
337336 if (sketch_type == " all" || sketch_type == " sparsestack" ) {
338337 exec->synchronize ();
339338 auto t0 = std::chrono::steady_clock::now ();
340-
339+
341340 auto ss = gko::sketch::SparseStack<ValueType, IndexType>::create (
342341 exec, k, m, zeta, seed);
343-
342+
344343 exec->synchronize ();
345344 auto t1 = std::chrono::steady_clock::now ();
346- auto setup_us = std::chrono::duration_cast<std::chrono::microseconds>(
347- t1 - t0)
348- .count ();
345+ auto setup_us =
346+ std::chrono::duration_cast<std::chrono::microseconds>( t1 - t0)
347+ .count ();
349348 auto x = vec::create (exec, gko::dim<2 >{n, 1 });
350-
349+
351350 sketch_and_solve (exec, problem.A .get (), problem.b .get (), x.get (),
352351 ss.get (), " SparseStack" , num_reps, setup_us);
353352 }
0 commit comments