Skip to content

Commit 4f0036b

Browse files
yhmtsaiMarcelKoch
andcommitted
improve the precision list type traits
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
1 parent 66ad642 commit 4f0036b

31 files changed

Lines changed: 339 additions & 431 deletions

core/base/batch_multi_vector.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ void MultiVector<ValueType>::move_to(
299299
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
300300
template <typename ValueType>
301301
void MultiVector<ValueType>::convert_to(
302-
MultiVector<next_precision_move<ValueType, 2>>* result) const
302+
MultiVector<next_precision<ValueType, 2>>* result) const
303303
{
304304
result->values_ = this->values_;
305305
result->set_size(this->get_size());
@@ -308,7 +308,7 @@ void MultiVector<ValueType>::convert_to(
308308

309309
template <typename ValueType>
310310
void MultiVector<ValueType>::move_to(
311-
MultiVector<next_precision_move<ValueType, 2>>* result)
311+
MultiVector<next_precision<ValueType, 2>>* result)
312312
{
313313
this->convert_to(result);
314314
}
@@ -318,7 +318,7 @@ void MultiVector<ValueType>::move_to(
318318
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
319319
template <typename ValueType>
320320
void MultiVector<ValueType>::convert_to(
321-
MultiVector<next_precision_move<ValueType, 3>>* result) const
321+
MultiVector<next_precision<ValueType, 3>>* result) const
322322
{
323323
result->values_ = this->values_;
324324
result->set_size(this->get_size());
@@ -327,7 +327,7 @@ void MultiVector<ValueType>::convert_to(
327327

328328
template <typename ValueType>
329329
void MultiVector<ValueType>::move_to(
330-
MultiVector<next_precision_move<ValueType, 3>>* result)
330+
MultiVector<next_precision<ValueType, 3>>* result)
331331
{
332332
this->convert_to(result);
333333
}

core/distributed/matrix.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
307307
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
308308
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
309309
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
310-
Matrix<next_precision_move<value_type, 2>, local_index_type,
311-
global_index_type>* result) const
310+
Matrix<next_precision<value_type, 2>, local_index_type, global_index_type>*
311+
result) const
312312
{
313313
GKO_ASSERT(this->get_communicator().size() ==
314314
result->get_communicator().size());
@@ -326,8 +326,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
326326

327327
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
328328
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
329-
Matrix<next_precision_move<value_type, 2>, local_index_type,
330-
global_index_type>* result)
329+
Matrix<next_precision<value_type, 2>, local_index_type, global_index_type>*
330+
result)
331331
{
332332
GKO_ASSERT(this->get_communicator().size() ==
333333
result->get_communicator().size());
@@ -348,8 +348,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
348348
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
349349
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
350350
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
351-
Matrix<next_precision_move<value_type, 3>, local_index_type,
352-
global_index_type>* result) const
351+
Matrix<next_precision<value_type, 3>, local_index_type, global_index_type>*
352+
result) const
353353
{
354354
GKO_ASSERT(this->get_communicator().size() ==
355355
result->get_communicator().size());
@@ -367,8 +367,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
367367

368368
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
369369
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
370-
Matrix<next_precision_move<value_type, 3>, local_index_type,
371-
global_index_type>* result)
370+
Matrix<next_precision<value_type, 3>, local_index_type, global_index_type>*
371+
result)
372372
{
373373
GKO_ASSERT(this->get_communicator().size() ==
374374
result->get_communicator().size());

core/distributed/vector.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ void Vector<ValueType>::move_to(Vector<next_precision<ValueType>>* result)
299299
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
300300
template <typename ValueType>
301301
void Vector<ValueType>::convert_to(
302-
Vector<next_precision_move<ValueType, 2>>* result) const
302+
Vector<next_precision<ValueType, 2>>* result) const
303303
{
304304
GKO_ASSERT(this->get_communicator().size() ==
305305
result->get_communicator().size());
@@ -309,8 +309,7 @@ void Vector<ValueType>::convert_to(
309309

310310

311311
template <typename ValueType>
312-
void Vector<ValueType>::move_to(
313-
Vector<next_precision_move<ValueType, 2>>* result)
312+
void Vector<ValueType>::move_to(Vector<next_precision<ValueType, 2>>* result)
314313
{
315314
this->convert_to(result);
316315
}
@@ -320,7 +319,7 @@ void Vector<ValueType>::move_to(
320319
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
321320
template <typename ValueType>
322321
void Vector<ValueType>::convert_to(
323-
Vector<next_precision_move<ValueType, 3>>* result) const
322+
Vector<next_precision<ValueType, 3>>* result) const
324323
{
325324
GKO_ASSERT(this->get_communicator().size() ==
326325
result->get_communicator().size());
@@ -330,8 +329,7 @@ void Vector<ValueType>::convert_to(
330329

331330

332331
template <typename ValueType>
333-
void Vector<ValueType>::move_to(
334-
Vector<next_precision_move<ValueType, 3>>* result)
332+
void Vector<ValueType>::move_to(Vector<next_precision<ValueType, 3>>* result)
335333
{
336334
this->convert_to(result);
337335
}

core/matrix/batch_csr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ void Csr<ValueType, IndexType>::move_to(
266266
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
267267
template <typename ValueType, typename IndexType>
268268
void Csr<ValueType, IndexType>::convert_to(
269-
Csr<next_precision_move<ValueType, 2>, IndexType>* result) const
269+
Csr<next_precision<ValueType, 2>, IndexType>* result) const
270270
{
271271
result->values_ = this->values_;
272272
result->col_idxs_ = this->col_idxs_;
@@ -277,7 +277,7 @@ void Csr<ValueType, IndexType>::convert_to(
277277

278278
template <typename ValueType, typename IndexType>
279279
void Csr<ValueType, IndexType>::move_to(
280-
Csr<next_precision_move<ValueType, 2>, IndexType>* result)
280+
Csr<next_precision<ValueType, 2>, IndexType>* result)
281281
{
282282
this->convert_to(result);
283283
}
@@ -287,7 +287,7 @@ void Csr<ValueType, IndexType>::move_to(
287287
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
288288
template <typename ValueType, typename IndexType>
289289
void Csr<ValueType, IndexType>::convert_to(
290-
Csr<next_precision_move<ValueType, 3>, IndexType>* result) const
290+
Csr<next_precision<ValueType, 3>, IndexType>* result) const
291291
{
292292
result->values_ = this->values_;
293293
result->col_idxs_ = this->col_idxs_;
@@ -298,7 +298,7 @@ void Csr<ValueType, IndexType>::convert_to(
298298

299299
template <typename ValueType, typename IndexType>
300300
void Csr<ValueType, IndexType>::move_to(
301-
Csr<next_precision_move<ValueType, 3>, IndexType>* result)
301+
Csr<next_precision<ValueType, 3>, IndexType>* result)
302302
{
303303
this->convert_to(result);
304304
}

core/matrix/batch_dense.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,15 @@ void Dense<ValueType>::move_to(Dense<next_precision<ValueType>>* result)
262262
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
263263
template <typename ValueType>
264264
void Dense<ValueType>::convert_to(
265-
Dense<next_precision_move<ValueType, 2>>* result) const
265+
Dense<next_precision<ValueType, 2>>* result) const
266266
{
267267
result->values_ = this->values_;
268268
result->set_size(this->get_size());
269269
}
270270

271271

272272
template <typename ValueType>
273-
void Dense<ValueType>::move_to(Dense<next_precision_move<ValueType, 2>>* result)
273+
void Dense<ValueType>::move_to(Dense<next_precision<ValueType, 2>>* result)
274274
{
275275
this->convert_to(result);
276276
}
@@ -280,15 +280,15 @@ void Dense<ValueType>::move_to(Dense<next_precision_move<ValueType, 2>>* result)
280280
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
281281
template <typename ValueType>
282282
void Dense<ValueType>::convert_to(
283-
Dense<next_precision_move<ValueType, 3>>* result) const
283+
Dense<next_precision<ValueType, 3>>* result) const
284284
{
285285
result->values_ = this->values_;
286286
result->set_size(this->get_size());
287287
}
288288

289289

290290
template <typename ValueType>
291-
void Dense<ValueType>::move_to(Dense<next_precision_move<ValueType, 3>>* result)
291+
void Dense<ValueType>::move_to(Dense<next_precision<ValueType, 3>>* result)
292292
{
293293
this->convert_to(result);
294294
}

core/matrix/batch_ell.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ void Ell<ValueType, IndexType>::move_to(
286286
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
287287
template <typename ValueType, typename IndexType>
288288
void Ell<ValueType, IndexType>::convert_to(
289-
Ell<next_precision_move<ValueType, 2>, IndexType>* result) const
289+
Ell<next_precision<ValueType, 2>, IndexType>* result) const
290290
{
291291
result->values_ = this->values_;
292292
result->col_idxs_ = this->col_idxs_;
@@ -297,7 +297,7 @@ void Ell<ValueType, IndexType>::convert_to(
297297

298298
template <typename ValueType, typename IndexType>
299299
void Ell<ValueType, IndexType>::move_to(
300-
Ell<next_precision_move<ValueType, 2>, IndexType>* result)
300+
Ell<next_precision<ValueType, 2>, IndexType>* result)
301301
{
302302
this->convert_to(result);
303303
}
@@ -307,7 +307,7 @@ void Ell<ValueType, IndexType>::move_to(
307307
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
308308
template <typename ValueType, typename IndexType>
309309
void Ell<ValueType, IndexType>::convert_to(
310-
Ell<next_precision_move<ValueType, 3>, IndexType>* result) const
310+
Ell<next_precision<ValueType, 3>, IndexType>* result) const
311311
{
312312
result->values_ = this->values_;
313313
result->col_idxs_ = this->col_idxs_;
@@ -318,7 +318,7 @@ void Ell<ValueType, IndexType>::convert_to(
318318

319319
template <typename ValueType, typename IndexType>
320320
void Ell<ValueType, IndexType>::move_to(
321-
Ell<next_precision_move<ValueType, 3>, IndexType>* result)
321+
Ell<next_precision<ValueType, 3>, IndexType>* result)
322322
{
323323
this->convert_to(result);
324324
}

core/matrix/coo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ void Coo<ValueType, IndexType>::move_to(
236236
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
237237
template <typename ValueType, typename IndexType>
238238
void Coo<ValueType, IndexType>::convert_to(
239-
Coo<next_precision_move<ValueType, 2>, IndexType>* result) const
239+
Coo<next_precision<ValueType, 2>, IndexType>* result) const
240240
{
241241
result->values_ = this->values_;
242242
result->row_idxs_ = this->row_idxs_;
@@ -247,7 +247,7 @@ void Coo<ValueType, IndexType>::convert_to(
247247

248248
template <typename ValueType, typename IndexType>
249249
void Coo<ValueType, IndexType>::move_to(
250-
Coo<next_precision_move<ValueType, 2>, IndexType>* result)
250+
Coo<next_precision<ValueType, 2>, IndexType>* result)
251251
{
252252
this->convert_to(result);
253253
}
@@ -257,7 +257,7 @@ void Coo<ValueType, IndexType>::move_to(
257257
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
258258
template <typename ValueType, typename IndexType>
259259
void Coo<ValueType, IndexType>::convert_to(
260-
Coo<next_precision_move<ValueType, 3>, IndexType>* result) const
260+
Coo<next_precision<ValueType, 3>, IndexType>* result) const
261261
{
262262
result->values_ = this->values_;
263263
result->row_idxs_ = this->row_idxs_;
@@ -268,7 +268,7 @@ void Coo<ValueType, IndexType>::convert_to(
268268

269269
template <typename ValueType, typename IndexType>
270270
void Coo<ValueType, IndexType>::move_to(
271-
Coo<next_precision_move<ValueType, 3>, IndexType>* result)
271+
Coo<next_precision<ValueType, 3>, IndexType>* result)
272272
{
273273
this->convert_to(result);
274274
}

core/matrix/csr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ void Csr<ValueType, IndexType>::move_to(
327327
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
328328
template <typename ValueType, typename IndexType>
329329
void Csr<ValueType, IndexType>::convert_to(
330-
Csr<next_precision_move<ValueType, 2>, IndexType>* result) const
330+
Csr<next_precision<ValueType, 2>, IndexType>* result) const
331331
{
332332
result->values_ = this->values_;
333333
result->col_idxs_ = this->col_idxs_;
@@ -339,7 +339,7 @@ void Csr<ValueType, IndexType>::convert_to(
339339

340340
template <typename ValueType, typename IndexType>
341341
void Csr<ValueType, IndexType>::move_to(
342-
Csr<next_precision_move<ValueType, 2>, IndexType>* result)
342+
Csr<next_precision<ValueType, 2>, IndexType>* result)
343343
{
344344
this->convert_to(result);
345345
}
@@ -349,7 +349,7 @@ void Csr<ValueType, IndexType>::move_to(
349349
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
350350
template <typename ValueType, typename IndexType>
351351
void Csr<ValueType, IndexType>::convert_to(
352-
Csr<next_precision_move<ValueType, 3>, IndexType>* result) const
352+
Csr<next_precision<ValueType, 3>, IndexType>* result) const
353353
{
354354
result->values_ = this->values_;
355355
result->col_idxs_ = this->col_idxs_;
@@ -361,7 +361,7 @@ void Csr<ValueType, IndexType>::convert_to(
361361

362362
template <typename ValueType, typename IndexType>
363363
void Csr<ValueType, IndexType>::move_to(
364-
Csr<next_precision_move<ValueType, 3>, IndexType>* result)
364+
Csr<next_precision<ValueType, 3>, IndexType>* result)
365365
{
366366
this->convert_to(result);
367367
}

core/matrix/dense.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ void Dense<ValueType>::move_to(Dense<next_precision<ValueType>>* result)
606606
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
607607
template <typename ValueType>
608608
void Dense<ValueType>::convert_to(
609-
Dense<next_precision_move<ValueType, 2>>* result) const
609+
Dense<next_precision<ValueType, 2>>* result) const
610610
{
611611
if (result->get_size() != this->get_size()) {
612612
result->set_size(this->get_size());
@@ -621,7 +621,7 @@ void Dense<ValueType>::convert_to(
621621

622622

623623
template <typename ValueType>
624-
void Dense<ValueType>::move_to(Dense<next_precision_move<ValueType, 2>>* result)
624+
void Dense<ValueType>::move_to(Dense<next_precision<ValueType, 2>>* result)
625625
{
626626
this->convert_to(result);
627627
}
@@ -631,7 +631,7 @@ void Dense<ValueType>::move_to(Dense<next_precision_move<ValueType, 2>>* result)
631631
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
632632
template <typename ValueType>
633633
void Dense<ValueType>::convert_to(
634-
Dense<next_precision_move<ValueType, 3>>* result) const
634+
Dense<next_precision<ValueType, 3>>* result) const
635635
{
636636
if (result->get_size() != this->get_size()) {
637637
result->set_size(this->get_size());
@@ -646,7 +646,7 @@ void Dense<ValueType>::convert_to(
646646

647647

648648
template <typename ValueType>
649-
void Dense<ValueType>::move_to(Dense<next_precision_move<ValueType, 3>>* result)
649+
void Dense<ValueType>::move_to(Dense<next_precision<ValueType, 3>>* result)
650650
{
651651
this->convert_to(result);
652652
}
@@ -1570,8 +1570,7 @@ void gather_mixed_real_complex(Function fn, LinOp* out)
15701570
{
15711571
#ifdef GINKGO_MIXED_PRECISION
15721572
run<matrix::Dense, ValueType, next_precision<ValueType>,
1573-
next_precision_move<ValueType, 2>, next_precision_move<ValueType, 3>>(
1574-
out, fn);
1573+
next_precision<ValueType, 2>, next_precision<ValueType, 3>>(out, fn);
15751574
#else
15761575
precision_dispatch<ValueType>(fn, out);
15771576
#endif

core/matrix/diagonal.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ void Diagonal<ValueType>::move_to(Diagonal<next_precision<ValueType>>* result)
166166
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
167167
template <typename ValueType>
168168
void Diagonal<ValueType>::convert_to(
169-
Diagonal<next_precision_move<ValueType, 2>>* result) const
169+
Diagonal<next_precision<ValueType, 2>>* result) const
170170
{
171171
result->values_ = this->values_;
172172
result->set_size(this->get_size());
@@ -175,7 +175,7 @@ void Diagonal<ValueType>::convert_to(
175175

176176
template <typename ValueType>
177177
void Diagonal<ValueType>::move_to(
178-
Diagonal<next_precision_move<ValueType, 2>>* result)
178+
Diagonal<next_precision<ValueType, 2>>* result)
179179
{
180180
this->convert_to(result);
181181
}
@@ -185,7 +185,7 @@ void Diagonal<ValueType>::move_to(
185185
#if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
186186
template <typename ValueType>
187187
void Diagonal<ValueType>::convert_to(
188-
Diagonal<next_precision_move<ValueType, 3>>* result) const
188+
Diagonal<next_precision<ValueType, 3>>* result) const
189189
{
190190
result->values_ = this->values_;
191191
result->set_size(this->get_size());
@@ -194,7 +194,7 @@ void Diagonal<ValueType>::convert_to(
194194

195195
template <typename ValueType>
196196
void Diagonal<ValueType>::move_to(
197-
Diagonal<next_precision_move<ValueType, 3>>* result)
197+
Diagonal<next_precision<ValueType, 3>>* result)
198198
{
199199
this->convert_to(result);
200200
}

0 commit comments

Comments
 (0)