@@ -20,38 +20,35 @@ namespace cu {
2020template <typename Op, typename In, typename Out>
2121constexpr bool supports_unary_op () {
2222 if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
23- std::is_same_v<Op, Sign>) {
23+ std::is_same_v<Op, Sign> || std::is_same_v<Op, Square> ) {
2424 return std::is_same_v<In, Out>;
2525 }
26- if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
27- std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
28- std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
29- std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
30- std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
31- std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
26+ if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
27+ std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
28+ std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
29+ std::is_same_v<Op, Sigmoid>) {
3230 return std::is_same_v<In, Out> && is_floating_v<In>;
3331 }
34- if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
35- std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
36- return std::is_same_v<In, Out> && is_inexact_v<In>;
37- }
3832 if (std::is_same_v<Op, BitwiseInvert>) {
3933 return std::is_same_v<In, Out> && std::is_integral_v<In> &&
4034 !std::is_same_v<In, bool >;
4135 }
42- if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
43- std::is_same_v<Op, Square>) {
36+ if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
4437 return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t >;
4538 }
4639 if (std::is_same_v<Op, Conjugate>) {
4740 return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t >;
4841 }
49- if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
50- std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
51- std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
52- std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
53- return std::is_same_v<In, Out> &&
54- (is_floating_v<In> || std::is_same_v<In, complex64_t >);
42+ if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
43+ std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
44+ std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
45+ std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
46+ std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
47+ std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
48+ std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
49+ std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
50+ std::is_same_v<Op, Tanh>) {
51+ return std::is_same_v<In, Out> && is_inexact_v<In>;
5552 }
5653 if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
5754 return std::is_same_v<In, complex64_t > && std::is_same_v<Out, float >;
0 commit comments