Skip to content

Commit

Permalink
tried to pass jit but failed
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 11, 2024
1 parent 44c939b commit 784911c
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 50 deletions.
7 changes: 7 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
use_aparam_as_mask: bool = False,
spin: Any = None,
distinguish_types: bool = False,
do_hessian: bool = False,
):
# seed, uniform_seed are not included
if tot_ener_zero:
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
self.use_aparam_as_mask = use_aparam_as_mask
self.spin = spin
self.distinguish_types = distinguish_types
self.do_hessian = do_hessian
if self.spin is not None:
raise NotImplementedError("spin is not supported")

Expand Down Expand Up @@ -204,10 +206,15 @@ def output_def(self):
reduciable=True,
r_differentiable=True,
c_differentiable=True,
r_hessian=self.do_hessian,
),
]
)

def require_hessian(self, yes=False):
"""Set requirement for the calculation of Hessian."""
self.do_hessian = yes

def __setitem__(self, key, value):
if key in ["bias_atom_e"]:
self.bias_atom_e = value
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def distinguish_types(self) -> bool:
"""
return self.descriptor.distinguish_types()

def require_hessian(self, yes=False):
"""Set requirement for the calculation of Hessian."""
self.fitting.require_hessian(yes=yes)

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _make_env_mat(
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight * np.expand_dims(mask, -1)
weight = weight * np.expand_dims(mask, -1)
env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight
return env_mat_se_a, diff * np.expand_dims(mask, -1), weight


Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/model/descriptor/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float):
t0 = 1 / length
t1 = diff / length**2
weight = compute_smooth_weight(length, ruct_smth, rcut)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight * mask.unsqueeze(-1)
weight = weight * mask.unsqueeze(-1)
env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight
return env_mat_se_a, diff * mask.unsqueeze(-1), weight


Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
return self.descriptor.distinguish_types()

def require_hessian(self, yes=False):
"""Set requirement for the calculation of Hessian."""
self.fitting_net.require_hessian(yes=yes)

def serialize(self) -> dict:
return {
"type_map": self.type_map,
Expand Down
95 changes: 66 additions & 29 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
from typing import (
Dict,
List,
Expand Down Expand Up @@ -64,10 +63,14 @@ def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.fitting_output_def())

# cannot use the name forward. torch script does not work
# wrapper for computing hessian. We only provide hessian calculation
# for the forward interface, thus the jacobian is not used to compute
# hessian, but computing from scratch.
# Cannot use the name forward. torch script does not work
#
# Wrapper for computing hessian. We only provide hessian calculation
# for the forward interface, not for forward_lower
#
# Low efficiency: hessian is computed from scratch, we do not use
# the jacobian.
@torch.jit.export
def forward_common(
self,
coord,
Expand Down Expand Up @@ -110,14 +113,17 @@ def forward_common(
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
hess = self._cal_hessian_all(
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
)
ret.update(hess)
vdef = self.fitting_output_def()
hess_yes = [vdef[kk].r_hessian for kk in vdef.keys()]
if any(hess_yes):
hess = self._cal_hessian_all(
coord,
atype,
box=box,
fparam=fparam,
aparam=aparam,
)
ret.update(hess)
return ret

# cannot use the name forward. torch script does not work
Expand Down Expand Up @@ -192,6 +198,7 @@ def forward_common_(
)
return model_predict

@torch.jit.export
def forward_common_lower(
self,
extended_coord,
Expand Down Expand Up @@ -336,29 +343,38 @@ def _format_nlist(
assert nlist.shape[-1] == nnei
return nlist

## FAILED TO JIT this method
## torch/autograd/functional does not support jit script.
## tested torch version: 2.2
@torch.jit.ignore
def _cal_hessian_all(
self,
coord,
atype,
coord: torch.Tensor,
atype: torch.Tensor,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
):
) -> Dict[str, torch.Tensor]:
nf, nloc = atype.shape
coord = coord.view([nf, (nloc * 3)])
box = box.view([nf, 9]) if box is not None else None
fparam = fparam.view([nf, -1]) if fparam is not None else None
aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None
fdef = self.fitting_output_def()
# keys of values that require hessian
hess_keys: List[str] = []
for kk in fdef.keys():
if fdef[kk].r_hessian:
hess_keys.append(kk)
# result dict init by empty lists
res = {get_hessian_name(kk): [] for kk in hess_keys}
# loop over variable
for kk in hess_keys:
vdef = fdef[kk]
vshape = vdef.shape
vsize = 1
for ii in vshape:
vsize *= ii
# loop over frames
for ii in range(nf):
icoord = coord[ii]
Expand All @@ -367,7 +383,7 @@ def _cal_hessian_all(
ifparam = fparam[ii] if fparam is not None else None
iaparam = aparam[ii] if aparam is not None else None
# loop over all components
for idx in itertools.product(*[range(ii) for ii in vshape]):
for idx in range(vsize):
hess = self._cal_hessian_one_component(
idx, icoord, iatype, ibox, ifparam, iaparam
)
Expand All @@ -391,23 +407,44 @@ def _cal_hessian_one_component(
# box: Optional[torch.Tensor] = None, # 9
# fparam: Optional[torch.Tensor] = None, # nfp
# aparam: Optional[torch.Tensor] = None, # (nloc x nap)
def wrapped_forward_energy(xx):
res = self.forward_common_(
xx.unsqueeze(0),
atype.unsqueeze(0),
box.unsqueeze(0) if box is not None else None,
fparam.unsqueeze(0) if fparam is not None else None,
aparam.unsqueeze(0) if aparam is not None else None,
do_atomic_virial=False,
)
return res["energy_redu"][(0, *ci)]
wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam)

hess = torch.autograd.functional.hessian(
wrapped_forward_energy,
wc,
coord,
create_graph=False,
vectorize=True,
)
return hess

class wrapper_class_forward_energy:
def __init__(
self,
obj: CM,
ci: int,
atype: torch.Tensor,
box: Optional[torch.Tensor],
fparam: Optional[torch.Tensor],
aparam: Optional[torch.Tensor],
):
self.atype, self.box, self.fparam, self.aparam = atype, box, fparam, aparam
self.ci = ci
self.obj = obj

def __call__(
self,
xx,
):
ci = self.ci
atype, box, fparam, aparam = self.atype, self.box, self.fparam, self.aparam
res = self.obj.forward_common_(
xx.unsqueeze(0),
atype.unsqueeze(0),
box.unsqueeze(0) if box is not None else None,
fparam.unsqueeze(0) if fparam is not None else None,
aparam.unsqueeze(0) if aparam is not None else None,
do_atomic_virial=False,
)
er = res["energy_redu"][0].view([-1])[ci]
return er

return CM
4 changes: 4 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def output_def(self) -> FittingOutputDef:
]
)

def require_hessian(self, yes=False):
"""Set requirement for the calculation of Hessian."""
self.do_hessian = yes

def __setitem__(self, key, value):
if key in ["bias_atom_e"]:
# correct bias_atom_e shape. user may provide stupid shape
Expand Down
15 changes: 10 additions & 5 deletions source/tests/common/dpmodel/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(
self,
):
nf, nloc, nnei = self.nlist.shape
ds = DescrptSeA(
self.rcut,
Expand All @@ -39,7 +35,16 @@ def test_self_consistency(
distinguish_types=ds.distinguish_types(),
)
type_map = ["foo", "bar"]
md0 = DPModel(ds, ft, type_map=type_map)
self.md0 = DPModel(ds, ft, type_map=type_map)

def test_methods(self):
self.md0.require_hessian(yes=True)
self.assertTrue(self.md0.fitting_output_def()["energy"].r_hessian)

def test_self_consistency(
self,
):
md0 = self.md0
md1 = DPModel.deserialize(md0.serialize())

ret0 = md0.call_lower(self.coord_ext, self.atype_ext, self.nlist)
Expand Down
18 changes: 7 additions & 11 deletions source/tests/pt/model/test_auto_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test(
self,
):
places = 8
delta = 1e-4
delta = 1e-3
natoms = self.nloc
nf = self.nf
nv = self.nv
Expand All @@ -72,8 +72,8 @@ def test(
coord = coord.view([nf, natoms * 3])
atype = torch.stack(
[
torch.IntTensor([0, 0, 0, 1, 1]),
torch.IntTensor([0, 1, 1, 0, 1]),
torch.IntTensor([0, 0, 1]),
torch.IntTensor([1, 0, 1]),
]
).view([nf, natoms])
# assumes input to be numpy tensor
Expand All @@ -82,10 +82,6 @@ def test(
fparam = torch.rand([nf, nfp], dtype=dtype)
aparam = torch.rand([nf, natoms * nap], dtype=dtype)

coord = coord.view([nf, natoms, 3])
coord = coord[:, [0, 1, 2, 3, 4], :]
coord = coord.view([nf, natoms * 3])

ret_dict0 = self.model_hess.forward_common(
coord, atype, box=cell, fparam=fparam, aparam=aparam
)
Expand Down Expand Up @@ -129,12 +125,12 @@ class TestDPModel(unittest.TestCase, HessianTest):
def setUp(self):
torch.manual_seed(2)
self.nf = 2
self.nloc = 5
self.nloc = 3
self.rcut = 4.0
self.rcut_smth = 3.0
self.sel = [15, 15]
self.sel = [10, 10]
self.nt = 2
self.nv = 1
self.nv = 2
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
Expand All @@ -157,7 +153,7 @@ def setUp(self):
env.DEVICE
)
self.model_valu = DPModel.deserialize(self.model_hess.serialize())

self.model_valu.require_hessian(yes=False)
# args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]]
# ret0 = md0.forward_common(*args)

Expand Down
11 changes: 8 additions & 3 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
def setUp(self):
TestCaseSingleFrameWithoutNlist.setUp(self)

def test_self_consistency(self):
nf, nloc = self.atype.shape
ds = DescrptSeA(
self.rcut,
Expand All @@ -57,7 +55,14 @@ def test_self_consistency(self):
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
self.md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)

def test_methods(self):
self.md0.require_hessian(yes=True)
self.assertTrue(self.md0.fitting_output_def()["energy"].r_hessian)

def test_self_consistency(self):
md0 = self.md0
md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE)
args = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]]
ret0 = md0.forward_common(*args)
Expand Down

0 comments on commit 784911c

Please sign in to comment.