Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions modules/nvidia_plugin/docs/cuda_opset.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The semantics match corresponding nGraph operation classes declared in `namespac
| [Mod](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Mod_1.md) | Supported |
| [MVN](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/normalization/MVN_6.md) | Supported |
| [Multiply](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Multiply_1.md) | Supported* |
| [Negative](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Negative_1.md) | Not Supported |
| [Negative](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Negative_1.md) | Supported |
| [NonMaxSuppression](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/sort/NonMaxSuppression_5.md) | Not Supported |
| [NonZero](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/condition/NonZero_3.md) | Not Supported |
| [NormalizeL2](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/normalization/NormalizeL2_1.md) | Not Supported |
Expand Down Expand Up @@ -134,11 +134,12 @@ The semantics match corresponding nGraph operation classes declared in `namespac
| [ShapeOf](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/shape/ShapeOf_3.md) | Not Supported |
| [ShuffleChannels](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/ShuffleChannels_1.md) | Not Supported |
| [Sigmoid](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/activation/Sigmoid_1.md) | Supported |
| [Sign](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Sign_1.md) | Not Supported |
| [Sign](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Sign_1.md) | Supported |
| [Sin](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Sin_1.md) | Supported |
| [Sinh](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/arithmetic/Sinh_1.md) | Supported |
| [SoftMax](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/activation/SoftMax_1.md) | Supported |
| [SoftPlus](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/activation/SoftPlus_4.md) | Not Supported |
| [SoftSign](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/activation/SoftSign_1.md) | Supported |
| [SpaceToBatch](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/SpaceToBatch_2.md) | Not Supported |
| [SpaceToDepth](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/SpaceToDepth_1.md) | Not Supported |
| [Split](https://github.com/openvinotoolkit/openvino/blob/master/docs/ops/movement/Split_1.md) | Supported* |
Expand Down
40 changes: 40 additions & 0 deletions modules/nvidia_plugin/src/cuda/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,46 @@ inline __device__ T log(T a) {
return static_cast<T>(::logf(static_cast<float>(a)));
}

template <typename T>
inline __device__ T sign_float(T x) {
static_assert(std::is_floating_point<T>::value, "T should be floating_point type");
if (x < 0.0f) return -1.0f;
if (x > 0.0f) return 1.0f;
return 0.0f;
}

template <>
inline __device__ __half sign_float<__half>(__half x) {
const __half zero = __float2half(0.0f);
if (x < zero) return __half(-1.0f);
if (x > zero) return __half(1.0f);
return zero;
}

#ifdef CUDA_HAS_BF16_TYPE
template <>
inline __device__ __nv_bfloat16 sign_float<__nv_bfloat16>(__nv_bfloat16 x) {
const __nv_bfloat16 zero = __float2bfloat16(0.0f);
if (x < zero) return __nv_bfloat16(-1.0f);
if (x > zero) return __nv_bfloat16(1.0f);
return zero;
}
#endif

template <typename T>
inline __device__ T sign_int(T x) {
static_assert(std::is_integral<T>::value && !std::is_unsigned<T>::value,
"T should be integer type");
return static_cast<T>((x > 0) - (x < 0));
}

template <typename T>
inline __device__ T sign_uint(T x) {
static_assert(std::is_integral<T>::value && std::is_unsigned<T>::value,
"T should be unsigned integer type");
return static_cast<T>(x > 0);
}

#ifdef __CUDACC__
/* ==================== __half ===================== */
template <>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ __global__ void elementwise_binary_broadcasting(const T* in0,

#endif // __CUDACC__

template <typename ElementTypes, template <typename> typename OP>
template <typename ElementTypes, template <typename... TArgs> typename OP>
class ElementwiseBinary {
public:
ElementwiseBinary(Type_t element_type, size_t out_num_elements, size_t max_threads_per_block)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ __global__ void elementwise_unary(const T* in, size_t num_elements, T* out, Args

#endif // __CUDACC__

template <typename ElementTypes, template <typename> typename OP>
template <typename ElementTypes, template <typename... TArgs> typename OP>
class ElementwiseUnary {
public:
ElementwiseUnary(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
Expand Down
29 changes: 29 additions & 0 deletions modules/nvidia_plugin/src/kernels/negative.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2021-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "negative.hpp"

namespace ov {
namespace nvidia_gpu {
namespace kernel {

namespace cumath = CUDA::math;

template <typename T>
struct NegativeOpImpl {
__device__ static inline T op(T x) {
return -x;
}
};

Negative::Negative(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
: impl_{element_type, max_threads_per_block, num_elements} {}

void Negative::operator()(cudaStream_t stream, const void* in0, void* out) const {
impl_(stream, in0, out);
}

} // namespace kernel
} // namespace nvidia_gpu
} // namespace ov
31 changes: 31 additions & 0 deletions modules/nvidia_plugin/src/kernels/negative.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2021-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "details/cuda_type_traits.hpp"
#include "details/elementwise_unary.cuh"

namespace ov {
namespace nvidia_gpu {
namespace kernel {

template <typename T>
struct NegativeOpImpl;
/**
* Elementwise Acosh operation
*/
class Negative {
public:
Negative(Type_t element_type, size_t max_threads_per_block, size_t num_elements);

void operator()(cudaStream_t stream, const void* in0, void* out) const;

private:
ElementwiseUnary<AllElementTypesSwitch, NegativeOpImpl> impl_;
};

} // namespace kernel
} // namespace nvidia_gpu
} // namespace ov
71 changes: 71 additions & 0 deletions modules/nvidia_plugin/src/kernels/sign.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (C) 2021-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "sign.hpp"

namespace ov {
namespace nvidia_gpu {
namespace kernel {

namespace cumath = CUDA::math;

template <typename T, typename Enable = void>
struct SignOpImpl {
__device__ static inline T op(T x);
};

template <>
struct SignOpImpl<char> {
__device__ static inline char op(char x) {
return cumath::sign_int(x);
}
};

template <typename T>
struct SignOpImpl<T, typename std::enable_if<std::is_integral<T>::value &&
!std::is_unsigned<T>::value>::type> {
__device__ static inline T op(T x) {
return cumath::sign_int(x);
}
};

template <typename T>
struct SignOpImpl<T, typename std::enable_if<std::is_integral<T>::value &&
std::is_unsigned<T>::value>::type> {
__device__ static inline T op(T x) {
return cumath::sign_uint(x);
}
};

template <typename T>
struct SignOpImpl<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
__device__ static inline T op(T x) {
return cumath::sign_float(x);
}
};

template <>
struct SignOpImpl<__nv_bfloat16> {
__device__ static inline __nv_bfloat16 op(__nv_bfloat16 x) {
return cumath::sign_float(x);
}
};

template <>
struct SignOpImpl<__half> {
__device__ static inline __half op(__half x) {
return cumath::sign_float(x);
}
};

Sign::Sign(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
: impl_{element_type, max_threads_per_block, num_elements} {}

void Sign::operator()(cudaStream_t stream, const void* in0, void* out) const {
impl_(stream, in0, out);
}

} // namespace kernel
} // namespace nvidia_gpu
} // namespace ov
31 changes: 31 additions & 0 deletions modules/nvidia_plugin/src/kernels/sign.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2021-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "details/cuda_type_traits.hpp"
#include "details/elementwise_unary.cuh"

namespace ov {
namespace nvidia_gpu {
namespace kernel {

template <typename T, typename>
struct SignOpImpl;
/**
* Elementwise Sign operation
*/
class Sign {
public:
Sign(Type_t element_type, size_t max_threads_per_block, size_t num_elements);

void operator()(cudaStream_t stream, const void* in0, void* out) const;

private:
ElementwiseUnary<AllElementTypesSwitch, SignOpImpl> impl_;
};

} // namespace kernel
} // namespace nvidia_gpu
} // namespace ov
46 changes: 46 additions & 0 deletions modules/nvidia_plugin/src/kernels/soft_sign.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2021-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "soft_sign.hpp"

namespace ov {
namespace nvidia_gpu {
namespace kernel {

namespace cumath = CUDA::math;

template <typename T>
__device__ constexpr T one = static_cast<T>(1);

template <typename T>
struct SoftSignOpImpl {
__device__ static inline T op(T x) {
return x / (one<T> + cumath::abs(x));
}
};

template <>
struct SoftSignOpImpl<__nv_bfloat16> {
__device__ static inline __nv_bfloat16 op(__nv_bfloat16 x) {
return x / (__nv_bfloat16(1.0f) + cumath::abs(x));
}
};

template <>
struct SoftSignOpImpl<__half> {
__device__ static inline __half op(__half x) {
return x / (__half(1.0f) + cumath::abs(x));
}
};

SoftSign::SoftSign(Type_t element_type, size_t max_threads_per_block, size_t num_elements)
: impl_{element_type, max_threads_per_block, num_elements} {}

void SoftSign::operator()(cudaStream_t stream, const void* in0, void* out) const {
impl_(stream, in0, out);
}

} // namespace kernel
} // namespace nvidia_gpu
} // namespace ov
31 changes: 31 additions & 0 deletions modules/nvidia_plugin/src/kernels/soft_sign.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2021-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "details/cuda_type_traits.hpp"
#include "details/elementwise_unary.cuh"

namespace ov {
namespace nvidia_gpu {
namespace kernel {

template <typename T>
struct SoftSignOpImpl;
/**
* Elementwise Sign operation
*/
class SoftSign {
public:
SoftSign(Type_t element_type, size_t max_threads_per_block, size_t num_elements);

void operator()(cudaStream_t stream, const void* in0, void* out) const;

private:
ElementwiseUnary<AllElementTypesSwitch, SoftSignOpImpl> impl_;
};

} // namespace kernel
} // namespace nvidia_gpu
} // namespace ov
15 changes: 15 additions & 0 deletions modules/nvidia_plugin/src/ops/negative.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "negative.hpp"

#include <cuda_operation_registry.hpp>

namespace ov {
namespace nvidia_gpu {

OPERATION_REGISTER(NegativeOp, Negative);

} // namespace nvidia_gpu
} // namespace ov
22 changes: 22 additions & 0 deletions modules/nvidia_plugin/src/ops/negative.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <cuda_operation_base.hpp>

#include "elementwise_unary.hpp"
#include "kernels/negative.hpp"
#include "openvino/op/negative.hpp"

namespace ov {
namespace nvidia_gpu {

class NegativeOp : public ElementwiseUnaryOp<ov::op::v0::Negative, kernel::Negative> {
public:
using ElementwiseUnaryOp::ElementwiseUnaryOp;
};

} // namespace nvidia_gpu
} // namespace ov
15 changes: 15 additions & 0 deletions modules/nvidia_plugin/src/ops/sign.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "sign.hpp"

#include <cuda_operation_registry.hpp>

namespace ov {
namespace nvidia_gpu {

OPERATION_REGISTER(SignOp, Sign);

} // namespace nvidia_gpu
} // namespace ov
Loading
Loading