@@ -110,6 +110,17 @@ int main(int argc, char **argv) {
110
110
exit (1 );
111
111
}
112
112
113
+ grad_loss_out_wrt_a.copy_to_host ();
114
+ grad_loss_out_wrt_b.copy_to_host ();
115
+ grad_loss_out_wrt_c.copy_to_host ();
116
+ dummy_grad_loss_output_wrt_lut.copy_to_host ();
117
+ dummy_grad_loss_output_wrt_lut_indices.copy_to_host ();
118
+ dummy_grad_loss_output_lut_wrt_input_a.copy_to_host ();
119
+ dummy_grad_loss_output_lut_wrt_input_b.copy_to_host ();
120
+ dummy_grad_loss_output_lut_wrt_input_c.copy_to_host ();
121
+ grad_loss_output_lut_wrt_lut.copy_to_host ();
122
+ grad_loss_output_lut_wrt_lut_indices.copy_to_host ();
123
+
113
124
// Although the values are float, all should be exact results,
114
125
// so we don't need to worry about comparing vs. an epsilon
115
126
grad_loss_out_wrt_a.for_each_element ([&](int x) {
@@ -118,18 +129,21 @@ int main(int argc, char **argv) {
118
129
float actual = grad_loss_out_wrt_a (x);
119
130
assert (expected == actual);
120
131
});
132
+
121
133
grad_loss_out_wrt_b.for_each_element ([&](int x) {
122
134
// ∂𝐿/∂b = b * 44 * L
123
135
float expected = L (x) * b (x) * 44 .f ;
124
136
float actual = grad_loss_out_wrt_b (x);
125
137
assert (expected == actual);
126
138
});
139
+
127
140
grad_loss_out_wrt_c.for_each_element ([&](int x) {
128
141
// ∂𝐿/∂c = 11 * L
129
142
float expected = L (x) * 11 .f ;
130
143
float actual = grad_loss_out_wrt_c (x);
131
144
assert (expected == actual);
132
145
});
146
+
133
147
dummy_grad_loss_output_wrt_lut.for_each_value ([](float f) { assert (f == 0 .f ); });
134
148
dummy_grad_loss_output_wrt_lut_indices.for_each_value ([](float f) { assert (f == 0 .f ); });
135
149
dummy_grad_loss_output_lut_wrt_input_a.for_each_value ([](float f) { assert (f == 0 .f ); });
0 commit comments