Skip to content

Commit afd1d9f

Browse files
committed
split the dispatch for half and bfloat16
1 parent 4fcf21e commit afd1d9f

11 files changed

Lines changed: 101 additions & 22 deletions

File tree

core/base/block_operator.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ template <typename Fn>
2020
auto dispatch_dense(Fn&& fn, LinOp* v)
2121
{
2222
return run<matrix::Dense, float, double,
23-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
23+
#if GINKGO_ENABLE_HALF
2424
float16, std::complex<float16>,
25+
#endif
26+
#if GINKGO_ENABLE_BFLOAT16
27+
bfloat16, std::complex<bfloat16>,
2528
#endif
2629
std::complex<float>, std::complex<double>>(v,
2730
std::forward<Fn>(fn));

core/base/mtx_io.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,15 +900,19 @@ matrix_data<ValueType, IndexType> read_binary_raw(std::istream& is)
900900
DECLARE_OVERLOAD(double, int32)
901901
DECLARE_OVERLOAD(float, int32)
902902
DECLARE_OVERLOAD(float16, int32)
903+
DECLARE_OVERLOAD(bfloat16, int32)
903904
DECLARE_OVERLOAD(std::complex<double>, int32)
904905
DECLARE_OVERLOAD(std::complex<float>, int32)
905906
DECLARE_OVERLOAD(std::complex<float16>, int32)
907+
DECLARE_OVERLOAD(std::complex<bfloat16>, int32)
906908
DECLARE_OVERLOAD(double, int64)
907909
DECLARE_OVERLOAD(float, int64)
908910
DECLARE_OVERLOAD(float16, int64)
911+
DECLARE_OVERLOAD(bfloat16, int64)
909912
DECLARE_OVERLOAD(std::complex<double>, int64)
910913
DECLARE_OVERLOAD(std::complex<float>, int64)
911914
DECLARE_OVERLOAD(std::complex<float16>, int64)
915+
DECLARE_OVERLOAD(std::complex<bfloat16>, int64)
912916
#undef DECLARE_OVERLOAD
913917
else
914918
{

core/config/dispatch.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ deferred_factory_parameter<ReturnType> dispatch(
105105
using value_type_list_base =
106106
syn::type_list<double, float, std::complex<double>, std::complex<float>>;
107107

108-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
109108
using value_type_list =
110-
syn::type_list<double, float, float16, std::complex<double>,
111-
std::complex<float>, std::complex<float16>>;
112-
#else
113-
using value_type_list = value_type_list_base;
114-
#endif // GINKGO_ENABLE_HALF
109+
syn::type_list<double, float,
110+
#if GINKGO_ENABLE_HALF
111+
float16, std::complex<float16>,
112+
#endif
113+
#if GINKGO_ENABLE_BFLOAT16
114+
bfloat16, std::complex<bfloat16>,
115+
#endif
116+
std::complex<double>, std::complex<float>>;
115117

116118
using index_type_list = syn::type_list<int32, int64>;
117119

core/distributed/helpers.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,21 @@ auto run_matrix(T* linop, F&& f, Args&&... args)
159159
with_same_constness_t<Matrix<float, int32, int32>, T>,
160160
with_same_constness_t<Matrix<float, int32, int64>, T>,
161161
with_same_constness_t<Matrix<float, int64, int64>, T>,
162-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
162+
#if GINKGO_ENABLE_HALF
163163
with_same_constness_t<Matrix<float16, int32, int32>, T>,
164164
with_same_constness_t<Matrix<float16, int32, int64>, T>,
165165
with_same_constness_t<Matrix<float16, int64, int64>, T>,
166166
with_same_constness_t<Matrix<std::complex<float16>, int32, int32>, T>,
167167
with_same_constness_t<Matrix<std::complex<float16>, int32, int64>, T>,
168168
with_same_constness_t<Matrix<std::complex<float16>, int64, int64>, T>,
169+
#endif
170+
#if GINKGO_ENABLE_BFLOAT16
171+
with_same_constness_t<Matrix<bfloat16, int32, int32>, T>,
172+
with_same_constness_t<Matrix<bfloat16, int32, int64>, T>,
173+
with_same_constness_t<Matrix<bfloat16, int64, int64>, T>,
174+
with_same_constness_t<Matrix<std::complex<bfloat16>, int32, int32>, T>,
175+
with_same_constness_t<Matrix<std::complex<bfloat16>, int32, int64>, T>,
176+
with_same_constness_t<Matrix<std::complex<bfloat16>, int64, int64>, T>,
169177
#endif
170178
with_same_constness_t<Matrix<std::complex<double>, int32, int32>, T>,
171179
with_same_constness_t<Matrix<std::complex<double>, int32, int64>, T>,

core/log/solver_progress.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,21 @@ class SolverProgressStore : public SolverProgress {
247247
run<gko::matrix::Dense<double>, gko::matrix::Dense<float>,
248248
gko::matrix::Dense<std::complex<double>>,
249249
gko::matrix::Dense<std::complex<float>>,
250-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
250+
#if GINKGO_ENABLE_HALF
251251
gko::matrix::Dense<gko::float16>,
252252
gko::matrix::Dense<std::complex<gko::float16>>,
253253
gko::WritableToMatrixData<gko::float16, int32>,
254254
gko::WritableToMatrixData<std::complex<gko::float16>, int32>,
255255
gko::WritableToMatrixData<gko::float16, int64>,
256256
gko::WritableToMatrixData<std::complex<gko::float16>, int64>,
257+
#endif
258+
#if GINKGO_ENABLE_BFLOAT16
259+
gko::matrix::Dense<gko::bfloat16>,
260+
gko::matrix::Dense<std::complex<gko::bfloat16>>,
261+
gko::WritableToMatrixData<gko::bfloat16, int32>,
262+
gko::WritableToMatrixData<std::complex<gko::bfloat16>, int32>,
263+
gko::WritableToMatrixData<gko::bfloat16, int64>,
264+
gko::WritableToMatrixData<std::complex<gko::bfloat16>, int64>,
257265
#endif
258266
// fallback for other matrix types
259267
gko::WritableToMatrixData<double, int32>,

core/matrix/dense.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1570,7 +1570,8 @@ void gather_mixed_real_complex(Function fn, LinOp* out)
15701570
{
15711571
#ifdef GINKGO_MIXED_PRECISION
15721572
run<matrix::Dense, ValueType, next_precision<ValueType>,
1573-
next_precision_move<ValueType, 2>>(out, fn);
1573+
next_precision_move<ValueType, 2>, next_precision_move<ValueType, 3>>(
1574+
out, fn);
15741575
#else
15751576
precision_dispatch<ValueType>(fn, out);
15761577
#endif

core/matrix/permutation.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,11 @@ void dispatch_dense(const LinOp* op, Functor fn)
268268
using matrix::Dense;
269269
using std::complex;
270270
run<Dense,
271-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
271+
#if GINKGO_ENABLE_HALF
272272
gko::float16, std::complex<gko::float16>,
273+
#endif
274+
#if GINKGO_ENABLE_BFLOAT16
275+
gko::bfloat16, std::complex<gko::bfloat16>,
273276
#endif
274277
double, float, std::complex<double>, std::complex<float>>(op, fn);
275278
}

core/matrix/row_gatherer.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ template <typename IndexType>
6666
void RowGatherer<IndexType>::apply_impl(const LinOp* in, LinOp* out) const
6767
{
6868
run<Dense,
69-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
69+
#if GINKGO_ENABLE_HALF
7070
gko::float16, std::complex<gko::float16>,
71+
#endif
72+
#if GINKGO_ENABLE_BFLOAT16
73+
gko::bfloat16, std::complex<gko::bfloat16>,
7174
#endif
7275
float, double, std::complex<float>, std::complex<double>>(
7376
in, [&](auto gather) { gather->row_gather(&row_idxs_, out); });
@@ -78,8 +81,11 @@ void RowGatherer<IndexType>::apply_impl(const LinOp* alpha, const LinOp* in,
7881
const LinOp* beta, LinOp* out) const
7982
{
8083
run<Dense,
81-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
84+
#if GINKGO_ENABLE_HALF
8285
gko::float16, std::complex<gko::float16>,
86+
#endif
87+
#if GINKGO_ENABLE_BFLOAT16
88+
gko::bfloat16, std::complex<gko::bfloat16>,
8389
#endif
8490
float, double, std::complex<float>, std::complex<double>>(
8591
in,

core/solver/multigrid.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,11 @@ void MultigridState::generate(const LinOp* system_matrix_in,
319319
auto mg_level = mg_level_list.at(i);
320320

321321
run<gko::multigrid::EnableMultigridLevel, float, double,
322-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
322+
#if GINKGO_ENABLE_HALF
323323
float16, std::complex<float16>,
324+
#endif
325+
#if GINKGO_ENABLE_BFLOAT16
326+
bfloat16, std::complex<bfloat16>,
324327
#endif
325328
std::complex<float>, std::complex<double>>(
326329
mg_level,
@@ -461,8 +464,11 @@ void MultigridState::run_mg_cycle(multigrid::cycle cycle, size_type level,
461464
}
462465
auto mg_level = multigrid->get_mg_level_list().at(level);
463466
run<gko::multigrid::EnableMultigridLevel, float, double,
464-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
467+
#if GINKGO_ENABLE_HALF
465468
float16, std::complex<float16>,
469+
#endif
470+
#if GINKGO_ENABLE_BFLOAT16
471+
bfloat16, std::complex<bfloat16>,
466472
#endif
467473
std::complex<float>, std::complex<double>>(
468474
mg_level, [&, this](auto mg_level) {
@@ -714,8 +720,11 @@ void Multigrid::generate()
714720
}
715721

716722
run<gko::multigrid::EnableMultigridLevel, float, double,
717-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
723+
#if GINKGO_ENABLE_HALF
718724
float16, std::complex<float16>,
725+
#endif
726+
#if GINKGO_ENABLE_BFLOAT16
727+
bfloat16, std::complex<bfloat16>,
719728
#endif
720729
std::complex<float>, std::complex<double>>(
721730
mg_level,
@@ -755,8 +764,11 @@ void Multigrid::generate()
755764

756765
// generate coarsest solver
757766
run<gko::multigrid::EnableMultigridLevel, float, double,
758-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
767+
#if GINKGO_ENABLE_HALF
759768
float16, std::complex<float16>,
769+
#endif
770+
#if GINKGO_ENABLE_BFLOAT16
771+
bfloat16, std::complex<bfloat16>,
760772
#endif
761773
std::complex<float>, std::complex<double>>(
762774
last_mg_level,
@@ -875,8 +887,11 @@ void Multigrid::apply_with_initial_guess_impl(const LinOp* b, LinOp* x,
875887
};
876888
auto first_mg_level = this->get_mg_level_list().front();
877889
run<gko::multigrid::EnableMultigridLevel, float, double,
878-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
890+
#if GINKGO_ENABLE_HALF
879891
float16, std::complex<float16>,
892+
#endif
893+
#if GINKGO_ENABLE_BFLOAT16
894+
bfloat16, std::complex<bfloat16>,
880895
#endif
881896
std::complex<float>, std::complex<double>>(first_mg_level, lambda, b,
882897
x);
@@ -917,8 +932,11 @@ void Multigrid::apply_with_initial_guess_impl(const LinOp* alpha,
917932
};
918933
auto first_mg_level = this->get_mg_level_list().front();
919934
run<gko::multigrid::EnableMultigridLevel, float, double,
920-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
935+
#if GINKGO_ENABLE_HALF
921936
float16, std::complex<float16>,
937+
#endif
938+
#if GINKGO_ENABLE_BFLOAT16
939+
bfloat16, std::complex<bfloat16>,
922940
#endif
923941
std::complex<float>, std::complex<double>>(first_mg_level, lambda,
924942
alpha, b, beta, x);
@@ -985,8 +1003,11 @@ void Multigrid::apply_dense_impl(const VectorType* b, VectorType* x,
9851003
auto first_mg_level = this->get_mg_level_list().front();
9861004

9871005
run<gko::multigrid::EnableMultigridLevel, float, double,
988-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
1006+
#if GINKGO_ENABLE_HALF
9891007
float16, std::complex<float16>,
1008+
#endif
1009+
#if GINKGO_ENABLE_BFLOAT16
1010+
bfloat16, std::complex<bfloat16>,
9901011
#endif
9911012
std::complex<float>, std::complex<double>>(first_mg_level, lambda, b,
9921013
x);

core/test/base/mtx_io.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,12 +571,18 @@ TEST(MtxReader, ReadsBinary)
571571
test_read(gko::matrix_data<double, gko::int64>{});
572572
test_read(gko::matrix_data<std::complex<float>, gko::int64>{});
573573
test_read(gko::matrix_data<std::complex<double>, gko::int64>{});
574-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
574+
#if GINKGO_ENABLE_HALF
575575
test_read(gko::matrix_data<gko::float16, gko::int32>{});
576576
test_read(gko::matrix_data<std::complex<gko::float16>, gko::int32>{});
577577
test_read(gko::matrix_data<gko::float16, gko::int64>{});
578578
test_read(gko::matrix_data<std::complex<gko::float16>, gko::int64>{});
579579
#endif
580+
#if GINKGO_ENABLE_BFLOAT16
581+
test_read(gko::matrix_data<gko::bfloat16, gko::int32>{});
582+
test_read(gko::matrix_data<std::complex<gko::bfloat16>, gko::int32>{});
583+
test_read(gko::matrix_data<gko::bfloat16, gko::int64>{});
584+
test_read(gko::matrix_data<std::complex<gko::bfloat16>, gko::int64>{});
585+
#endif
580586
}
581587

582588

@@ -632,12 +638,18 @@ TEST(MtxReader, ReadsComplexBinary)
632638
test_read_fail(gko::matrix_data<double, gko::int64>{});
633639
test_read(gko::matrix_data<std::complex<float>, gko::int64>{});
634640
test_read(gko::matrix_data<std::complex<double>, gko::int64>{});
635-
#if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
641+
#if GINKGO_ENABLE_HALF
636642
test_read_fail(gko::matrix_data<gko::float16, gko::int32>{});
637643
test_read(gko::matrix_data<std::complex<gko::float16>, gko::int32>{});
638644
test_read_fail(gko::matrix_data<gko::float16, gko::int64>{});
639645
test_read(gko::matrix_data<std::complex<gko::float16>, gko::int64>{});
640646
#endif
647+
#if GINKGO_ENABLE_BFLOAT16
648+
test_read_fail(gko::matrix_data<gko::bfloat16, gko::int32>{});
649+
test_read(gko::matrix_data<std::complex<gko::bfloat16>, gko::int32>{});
650+
test_read_fail(gko::matrix_data<gko::bfloat16, gko::int64>{});
651+
test_read(gko::matrix_data<std::complex<gko::bfloat16>, gko::int64>{});
652+
#endif
641653
}
642654

643655

0 commit comments

Comments
 (0)