-
Notifications
You must be signed in to change notification settings - Fork 145
Open
Description
Hi, if I try the following code:
#include <enzyme/enzyme>
#include <cstdio>
struct objective {
double operator()(double x) { return x * 10; }
};
int main() {
// ok
// auto y = enzyme::autodiff<enzyme::Reverse>(objective{}, enzyme::Active<double>(3.1));
// not ok
objective fn;
auto y = enzyme::autodiff<enzyme::Reverse>(fn, enzyme::Active<double>(3.1));
auto y1 = enzyme::get<0>(enzyme::get<0>(y));
std::printf("dmul %f\n", y1);
}It results in a compiler error:
error: 'f' declared as a pointer to a reference of type 'objective &'
I believe this could (hopefully) be fixed by changing to something like:
template<typename function, typename RT, typename ...T>
struct templated_call<function, RT(T...)> {
static RT wrap(std::remove_reference_t<function>* __restrict__ f, T... args) {
return (*f)(args...);
}
};in this:
Enzyme/enzyme/include/enzyme/utils
Lines 330 to 335 in 1961293
| template<typename function, typename RT, typename ...T> | |
| struct templated_call<function, RT(T...)> { | |
| static RT wrap(function* __restrict__ f, T... args) { | |
| return (*f)(args...); | |
| } | |
| }; |
I guess in the tests only the // ok part of the code is tested, see e.g.
| auto y = enzyme::autodiff<enzyme::Reverse>(overload{}, enzyme::Active<double>(3.1)); |
I can set up a pull request, or if someone who knows a better fix or is more familiar with the code (I am not really familiar at all) perhaps they could fix accordingly and I would be very happy :)
Thank you.
Metadata
Metadata
Assignees
Labels
No labels