forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLinearAlgebraKernel.cpp
128 lines (116 loc) · 4.32 KB
/
LinearAlgebraKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <ATen/ATen.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/cpu/Reduce.h>
#include <ATen/native/cpu/Loops.h>
namespace at { namespace native { namespace {
void addr_kernel(TensorIterator &iter,
const Scalar& beta, const Scalar& alpha) {
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
// when beta is false, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == false) {
cpu_kernel(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return alpha_val && vec1_val && vec2_val;
}
);
} else {
cpu_kernel(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
}
);
}
return;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
iter.dtype(), "addr_cpu", [&]() {
using Vec = Vec256<scalar_t>;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
auto beta_vec = Vec(beta_val);
auto alpha_vec = Vec(alpha_val);
const scalar_t zero_val(0);
// when beta == 0, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == zero_val) {
cpu_kernel_vec(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return alpha_val * vec1_val * vec2_val;
},
[=](Vec self_vec,
Vec vec1_vec,
Vec vec2_vec) __ubsan_ignore_undefined__ {
return alpha_vec * vec1_vec * vec2_vec;
}
);
} else {
cpu_kernel_vec(iter,
[=](scalar_t self_val,
scalar_t vec1_val,
scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
return beta_val * self_val + alpha_val * vec1_val * vec2_val;
},
[=](Vec self_vec,
Vec vec1_vec,
Vec vec2_vec) __ubsan_ignore_undefined__ {
return beta_vec * self_vec + alpha_vec * vec1_vec * vec2_vec;
}
);
}
}
);
}
template <typename scalar_t, typename acc_t=typename scalar_value_type<scalar_t>::type>
void linalg_vector_norm_kernel_cpu_impl(TensorIterator& iter, Scalar ord) {
double ord_val;
if (ord.isFloatingPoint()) {
ord_val = ord.to<double>();
} else {
TORCH_CHECK(false, "linalg.vector_norm expects ord to be float");
}
acc_t init_val = (ord_val == -INFINITY) ? std::numeric_limits<acc_t>::infinity() : static_cast<acc_t>(0);
if (iter.numel() == 0) {
iter.output().fill_((ord_val < 0) ? INFINITY : 0);
return;
}
if (ord_val == 0) {
binary_kernel_reduce(iter, NormZeroOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == 1) {
binary_kernel_reduce(iter, NormOneOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == 2) {
binary_kernel_reduce(iter, NormTwoOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == INFINITY) {
binary_kernel_reduce(iter, AbsMaxOps<scalar_t, acc_t>(), init_val);
} else if (ord_val == -INFINITY) {
binary_kernel_reduce(iter, AbsMinOps<scalar_t, acc_t>(), init_val);
} else {
binary_kernel_reduce(iter, NormOps<scalar_t, acc_t> { static_cast<acc_t>(ord_val) }, init_val);
}
// For complex outputs, the above kernels do not touch the imaginary values,
// so we must zero them out
if (isComplexType(iter.output().scalar_type())) {
at::imag(iter.output()).zero_();
}
}
static void linalg_vector_norm_kernel_cpu(TensorIterator& iter, Scalar ord) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "linalg_vector_norm_cpu", [&] {
linalg_vector_norm_kernel_cpu_impl<scalar_t>(iter, ord);
});
}
} // anonymous namespace
REGISTER_DISPATCH(addr_stub, &addr_kernel);
REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cpu);
}} // namespace at::native