Skip to content

Commit

Permalink
implement dy_dtwo for gradgrad
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Oct 10, 2023
1 parent 63a4aee commit aad7830
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
1 change: 1 addition & 0 deletions deepmd/op/_tabulate_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo):
op.inputs[4],
dy,
dy_,
dy_dtwo,
op.inputs[6],
is_sorted=op.get_attr("is_sorted"),
)
Expand Down
2 changes: 2 additions & 0 deletions source/lib/include/tabulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down Expand Up @@ -147,6 +148,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
16 changes: 13 additions & 3 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const FPTYPE lower,
const FPTYPE upper,
const FPTYPE max,
Expand Down Expand Up @@ -413,9 +414,15 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
xx) *
xx;
FPTYPE two_grad = 0.;
if (enable_se_atten) {
FPTYPE t = two_embed[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx];
// dz_dy_dtwo * res * em
// res above should be used instead of res + res * t below
two_grad = dz_dy_dtwo[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx] *
res;
res += res * t;
res_grad += res_grad * t;
}
Expand Down Expand Up @@ -443,8 +450,8 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
iteratorC[kk * last_layer_size + thread_idx] +=
(nnei - breakpoint) *
(em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res);
(nnei - breakpoint) * (em[em_index] * (res_grad * dz_xx + two_grad) +
dz_dy_dem[em_index] * res);
}
mark_table_idx = table_idx;
if (unloop) {
Expand Down Expand Up @@ -813,6 +820,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -825,7 +833,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size));
tabulate_fusion_se_a_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, dz_dy_dtwo,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
Expand Down Expand Up @@ -1036,6 +1044,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<float>(
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const float* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -1049,6 +1058,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<double>(
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const double* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
31 changes: 22 additions & 9 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
// fill from jj to nnei
for (int jj2 = jj; jj2 < nnei; jj2++) {
dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size +
kk] += res * dotllrr;
kk] += resold * dotllrr;
}
}
} else {
Expand Down Expand Up @@ -267,6 +267,7 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down Expand Up @@ -317,9 +318,15 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx;
FPTYPE two_grad = 0.;
if (enable_se_atten) {
FPTYPE t = two_embed[ii * nnei * last_layer_size +
jj * last_layer_size + kk];
// dz_dy_dtwo * var * ll
// var above should be used instead of var + var * t below
two_grad = dz_dy_dtwo[ii * nnei * last_layer_size +
jj * last_layer_size + kk] *
var;
var += var * t;
var_grad += var_grad * t;
}
Expand All @@ -346,22 +353,26 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
*/
if (unloop) {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);
(nnei - jj) *
(var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0]);
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[1] + dz_xx * var_grad * ll[1]);
(nnei - jj) *
(var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1]);
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[2] + dz_xx * var_grad * ll[2]);
(nnei - jj) *
(var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2]);
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[3] + dz_xx * var_grad * ll[3]);
(nnei - jj) *
(var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3]);
} else {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
var * hh[0] + dz_xx * var_grad * ll[0];
var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0];
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] +=
var * hh[1] + dz_xx * var_grad * ll[1];
var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1];
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] +=
var * hh[2] + dz_xx * var_grad * ll[2];
var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2];
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] +=
var * hh[3] + dz_xx * var_grad * ll[3];
var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3];
}
}
if (unloop) {
Expand Down Expand Up @@ -711,6 +722,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<float>(
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const float* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -724,6 +736,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<double>(
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const double* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
12 changes: 8 additions & 4 deletions source/op/tabulate_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ REGISTER_OP("TabulateFusionSeAttenGradGrad")
.Input("two_embed: T")
.Input("dz_dy_dem_x: T")
.Input("dz_dy_dem: T")
.Input("dz_dy_dtwo: T")
.Input("descriptor: T")
.Output("dz_dy: T")
.Attr("is_sorted: bool = true");
Expand Down Expand Up @@ -329,6 +330,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel {
const FPTYPE* two_embed = nullptr;
const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.flat<FPTYPE>().data();
const FPTYPE* dz_dy_dem = dz_dy_dem_tensor.flat<FPTYPE>().data();
const FPTYPE* dz_dy_dtwo = nullptr;
const int nloc = em_tensor.shape().dim_size(0);
const int nnei = em_tensor.shape().dim_size(1);
const int last_layer_size = descriptor_tensor.shape().dim_size(2);
Expand All @@ -337,7 +339,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
deepmd::tabulate_fusion_se_a_grad_grad_gpu(
dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
nloc, nnei, last_layer_size, is_sorted);
dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
OP_REQUIRES(context, (last_layer_size <= 1024),
errors::InvalidArgument(
Expand All @@ -346,7 +348,7 @@ class TabulateFusionSeAGradGradOp : public OpKernel {
} else if (device == "CPU") {
deepmd::tabulate_fusion_se_a_grad_grad_cpu(
dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
nloc, nnei, last_layer_size, is_sorted);
dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted);
}
}

Expand Down Expand Up @@ -522,6 +524,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel {
const Tensor& two_embed_tensor = context->input(context_input_index++);
const Tensor& dz_dy_dem_x_tensor = context->input(context_input_index++);
const Tensor& dz_dy_dem_tensor = context->input(context_input_index++);
const Tensor& dz_dy_dtwo_tensor = context->input(context_input_index++);
const Tensor& descriptor_tensor = context->input(context_input_index++);
// set size of the sample
OP_REQUIRES(context, (dz_dy_dem_x_tensor.shape().dims() == 2),
Expand All @@ -544,6 +547,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel {
const FPTYPE* two_embed = two_embed_tensor.flat<FPTYPE>().data();
const FPTYPE* dz_dy_dem_x = dz_dy_dem_x_tensor.flat<FPTYPE>().data();
const FPTYPE* dz_dy_dem = dz_dy_dem_tensor.flat<FPTYPE>().data();
const FPTYPE* dz_dy_dtwo = dz_dy_dtwo_tensor.flat<FPTYPE>().data();
const int nloc = em_tensor.shape().dim_size(0);
const int nnei = em_tensor.shape().dim_size(1);
const int last_layer_size = descriptor_tensor.shape().dim_size(2);
Expand All @@ -552,7 +556,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
deepmd::tabulate_fusion_se_a_grad_grad_gpu(
dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
nloc, nnei, last_layer_size, is_sorted);
dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
OP_REQUIRES(context, (last_layer_size <= 1024),
errors::InvalidArgument(
Expand All @@ -561,7 +565,7 @@ class TabulateFusionSeAttenGradGradOp : public OpKernel {
} else if (device == "CPU") {
deepmd::tabulate_fusion_se_a_grad_grad_cpu(
dz_dy, table, table_info, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
nloc, nnei, last_layer_size, is_sorted);
dz_dy_dtwo, nloc, nnei, last_layer_size, is_sorted);
}
}

Expand Down

0 comments on commit aad7830

Please sign in to comment.