Skip to content

Commit e692617

Browse files
committed
Mullapudi2016-gpu: Copy output buffer data to host
1 parent 9a85cd1 commit e692617

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

test/generator/alias_aottest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ int main(int argc, char **argv) {
5252
output.fill(0);
5353
output.copy_to_host();
5454
alias_Mullapudi2016(input, output);
55+
output.copy_to_host();
5556
input.for_each_element([=](int x) {
5657
assert(output(x) == input(x) + 2016);
5758
});

test/generator/autograd_aottest.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ int main(int argc, char **argv) {
110110
exit(1);
111111
}
112112

113+
grad_loss_out_wrt_a.copy_to_host();
114+
grad_loss_out_wrt_b.copy_to_host();
115+
grad_loss_out_wrt_c.copy_to_host();
116+
dummy_grad_loss_output_wrt_lut.copy_to_host();
117+
dummy_grad_loss_output_wrt_lut_indices.copy_to_host();
118+
dummy_grad_loss_output_lut_wrt_input_a.copy_to_host();
119+
dummy_grad_loss_output_lut_wrt_input_b.copy_to_host();
120+
dummy_grad_loss_output_lut_wrt_input_c.copy_to_host();
121+
grad_loss_output_lut_wrt_lut.copy_to_host();
122+
grad_loss_output_lut_wrt_lut_indices.copy_to_host();
123+
113124
// Although the values are float, all should be exact results,
114125
// so we don't need to worry about comparing vs. an epsilon
115126
grad_loss_out_wrt_a.for_each_element([&](int x) {
@@ -118,18 +129,21 @@ int main(int argc, char **argv) {
118129
float actual = grad_loss_out_wrt_a(x);
119130
assert(expected == actual);
120131
});
132+
121133
grad_loss_out_wrt_b.for_each_element([&](int x) {
122134
// ∂𝐿/∂b = b * 44 * L
123135
float expected = L(x) * b(x) * 44.f;
124136
float actual = grad_loss_out_wrt_b(x);
125137
assert(expected == actual);
126138
});
139+
127140
grad_loss_out_wrt_c.for_each_element([&](int x) {
128141
// ∂𝐿/∂c = 11 * L
129142
float expected = L(x) * 11.f;
130143
float actual = grad_loss_out_wrt_c(x);
131144
assert(expected == actual);
132145
});
146+
133147
dummy_grad_loss_output_wrt_lut.for_each_value([](float f) { assert(f == 0.f); });
134148
dummy_grad_loss_output_wrt_lut_indices.for_each_value([](float f) { assert(f == 0.f); });
135149
dummy_grad_loss_output_lut_wrt_input_a.for_each_value([](float f) { assert(f == 0.f); });

0 commit comments

Comments
 (0)