Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Oct 9, 2023
1 parent 0a3b011 commit 1bc9a2e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions source/lib/tests/test_tabulate_se_a.cc
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) {
std::vector<double> dy_dem_x(em_x.size());
std::vector<double> dy_dem(em.size());
std::vector<double> dy(nloc * nnei * last_layer_size, 1.0);
std::vector<double> dy_dtwo(nloc * nnei * last_layer_size);
double *dy_dtwo = nullptr;
deepmd::tabulate_fusion_se_a_grad_cpu<double>(
&dy_dem_x[0], &dy_dem[0], dy_dtwo, &table[0], &info[0], &em_x[0], &em[0],
&dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0], &em[0],
nullptr, &dy[0], nloc, nnei, last_layer_size);
EXPECT_EQ(dy_dem_x.size(), nloc * nnei);
EXPECT_EQ(dy_dem.size(), nloc * nnei * 4);
Expand All @@ -742,7 +743,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) {
}

deepmd::tabulate_fusion_se_a_grad_cpu<double>(
&dy_dem_x[0], &dy_dem[0], dy_dtwo, &table[0], &info[0], &em_x[0], &em[0],
&dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0], &em[0],
&two_embed[0], &dy[0], nloc, nnei, last_layer_size);
EXPECT_EQ(dy_dem_x.size(), nloc * nnei);
EXPECT_EQ(dy_dem.size(), nloc * nnei * 4);
Expand Down Expand Up @@ -803,9 +804,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) {
std::vector<double> dy_dem_x(em_x.size(), 0.0);
std::vector<double> dy_dem(em.size(), 0.0);
std::vector<double> dy(nloc * nnei * last_layer_size, 1.0);
std::vector<double> dy_dtwo(nloc * nnei * last_layer_size, 0.0);

double *dy_dem_x_dev = NULL, *dy_dem_dev = NULL, *table_dev = NULL,
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, *dy_dtwo = nullptr;
*em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, *dy_dtwo_dev = nullptr;
deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x);
deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem);
deepmd::malloc_device_memory_sync(table_dev, table);
Expand Down Expand Up @@ -833,8 +835,9 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) {
deepmd::malloc_device_memory_sync(two_embed_dev, two_embed);
deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x);
deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem);
deepmd::malloc_device_memory_sync(dy_dtwo_dev, dy_dtwo);
deepmd::tabulate_fusion_se_a_grad_gpu<double>(
dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev,
dy_dem_x_dev, dy_dem_dev, dy_dtwo_dev, table_dev, &info[0], em_x_dev, em_dev,
two_embed_dev, dy_dev, nloc, nnei, last_layer_size);
deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x);
deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem);
Expand Down

0 comments on commit 1bc9a2e

Please sign in to comment.