Skip to content

Commit

Permalink
support separate r_differentiable and c_differentiable
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 6, 2024
1 parent 6c12380 commit c02b22c
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 67 deletions.
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
self.var_name, [self.dim_out], reduciable=True, differentiable=True
self.var_name,
[self.dim_out],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
),
]
)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def do_grad_(
) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return self.fitting_output_def()[var_name].differentiable
return (
self.fitting_output_def()[var_name].r_differentiable
or self.fitting_output_def()[var_name].c_differentiable
)

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")
Expand Down
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/pair_tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
name="energy",
shape=[1],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
)
]
)
Expand Down
10 changes: 8 additions & 2 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ def fit_output_to_model_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = np.sum(vv, axis=atom_axis)
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name-holders
model_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
model_ret[kk_derv_c] = None
return model_ret

Expand All @@ -57,10 +60,13 @@ def communicate_extended_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
Expand Down
45 changes: 31 additions & 14 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def __call__(
if dd.reduciable:
rk = get_reduce_name(kk)
check_var(ret[rk], self.md[rk])
if dd.differentiable:
if dd.r_differentiable:
dnr, dnc = get_deriv_name(kk)
check_var(ret[dnr], self.md[dnr])
if dd.c_differentiable:
assert dd.r_differentiable
check_var(ret[dnc], self.md[dnc])
return ret

Expand Down Expand Up @@ -160,9 +162,12 @@ class OutputVariableDef:
dipole should be [3], polarizabilty should be [3,3].
reduciable
If the variable is reduced.
differentiable
r_differentiable
If the variable is differentiated with respect to coordinates
of atoms and cell tensor (pbc case). Only reduciable variable
of atoms. Only reduciable variable are differentiable.
c_differentiable
If the variable is differentiated with respect to the
cell tensor (pbc case). Only reduciable variable
are differentiable.
category : int
The category of the output variable.
Expand All @@ -173,19 +178,25 @@ def __init__(
name: str,
shape: List[int],
reduciable: bool = False,
differentiable: bool = False,
r_differentiable: bool = False,
c_differentiable: bool = False,
atomic: bool = True,
category: int = OutputVariableCategory.OUT.value,
):
self.name = name
self.shape = list(shape)
self.atomic = atomic
self.reduciable = reduciable
self.differentiable = differentiable
if not self.reduciable and self.differentiable:
raise ValueError("only reduciable variable are differentiable")
self.r_differentiable = r_differentiable
self.c_differentiable = c_differentiable
if self.c_differentiable and not self.r_differentiable:
raise ValueError("c differentiable requires r_differentiable")
if not self.reduciable and self.r_differentiable:
raise ValueError("only reduciable variable are r differentiable")
if not self.reduciable and self.c_differentiable:
raise ValueError("only reduciable variable are c differentiable")
if self.reduciable and not self.atomic:
raise ValueError("only reduciable variable should be atomic")
raise ValueError("a reduciable variable should be atomic")
self.category = category


Expand Down Expand Up @@ -358,7 +369,8 @@ def do_reduce(
rk,
vv.shape,
reduciable=False,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=False,
category=apply_operation(vv, OutputVariableOperation.REDU),
)
Expand All @@ -371,21 +383,26 @@ def do_derivative(
def_derv_r: Dict[str, OutputVariableDef] = {}
def_derv_c: Dict[str, OutputVariableDef] = {}
for kk, vv in def_outp_data.items():
if vv.differentiable:
rkr, rkc = get_deriv_name(kk)
rkr, rkc = get_deriv_name(kk)
if vv.r_differentiable:
def_derv_r[rkr] = OutputVariableDef(
rkr,
vv.shape + [3], # noqa: RUF005
reduciable=False,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_R),
)
if vv.c_differentiable:
assert vv.r_differentiable
rkr, rkc = get_deriv_name(kk)
def_derv_c[rkc] = OutputVariableDef(
rkc,
vv.shape + [3, 3], # noqa: RUF005
vv.shape + [9], # noqa: RUF005
reduciable=True,
differentiable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
category=apply_operation(vv, OutputVariableOperation.DERV_C),
)
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/pair_tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
name="energy", shape=[1], reduciable=True, differentiable=True
name="energy",
shape=[1],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
)
]
)
Expand Down
64 changes: 43 additions & 21 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def task_deriv_one(
atom_energy: torch.Tensor,
energy: torch.Tensor,
extended_coord: torch.Tensor,
do_virial: bool = True,
do_atomic_virial: bool = False,
):
faked_grad = torch.ones_like(energy)
Expand All @@ -65,13 +66,16 @@ def task_deriv_one(
)[0]
assert extended_force is not None
extended_force = -extended_force
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005
if do_virial:
extended_virial = extended_force.unsqueeze(-1) @ extended_coord.unsqueeze(-2)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005
else:
extended_virial = None
return extended_force, extended_virial


Expand All @@ -97,6 +101,7 @@ def take_deriv(
svv: torch.Tensor,
vdef: OutputVariableDef,
coord_ext: torch.Tensor,
do_virial: bool = False,
do_atomic_virial: bool = False,
):
size = 1
Expand All @@ -110,16 +115,25 @@ def take_deriv(
for vvi, svvi in zip(split_vv1, split_svv1):
# nf x nloc x 3, nf x nloc x 9
ffi, aviri = task_deriv_one(
vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial
vvi,
svvi,
coord_ext,
do_virial=do_virial,
do_atomic_virial=do_atomic_virial,
)
# nf x nloc x 1 x 3, nf x nloc x 1 x 9
ffi = ffi.unsqueeze(-2)
aviri = aviri.unsqueeze(-2)
split_ff.append(ffi)
split_avir.append(aviri)
if do_virial:
assert aviri is not None
aviri = aviri.unsqueeze(-2)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
avir = torch.concat(split_avir, dim=-2)
if do_virial:
avir = torch.concat(split_avir, dim=-2)
else:
avir = None
return ff, avir


Expand All @@ -141,18 +155,23 @@ def fit_output_to_model_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
model_ret[kk_redu] = torch.sum(vv, dim=atom_axis)
if vdef.differentiable:
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
dr, dc = take_deriv(
vv,
model_ret[kk_redu],
vdef,
coord_ext,
do_virial=vdef.c_differentiable,
do_atomic_virial=do_atomic_virial,
)
model_ret[kk_derv_r] = dr
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(model_ret[kk_derv_c], dim=1)
if vdef.c_differentiable:
assert dc is not None
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(
model_ret[kk_derv_c], dim=1
)
return model_ret


Expand All @@ -174,12 +193,12 @@ def communicate_extended_output(
if vdef.reduciable:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
if vdef.differentiable:
# nf x nloc
vldims = get_leading_dims(vv, vdef)
# nf x nall
mldims = list(mapping.shape)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# nf x nloc
vldims = get_leading_dims(vv, vdef)
# nf x nall
mldims = list(mapping.shape)
kk_derv_r, kk_derv_c = get_deriv_name(kk)
if vdef.r_differentiable:
# vdim x 3
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = mapping.view(mldims + [1] * len(derv_r_ext_dims)).expand(
Expand All @@ -196,10 +215,13 @@ def communicate_extended_output(
src=model_ret[kk_derv_r],
reduce="sum",
)
if vdef.c_differentiable:
assert vdef.r_differentiable
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
# nf x nloc x nvar x 3 -> nf x nloc x nvar x 9
mapping = torch.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
mapping,
[1] * (len(mldims) + len(vdef.shape)) + [3],
)
virial = torch.zeros(
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/task/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@ def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
"updated_coord", [3], reduciable=False, differentiable=False
"updated_coord",
[3],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"logits", [-1], reduciable=False, differentiable=False
"logits",
[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
20 changes: 17 additions & 3 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
self.var_name, [self.dim_out], reduciable=True, differentiable=True
self.var_name,
[self.dim_out],
reduciable=True,
r_differentiable=True,
c_differentiable=True,
),
]
)
Expand Down Expand Up @@ -459,9 +463,19 @@ def __init__(
def output_def(self):
return FittingOutputDef(
[
OutputVariableDef("energy", [1], reduciable=True, differentiable=False),
OutputVariableDef(
"dforce", [3], reduciable=False, differentiable=False
"energy",
[1],
reduciable=True,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"dforce",
[3],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
Expand Down
Loading

0 comments on commit c02b22c

Please sign in to comment.