Skip to content

Commit 4d9dd63

Browse files
committed
Use host pointer in rocm 5.2.0 blas
1 parent c11cd33 commit 4d9dd63

7 files changed

+561
-4
lines changed

blas/tpls/KokkosBlas1_scal_tpl_spec_decl.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,17 @@ KOKKOSBLAS1_CSCAL_TPL_SPEC_DECL_CUBLAS(Kokkos::LayoutLeft, Kokkos::CudaUVMSpace,
256256
namespace KokkosBlas {
257257
namespace Impl {
258258

259+
260+
/* rocBLAS documentation:
261+
"a rocBLAS handle always has one stream.""
262+
"If the handle is switching from one non-default stream to another, the old
263+
stream needs to be synchronized...next...rocblas_set_stream"
264+
Basically this means if we're switching streams, we have to fence the old one
265+
first.
266+
We also set the handle's pointer mode appropriately before invoking BLAS.
267+
268+
// push_pointer_mode
269+
*/
259270
#define KOKKOSBLAS1_XSCAL_TPL_SPEC_DECL_ROCBLAS( \
260271
SCALAR_TYPE, ROCBLAS_SCALAR_TYPE, ROCBLAS_FN, LAYOUT, EXECSPACE, MEMSPACE, \
261272
ETI_SPEC_AVAIL) \
@@ -288,11 +299,18 @@ namespace Impl {
288299
const size_type numElems = X.extent(0); \
289300
if ((numElems < static_cast<size_type>(INT_MAX)) && \
290301
(R.data() == X.data())) { \
302+
std::cerr << __FILE__ << ":" << __LINE__ << " rocBLAS scal(1)!\n"; \
291303
scal_print_specialization<RV, AS, XV>(); \
292304
const int N = static_cast<int>(numElems); \
293305
constexpr int one = 1; \
294306
KokkosBlas::Impl::RocBlasSingleton& s = \
295307
KokkosBlas::Impl::RocBlasSingleton::singleton(); \
308+
hipStream_t cur; \
309+
KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \
310+
rocblas_get_stream(s.handle, &cur)); \
311+
if (cur != space.hip_stream()) { \
312+
execution_space(cur).fence(); \
313+
} \
296314
KOKKOS_ROCBLAS_SAFE_CALL_IMPL( \
297315
rocblas_set_stream(s.handle, space.hip_stream())); \
298316
rocblas_pointer_mode pointer_mode; \
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2022) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#ifndef KOKKOSKERNELS_ISKOKKOSCOMPLEX_HPP
18+
#define KOKKOSKERNELS_ISKOKKOSCOMPLEX_HPP
19+
20+
#include <Kokkos_Core.hpp>
21+
22+
namespace KokkosKernels {
23+
namespace Impl {
24+
25+
/// \class is_kokkos_complex
26+
/// \brief is_kokkos_complex<T>::value is true if T is a Kokkos::complex<...>, false
27+
/// otherwise
28+
template <typename>
29+
struct is_kokkos_complex : public std::false_type {};
30+
template <typename... P>
31+
struct is_kokkos_complex<Kokkos::complex<P...>> : public std::true_type {};
32+
template <typename... P>
33+
struct is_kokkos_complex<const Kokkos::complex<P...>> : public std::true_type {};
34+
35+
template <typename... P>
36+
inline constexpr bool is_kokkos_complex_v = is_kokkos_complex<P...>::value;
37+
38+
}
39+
}
40+
41+
42+
#endif // KOKKOSKERNELS_ISKOKKOSCOMPLEX_HPP
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2022) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#ifndef KOKKOSKERNELS_UNIFIEDSCALARVIEW_HPP
18+
#define KOKKOSKERNELS_UNIFIEDSCALARVIEW_HPP
19+
20+
#include <type_traits>
21+
22+
#include <Kokkos_Core.hpp>
23+
24+
#include <KokkosKernels_AlwaysFalse.hpp>
25+
#include <KokkosKernels_IsKokkosComplex.hpp>
26+
27+
28+
namespace KokkosKernels {
29+
namespace Impl {
30+
31+
32+
template <typename ScalarLike, typename = void>
33+
struct is_scalar : std::false_type {};
34+
35+
// kokkos complex
36+
template <typename T>
37+
struct is_scalar<T, std::enable_if_t<is_kokkos_complex_v<T>>> : std::true_type {};
38+
39+
// other scalars
40+
template <typename ScalarLike>
41+
struct is_scalar<ScalarLike, std::enable_if_t<std::is_integral_v<ScalarLike> || std::is_floating_point_v<ScalarLike>>> : std::true_type {};
42+
43+
template <typename ScalarLike>
44+
inline constexpr bool is_scalar_v = is_scalar<ScalarLike>::value;
45+
46+
47+
48+
49+
template <typename ScalarLike, typename = void>
50+
struct is_scalar_view : std::false_type {};
51+
52+
// rank 0
53+
template <typename ScalarLike>
54+
struct is_scalar_view<ScalarLike, std::enable_if_t<0 == ScalarLike::rank>> : std::true_type {};
55+
56+
// rank 1 and static extent is 1
57+
template <typename ScalarLike>
58+
struct is_scalar_view<ScalarLike,
59+
std::enable_if_t<
60+
1 == ScalarLike::rank && 1 == ScalarLike::static_extent(0)
61+
>
62+
> : std::true_type {};
63+
64+
template <typename ScalarLike>
65+
inline constexpr bool is_scalar_view_v = is_scalar_view<ScalarLike>::value;
66+
67+
/*! \brief true iff ScalarLike is a scalar or a 0D or 1D view of a single thing
68+
*/
69+
template <typename ScalarLike>
70+
inline constexpr bool is_scalar_or_scalar_view = is_scalar_v<ScalarLike> || is_scalar_view_v<ScalarLike>;
71+
72+
73+
74+
75+
76+
template <typename Value, typename = void>
77+
struct unified_scalar;
78+
79+
template <typename Value>
80+
struct unified_scalar<Value, std::enable_if_t<is_scalar_v<Value>>> {
81+
82+
using type = Value;
83+
using non_const_type = std::remove_const_t<type>;
84+
};
85+
86+
template <typename Value>
87+
struct unified_scalar<Value, std::enable_if_t<is_scalar_view_v<Value>>> {
88+
89+
using type = typename Value::value_type;
90+
using non_const_type = std::remove_const_t<type>;
91+
};
92+
93+
template <typename Value>
94+
using unified_scalar_t = typename unified_scalar<Value>::type;
95+
96+
template <typename Value>
97+
using non_const_unified_scalar_t = typename unified_scalar<Value>::non_const_type;
98+
99+
100+
template <typename Value>
101+
constexpr unified_scalar_t<Value> get_scalar(const Value &v) {
102+
103+
static_assert(is_scalar_or_scalar_view<Value>, "");
104+
105+
unified_scalar_t<Value> ref;
106+
if constexpr (is_scalar_view_v<Value>) {
107+
if (0 == Value::rank) {
108+
ref = *v;
109+
} else if (1 == Value::rank) {
110+
ref = v[0];
111+
} else {
112+
static_assert(KokkosKernels::Impl::always_false_v<Value>, "");
113+
}
114+
} else {
115+
ref = v;
116+
}
117+
return ref;
118+
}
119+
120+
} // namespace Impl
121+
} // namespace KokkosKernels
122+
123+
#endif // KOKKOSKERNELS_UNIFIEDSCALARVIEW_HPP

common/unit_test/Test_Common.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <Test_Common_PrintConfiguration.hpp>
2727
#include <Test_Common_Iota.hpp>
2828
#include <Test_Common_LowerBound.hpp>
29+
#include <Test_Common_UnifiedScalarView.hpp>
2930
#include <Test_Common_UpperBound.hpp>
3031

3132
#endif // TEST_COMMON_HPP

0 commit comments

Comments
 (0)