From 1bc9a2ee64ede39ad6ec8c01375efb7a276f7058 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 9 Oct 2023 19:04:32 -0400 Subject: [PATCH] fix tests --- source/lib/tests/test_tabulate_se_a.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/source/lib/tests/test_tabulate_se_a.cc b/source/lib/tests/test_tabulate_se_a.cc index 41a3b09734..6829177f08 100644 --- a/source/lib/tests/test_tabulate_se_a.cc +++ b/source/lib/tests/test_tabulate_se_a.cc @@ -726,9 +726,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) { std::vector dy_dem_x(em_x.size()); std::vector dy_dem(em.size()); std::vector dy(nloc * nnei * last_layer_size, 1.0); + std::vector dy_dtwo(nloc * nnei * last_layer_size); double *dy_dtwo = nullptr; deepmd::tabulate_fusion_se_a_grad_cpu( - &dy_dem_x[0], &dy_dem[0], dy_dtwo, &table[0], &info[0], &em_x[0], &em[0], + &dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0], &em[0], nullptr, &dy[0], nloc, nnei, last_layer_size); EXPECT_EQ(dy_dem_x.size(), nloc * nnei); EXPECT_EQ(dy_dem.size(), nloc * nnei * 4); @@ -742,7 +743,7 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_cpu) { } deepmd::tabulate_fusion_se_a_grad_cpu( - &dy_dem_x[0], &dy_dem[0], dy_dtwo, &table[0], &info[0], &em_x[0], &em[0], + &dy_dem_x[0], &dy_dem[0], &dy_dtwo[0], &table[0], &info[0], &em_x[0], &em[0], &two_embed[0], &dy[0], nloc, nnei, last_layer_size); EXPECT_EQ(dy_dem_x.size(), nloc * nnei); EXPECT_EQ(dy_dem.size(), nloc * nnei * 4); @@ -803,9 +804,10 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { std::vector dy_dem_x(em_x.size(), 0.0); std::vector dy_dem(em.size(), 0.0); std::vector dy(nloc * nnei * last_layer_size, 1.0); + std::vector dy_dtwo(nloc * nnei * last_layer_size, 0.0); double *dy_dem_x_dev = NULL, *dy_dem_dev = NULL, *table_dev = NULL, - *em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, *dy_dtwo = nullptr; + *em_x_dev = NULL, *em_dev = NULL, *dy_dev = NULL, *dy_dtwo_dev = nullptr; deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); deepmd::malloc_device_memory_sync(table_dev, table); @@ -833,8 +835,9 @@ TEST_F(TestTabulateSeA, tabulate_fusion_se_a_grad_gpu) { deepmd::malloc_device_memory_sync(two_embed_dev, two_embed); deepmd::malloc_device_memory_sync(dy_dem_x_dev, dy_dem_x); deepmd::malloc_device_memory_sync(dy_dem_dev, dy_dem); + deepmd::malloc_device_memory_sync(dy_dtwo_dev, dy_dtwo); deepmd::tabulate_fusion_se_a_grad_gpu( - dy_dem_x_dev, dy_dem_dev, dy_dtwo, table_dev, &info[0], em_x_dev, em_dev, + dy_dem_x_dev, dy_dem_dev, dy_dtwo_dev, table_dev, &info[0], em_x_dev, em_dev, two_embed_dev, dy_dev, nloc, nnei, last_layer_size); deepmd::memcpy_device_to_host(dy_dem_x_dev, dy_dem_x); deepmd::memcpy_device_to_host(dy_dem_dev, dy_dem);