Skip to content

Commit e063006

Browse files
committed
Add test_modify_list
Currently fails.
1 parent c14abdc commit e063006

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

enzyme/test/Integration/ReverseMode/stl_list.cpp

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#include <list>
99

1010

11-
template<typename T>
12-
extern double __enzyme_fwddiff(void*, int, T&, T&);
13-
template<typename T>
14-
extern double __enzyme_autodiff(void*, int, T&, T&);
11+
template<typename ...T>
12+
extern double __enzyme_fwddiff(void*, T...);
13+
template<typename ...T>
14+
extern double __enzyme_autodiff(void*, T...);
1515

1616

1717
double test_iterate_list(std::list<double>& vals) {
@@ -23,24 +23,62 @@ double test_iterate_list(std::list<double>& vals) {
2323
return result;
2424
}
2525

26+
struct S {
27+
S(double r) : x(r) {};
28+
double x = 0.0;
29+
};
30+
31+
double test_modify_list(std::list<S> vals, double x) {
32+
vals.front().x = x;
33+
34+
// iterate over list
35+
double result = 0.0;
36+
for (const auto& val : vals) {
37+
result += val.x * val.x;
38+
}
39+
return result;
40+
}
41+
2642
void test_forward_list() {
43+
// diff all values of list
2744
{
2845
std::list<double> vals = {1.0, 2.0, 3.0};
2946
std::list<double> dvals = {1.0, 1.0, 1.0};
3047

3148
double ret = __enzyme_fwddiff((void*)test_iterate_list, enzyme_dup, vals, dvals);
3249
APPROX_EQ(ret, 12., 1e-10);
3350
}
51+
52+
// list is const, then first value set to active
53+
{
54+
std::list<S> vals = {S{1.0}, S{2.0}, S{3.0}};
55+
double x = 3.0;
56+
double dx = 1.0;
57+
58+
double ret = __enzyme_fwddiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, x, dx);
59+
APPROX_EQ(ret, 6., 1e-10);
60+
}
3461
}
3562

3663
void test_reverse_list() {
64+
// diff all values of list
3765
{
3866
std::list<double> vals = {1.0, 2.0, 3.0};
3967
std::list<double> dvals = {1.0, 1.0, 1.0};
4068

4169
double ret = __enzyme_autodiff((void*)test_iterate_list, enzyme_dup, vals, dvals);
4270
APPROX_EQ(ret, 12., 1e-10);
4371
}
72+
73+
// list is const, then first value set to active
74+
{
75+
std::list<S> vals = {S{1.0}, S{2.0}, S{3.0}};
76+
double x = 3.0;
77+
double dx = 1.0;
78+
79+
double ret = __enzyme_autodiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, x, dx);
80+
APPROX_EQ(ret, 6., 1e-10);
81+
}
4482
}
4583

4684

0 commit comments

Comments
 (0)