|
5 | 5 |
|
6 | 6 | #include "KokkosBlas_util.hpp" |
7 | 7 | #include "KokkosBatched_Util.hpp" |
8 | | -#include "KokkosBatched_Gemm_Common_Impl.hpp" |
9 | 8 | #include "KokkosBatched_Gemm_Serial_Internal.hpp" |
10 | 9 |
|
11 | 10 | namespace KokkosBatched { |
| 11 | +namespace Impl { |
| 12 | +template <typename ArgTransA, typename ArgTransB, typename AViewType, typename BViewType, typename CViewType> |
| 13 | +KOKKOS_INLINE_FUNCTION static int checkGemmInput([[maybe_unused]] const AViewType &A, |
| 14 | + [[maybe_unused]] const BViewType &B, |
| 15 | + [[maybe_unused]] const CViewType &C) { |
| 16 | + static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::gemm: AViewType is not a Kokkos::View."); |
| 17 | + static_assert(Kokkos::is_view_v<BViewType>, "KokkosBatched::gemm: BViewType is not a Kokkos::View."); |
| 18 | + static_assert(Kokkos::is_view_v<CViewType>, "KokkosBatched::gemm: CViewType is not a Kokkos::View."); |
| 19 | + |
| 20 | + static_assert(AViewType::rank <= 2, "KokkosBatched::gemm: AViewType must have rank 0, 1 or 2."); |
| 21 | + static_assert(BViewType::rank <= 2, "KokkosBatched::gemm: BViewType must have rank 0, 1 or 2."); |
| 22 | + static_assert(CViewType::rank <= 2, "KokkosBatched::gemm: CViewType must have rank 0, 1 or 2."); |
| 23 | + |
| 24 | +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) |
| 25 | + const int m = C.extent(0), n = C.extent(1); |
| 26 | + const int lda = A.extent(0); |
| 27 | + const int ldb = B.extent(0); |
| 28 | + |
| 29 | + const int ka = std::is_same_v<ArgTransA, Trans::NoTranspose> ? A.extent(1) : A.extent(0); |
| 30 | + const int kb = std::is_same_v<ArgTransB, Trans::NoTranspose> ? B.extent(0) : B.extent(1); |
| 31 | + |
| 32 | + if (ka != kb) { |
| 33 | + Kokkos::printf( |
| 34 | + "KokkosBatched::gemm: Dimensions of A and B do not match: A: %d x %d, " |
| 35 | + "B: %d x %d\n", |
| 36 | + A.extent(0), A.extent(1), B.extent(0), B.extent(1)); |
| 37 | + return 1; |
| 38 | + } |
| 39 | + |
| 40 | + const int nrowa = std::is_same_v<ArgTransA, Trans::NoTranspose> ? m : ka; |
| 41 | + const int nrowb = std::is_same_v<ArgTransB, Trans::NoTranspose> ? kb : n; |
| 42 | + |
| 43 | + if (lda < Kokkos::max(1, nrowa)) { |
| 44 | + Kokkos::printf( |
| 45 | + "KokkosBatched::gemm: leading dimension of A must not be smaller than " |
| 46 | + "max(1, nrowa): " |
| 47 | + "lda = %d, nrowa = %d\n", |
| 48 | + lda, nrowa); |
| 49 | + return 1; |
| 50 | + } |
| 51 | + if (ldb < Kokkos::max(1, nrowb)) { |
| 52 | + Kokkos::printf( |
| 53 | + "KokkosBatched::gemm: leading dimension of B must not be smaller than " |
| 54 | + "max(1, nrowb): " |
| 55 | + "ldb = %d, nrowb = %d\n", |
| 56 | + ldb, nrowb); |
| 57 | + return 1; |
| 58 | + } |
| 59 | + |
| 60 | +#endif |
| 61 | + |
| 62 | + return 0; |
| 63 | +} |
| 64 | +} // namespace Impl |
| 65 | + |
12 | 66 | /// |
13 | 67 | /// Serial Impl |
14 | 68 | /// =========== |
|
0 commit comments