Skip to content

Commit 2d76deb

Browse files
lucbvyasahi-hpc
andauthored
Ifpack2 btds gemm (kokkos#2829)
* Revert "batched/dense/impl/KokkosBatched_Gemm_Team_Impl.hpp (kokkos#2626)" This reverts commit 82605a9. Signed-off-by: Luc Berger-Vergiat <[email protected]> * Revert "ConjTrans support for batched team gemm (kokkos#2580)" This reverts commit 386663c. Signed-off-by: Luc Berger-Vergiat <[email protected]> * Batch - Dense: Re-applying GEMM fixes for extent/stride We have been liberally querrying the extent and stride of Views without checking if the rank of the view is high enough to return a valid value. The fixes lead to failures in BTDS which might point to a bug that relied on the old UB behavior of Kokkos::View x( Signed-off-by: Luc Berger-Vergiat <[email protected]> * Applying clang-format Signed-off-by: Luc Berger-Vergiat <[email protected]> * Batched - Utils: fixing extent and stride calculation Making sure we return 1 as the minimum stride and extent even if a view has a rank lower than the stride/extent querried. Signed-off-by: Luc Berger-Vergiat <[email protected]> * Update batched/KokkosBatched_Util.hpp Co-authored-by: yasahi-hpc <[email protected]> Signed-off-by: Luc Berger <[email protected]> * Update batched/KokkosBatched_Util.hpp Co-authored-by: yasahi-hpc <[email protected]> Signed-off-by: Luc Berger <[email protected]> --------- Signed-off-by: Luc Berger-Vergiat <[email protected]> Signed-off-by: Luc Berger <[email protected]> Co-authored-by: yasahi-hpc <[email protected]>
1 parent f4e7905 commit 2d76deb

11 files changed

+362
-854
lines changed

batched/KokkosBatched_Util.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,13 +682,13 @@ KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) {
682682
static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2.");
683683

684684
if (r == 0) {
685-
int V_extent_0 = V_rank == 0 ? 0 : v.extent_int(0);
685+
int V_extent_0 = V_rank < 1 ? 1 : v.extent_int(0);
686686
return V_extent_0;
687687
} else if (r == 1) {
688-
int V_extent_1 = V_rank == 0 ? 0 : V_rank == 1 ? 1 : v.extent_int(1);
688+
int V_extent_1 = V_rank < 2 ? 1 : v.extent_int(1);
689689
return V_extent_1;
690690
} else {
691-
return -1;
691+
return 1;
692692
}
693693
}
694694

@@ -699,13 +699,13 @@ KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) {
699699
static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2.");
700700

701701
if (r == 0) {
702-
std::size_t V_stride_0 = V_rank == 0 ? 0 : v.stride(0);
702+
std::size_t V_stride_0 = V_rank < 1 ? 1 : v.stride(0);
703703
return V_stride_0;
704704
} else if (r == 1) {
705-
std::size_t V_stride_1 = V_rank == 0 ? 0 : V_rank == 1 ? 1 : v.stride(1);
705+
std::size_t V_stride_1 = V_rank < 2 ? 1 : v.stride(1);
706706
return V_stride_1;
707707
} else {
708-
return 0;
708+
return 1;
709709
}
710710
}
711711
} // namespace Impl

batched/dense/impl/KokkosBatched_Gemm_Common_Impl.hpp

Lines changed: 0 additions & 79 deletions
This file was deleted.

batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,64 @@
55

66
#include "KokkosBlas_util.hpp"
77
#include "KokkosBatched_Util.hpp"
8-
#include "KokkosBatched_Gemm_Common_Impl.hpp"
98
#include "KokkosBatched_Gemm_Serial_Internal.hpp"
109

1110
namespace KokkosBatched {
11+
namespace Impl {
12+
template <typename ArgTransA, typename ArgTransB, typename AViewType, typename BViewType, typename CViewType>
13+
KOKKOS_INLINE_FUNCTION static int checkGemmInput([[maybe_unused]] const AViewType &A,
14+
[[maybe_unused]] const BViewType &B,
15+
[[maybe_unused]] const CViewType &C) {
16+
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::gemm: AViewType is not a Kokkos::View.");
17+
static_assert(Kokkos::is_view_v<BViewType>, "KokkosBatched::gemm: BViewType is not a Kokkos::View.");
18+
static_assert(Kokkos::is_view_v<CViewType>, "KokkosBatched::gemm: CViewType is not a Kokkos::View.");
19+
20+
static_assert(AViewType::rank <= 2, "KokkosBatched::gemm: AViewType must have rank 0, 1 or 2.");
21+
static_assert(BViewType::rank <= 2, "KokkosBatched::gemm: BViewType must have rank 0, 1 or 2.");
22+
static_assert(CViewType::rank <= 2, "KokkosBatched::gemm: CViewType must have rank 0, 1 or 2.");
23+
24+
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
25+
const int m = C.extent(0), n = C.extent(1);
26+
const int lda = A.extent(0);
27+
const int ldb = B.extent(0);
28+
29+
const int ka = std::is_same_v<ArgTransA, Trans::NoTranspose> ? A.extent(1) : A.extent(0);
30+
const int kb = std::is_same_v<ArgTransB, Trans::NoTranspose> ? B.extent(0) : B.extent(1);
31+
32+
if (ka != kb) {
33+
Kokkos::printf(
34+
"KokkosBatched::gemm: Dimensions of A and B do not match: A: %d x %d, "
35+
"B: %d x %d\n",
36+
A.extent(0), A.extent(1), B.extent(0), B.extent(1));
37+
return 1;
38+
}
39+
40+
const int nrowa = std::is_same_v<ArgTransA, Trans::NoTranspose> ? m : ka;
41+
const int nrowb = std::is_same_v<ArgTransB, Trans::NoTranspose> ? kb : n;
42+
43+
if (lda < Kokkos::max(1, nrowa)) {
44+
Kokkos::printf(
45+
"KokkosBatched::gemm: leading dimension of A must not be smaller than "
46+
"max(1, nrowa): "
47+
"lda = %d, nrowa = %d\n",
48+
lda, nrowa);
49+
return 1;
50+
}
51+
if (ldb < Kokkos::max(1, nrowb)) {
52+
Kokkos::printf(
53+
"KokkosBatched::gemm: leading dimension of B must not be smaller than "
54+
"max(1, nrowb): "
55+
"ldb = %d, nrowb = %d\n",
56+
ldb, nrowb);
57+
return 1;
58+
}
59+
60+
#endif
61+
62+
return 0;
63+
}
64+
} // namespace Impl
65+
1266
///
1367
/// Serial Impl
1468
/// ===========

0 commit comments

Comments
 (0)