Skip to content

Commit edf96d1

Browse files
committed
[advalue] Convert constants if necessary
Unfortunately we cannot rely on the fact that number-like type `T` support arithmetic operations with `int` or `double`. E.g. `T = boost::numeric::interval<S>` only supports arithmetic operations between `T` and `S`. Hence we cannot use `t-1`. While `t-1.` works for `S=double` it still fails for `S=float`. As a remedy we use check if constants of the raw type are compatible and convert the constant to `T` otherwise. Doing the conversion always might be costly for other types `T` that support cheap operations with elementary type (e.g. `T=ADValue<S>` itself).
1 parent bfa0545 commit edf96d1

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

ADOL-C/include/adolc/advalue.h

+17-9
Original file line numberDiff line numberDiff line change
@@ -1027,12 +1027,14 @@ auto adCompose(const Value &x, const Derivatives &f) {
10271027
template <class X> X inv(const X &x) {
10281028
return adCompose(x, [](auto k, auto x) {
10291029
constexpr auto order = decltype(k)::value;
1030+
constexpr auto canUseInt = requires() { 1 / x; };
1031+
using Constant = std::conditional_t<canUseInt, int, decltype(x)>;
10301032
if constexpr (order == 0)
1031-
return 1 / x;
1033+
return Constant(1) / x;
10321034
if constexpr (order == 1)
1033-
return -1 / (x * x);
1035+
return Constant(-1) / (x * x);
10341036
if constexpr (order == 2)
1035-
return 2 / (x * x * x);
1037+
return Constant(2)./ (x * x * x);
10361038
static_assert(order <= 2, "Only derivatives up to order 2 are implemented");
10371039
});
10381040
}
@@ -1125,12 +1127,14 @@ template <class X> auto log(const X &x) {
11251127
using std::pow;
11261128
return adCompose(x, [](auto k, auto x) {
11271129
constexpr auto order = decltype(k)::value;
1130+
constexpr auto canUseInt = requires() { 1 / x; };
1131+
using Constant = std::conditional_t<canUseInt, int, decltype(x)>;
11281132
if constexpr (order == 0)
11291133
return log(x);
11301134
if constexpr (order == 1)
1131-
return 1. / x;
1135+
return Constant(1) / x;
11321136
if constexpr (order == 2)
1133-
return -1. / (x * x);
1137+
return Constant(-1) / (x * x);
11341138
static_assert(order <= 2, "Only derivatives up to order 2 are implemented");
11351139
});
11361140
}
@@ -1143,12 +1147,14 @@ template <class X> auto sqrt(const X &x) {
11431147
using std::sqrt;
11441148
return adCompose(x, [](auto k, auto x) {
11451149
constexpr auto order = decltype(k)::value;
1150+
constexpr auto canUseDouble = requires() { 1. / x; };
1151+
using Constant = std::conditional_t<canUseDouble, double, decltype(x)>;
11461152
if constexpr (order == 0)
11471153
return sqrt(x);
11481154
if constexpr (order == 1)
1149-
return 1. / (2. * sqrt(x));
1155+
return Constant(1. / 2.) / sqrt(x);
11501156
if constexpr (order == 2)
1151-
return -1. / (4. * x * sqrt(x));
1157+
return Constant(-1. / 4.) / (x * sqrt(x));
11521158
static_assert(order <= 2, "Only derivatives up to order 2 are implemented");
11531159
});
11541160
}
@@ -1162,12 +1168,14 @@ auto pow(const X &x, const Y &y) {
11621168
using std::pow;
11631169
return adCompose(x, [y](auto k, auto x) {
11641170
constexpr auto order = decltype(k)::value;
1171+
constexpr auto canUseInt = requires() { y - 1; };
1172+
using Constant = std::conditional_t<canUseInt, int, decltype(y)>;
11651173
if constexpr (order == 0)
11661174
return pow(x, y);
11671175
if constexpr (order == 1)
1168-
return y * pow(x, y - 1);
1176+
return y * pow(x, y - Constant(1));
11691177
if constexpr (order == 2)
1170-
return y * (y - 1.) * pow(x, y - 2.);
1178+
return y * (y - Constant(1)) * pow(x, y - Constant(2));
11711179
static_assert(order <= 2, "Only derivatives up to order 2 are implemented");
11721180
});
11731181
}

0 commit comments

Comments
 (0)