Skip to content

Commit

Permalink
Mullapudi2016-gpu: Copy output buffer data to host
Browse files Browse the repository at this point in the history
  • Loading branch information
antonysigma committed Aug 23, 2024
1 parent 33d0974 commit 475d52f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/generator/alias_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ int main(int argc, char **argv) {
output.fill(0);
output.copy_to_host();
alias_Mullapudi2016(input, output);
output.copy_to_host();
input.for_each_element([=](int x) {
assert(output(x) == input(x) + 2016);
});
Expand Down
14 changes: 14 additions & 0 deletions test/generator/autograd_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ int main(int argc, char **argv) {
exit(1);
}

grad_loss_out_wrt_a.copy_to_host();
grad_loss_out_wrt_b.copy_to_host();
grad_loss_out_wrt_c.copy_to_host();
dummy_grad_loss_output_wrt_lut.copy_to_host();
dummy_grad_loss_output_wrt_lut_indices.copy_to_host();
dummy_grad_loss_output_lut_wrt_input_a.copy_to_host();
dummy_grad_loss_output_lut_wrt_input_b.copy_to_host();
dummy_grad_loss_output_lut_wrt_input_c.copy_to_host();
grad_loss_output_lut_wrt_lut.copy_to_host();
grad_loss_output_lut_wrt_lut_indices.copy_to_host();

// Although the values are float, all should be exact results,
// so we don't need to worry about comparing vs. an epsilon
grad_loss_out_wrt_a.for_each_element([&](int x) {
Expand All @@ -118,18 +129,21 @@ int main(int argc, char **argv) {
float actual = grad_loss_out_wrt_a(x);
assert(expected == actual);
});

grad_loss_out_wrt_b.for_each_element([&](int x) {
// ∂𝐿/∂b = b * 44 * L
float expected = L(x) * b(x) * 44.f;
float actual = grad_loss_out_wrt_b(x);
assert(expected == actual);
});

grad_loss_out_wrt_c.for_each_element([&](int x) {
// ∂𝐿/∂c = 11 * L
float expected = L(x) * 11.f;
float actual = grad_loss_out_wrt_c(x);
assert(expected == actual);
});

dummy_grad_loss_output_wrt_lut.for_each_value([](float f) { assert(f == 0.f); });
dummy_grad_loss_output_wrt_lut_indices.for_each_value([](float f) { assert(f == 0.f); });
dummy_grad_loss_output_lut_wrt_input_a.for_each_value([](float f) { assert(f == 0.f); });
Expand Down

0 comments on commit 475d52f

Please sign in to comment.