forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReduceLogicKernel.cu
44 lines (37 loc) · 1.26 KB
/
ReduceLogicKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Reduce.cuh>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/ReduceOps.h>
namespace at { namespace native {
void and_kernel_cuda(TensorIterator& iter) {
gpu_reduce_kernel<uint8_t, uint8_t>(
iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t {
return a && b;
}), true);
}
void or_kernel_cuda(TensorIterator& iter) {
gpu_reduce_kernel<uint8_t, uint8_t>(
iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t {
return a || b;
}), false);
}
REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
REGISTER_DISPATCH(or_stub, &or_kernel_cuda);
bool cuda_equal(const Tensor& self, const Tensor &src) {
if (!at::namedinference::are_names_equal(
self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) {
return false;
}
at::NoNamesGuard guard;
TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on "
"different devices. Got: ", self.device(), " and ", src.device());
if (self.sizes() != src.sizes()) {
return false;
}
if (self.numel() == 0) {
return true;
}
return at::native::eq(self, src).all().item().to<bool>();
}
}} // namespace at::native