Skip to content

Commit

Permalink
add layernorm on g/h; make jit happy (#3)
Browse files Browse the repository at this point in the history
* add layernorm on g/h; make jit happy

* Update repformer_layer.py
  • Loading branch information
iProzd authored Oct 8, 2024
1 parent c05ed09 commit 5a354bc
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 14 deletions.
12 changes: 12 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ def __init__(
use_sqrt_nnei: bool = True,
g1_out_conv: bool = True,
g1_out_mlp: bool = True,
update_h1_has_g1: bool = False,
update_h2_has_g2: bool = False,
output_g1_ln: bool = False,
output_g2_ln: bool = False,
output_h1_ln: bool = False,
output_h2_ln: bool = False,
ln_eps: Optional[float] = 1e-5,
):
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -308,6 +314,12 @@ def __init__(
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
self.update_h1_has_g1 = update_h1_has_g1
self.update_h2_has_g2 = update_h2_has_g2
self.output_g1_ln = output_g1_ln
self.output_g2_ln = output_g2_ln
self.output_h1_ln = output_h1_ln
self.output_h2_ln = output_h2_ln
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def init_subclass_params(sub_data, sub_class):
use_sqrt_nnei=self.repformer_args.use_sqrt_nnei,
g1_out_conv=self.repformer_args.g1_out_conv,
g1_out_mlp=self.repformer_args.g1_out_mlp,
update_h1_has_g1=self.repformer_args.update_h1_has_g1,
update_h2_has_g2=self.repformer_args.update_h2_has_g2,
output_g1_ln=self.repformer_args.output_g1_ln,
output_g2_ln=self.repformer_args.output_g2_ln,
output_h1_ln=self.repformer_args.output_h1_ln,
output_h2_ln=self.repformer_args.output_h2_ln,
seed=child_seed(seed, 1),
old_impl=old_impl,
)
Expand Down
87 changes: 73 additions & 14 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
List,
Optional,
Union,
Tuple,
)

import torch
Expand Down Expand Up @@ -587,8 +588,8 @@ def __init__(
update_g2_has_g1g1: bool = True,
update_g2_has_attn: bool = True,
update_h2: bool = False,
update_h2_has_g2: bool = True,
update_h1_has_g1: bool = True,
update_h1_has_g1: bool = False,
update_h2_has_g2: bool = False,
attn1_hidden: int = 64,
attn1_nhead: int = 4,
attn2_hidden: int = 16,
Expand All @@ -605,6 +606,10 @@ def __init__(
use_sqrt_nnei: bool = True,
g1_out_conv: bool = True,
g1_out_mlp: bool = True,
output_g1_ln=False,
output_g2_ln=False,
output_h1_ln=False,
output_h2_ln=False,
seed: Optional[Union[int, List[int]]] = None,
):
super().__init__()
Expand Down Expand Up @@ -656,6 +661,10 @@ def __init__(
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
self.output_g1_ln = output_g1_ln
self.output_g2_ln = output_g2_ln
self.output_h1_ln = output_h1_ln
self.output_h2_ln = output_h2_ln

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -695,6 +704,10 @@ def __init__(
self.attn2_lm = None
self.attn2_ev_apply = None
self.loc_attn = None
self.g1_ln = None
self.g2_ln = None
self.h1_ln = None
self.h2_ln = None

if self.update_chnnl_2:
self.linear2 = MLPLayer(
Expand Down Expand Up @@ -919,6 +932,38 @@ def __init__(
self.g2_residual = nn.ParameterList(self.g2_residual)
self.h1_residual = nn.ParameterList(self.h1_residual)
self.h2_residual = nn.ParameterList(self.h2_residual)
if self.output_g1_ln:
self.g1_ln = LayerNorm(
g1_dim,
eps=ln_eps,
trainable=trainable_ln,
precision=precision,
seed=child_seed(seed, 18),
)
if self.output_g2_ln:
self.g2_ln = LayerNorm(
g2_dim,
eps=ln_eps,
trainable=trainable_ln,
precision=precision,
seed=child_seed(seed, 19),
)
if self.output_h1_ln:
self.h1_ln = LayerNorm(
g1_dim,
eps=ln_eps,
trainable=trainable_ln,
precision=precision,
seed=child_seed(seed, 20),
)
if self.output_h2_ln:
self.h2_ln = LayerNorm(
g2_dim,
eps=ln_eps,
trainable=trainable_ln,
precision=precision,
seed=child_seed(seed, 21),
)

def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int:
ret = g1d if not self.g1_out_mlp else 0
Expand Down Expand Up @@ -958,7 +1003,7 @@ def _update_g1_conv(
h2: torch.Tensor, # nb x nloc x nnei x 3 x ng2
nlist_mask: torch.Tensor,
sw: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calculate the convolution update for atomic invariant rep.
Expand All @@ -974,14 +1019,11 @@ def _update_g1_conv(
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nb x nloc x nnei.
"""
assert self.proj_g1g2_1 is not None
assert self.proj_h1h2_1 is not None
assert self.proj_g1g2_2 is not None
assert self.proj_h1h2_2 is not None
nb, nloc, nnei, _ = g2.shape
ng1 = gg1.shape[-1]
ng2 = g2.shape[-1]
if not self.g1_out_conv:
assert self.proj_g1g2 is not None
# gg1 : nb x nloc x nnei x ng2
gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2)
else:
Expand Down Expand Up @@ -1012,10 +1054,14 @@ def _update_g1_conv(
)
if not self.g1_out_conv:
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
g1_1 = g1_11
h1_1 = None
g1_1 = torch.sum(g2 * gg1, dim=2) * invnnei
g1_11 = g1_1
g1_12, h1_11, h1_12 = None, None, None
else:
assert self.proj_g1g2_1 is not None
assert self.proj_h1h2_1 is not None
assert self.proj_g1g2_2 is not None
assert self.proj_h1h2_2 is not None
# nb x nloc x ng1
g2_1 = self.proj_g1g2_1(g2).view(nb, nloc, nnei, ng1)
g2_2 = self.proj_g1g2_2(g2).view(nb, nloc, nnei, ng1)
Expand Down Expand Up @@ -1112,6 +1158,9 @@ def _update_g1_conv(
# torch.std(h1_11).detach().numpy(),
# torch.std(h1_12).detach().numpy(),
# )
assert g1_12 is not None
assert h1_11 is not None
assert h1_12 is not None
return g1_11, g1_12, h1_11, h1_12

@staticmethod
Expand Down Expand Up @@ -1353,6 +1402,7 @@ def forward(
)
else:
gg1 = None
hh1 = None

if self.update_chnnl_2:
# mlp(g2)
Expand Down Expand Up @@ -1389,6 +1439,7 @@ def forward(

if self.update_g1_has_conv:
assert gg1 is not None
assert hh1 is not None
g11_conv, g12_conv, h11_conv, h12_conv = self._update_g1_conv(
gg1, hh1, g2, h2, nlist_mask, sw
)
Expand Down Expand Up @@ -1452,6 +1503,18 @@ def forward(
g2_new, h2_new = g2, h2
g1_new = self.list_update(g1_update, "g1")
h1_new = self.list_update(h1_update, "h1")
if self.output_g1_ln:
assert self.g1_ln is not None
g1_new = self.g1_ln(g1_new)
if self.output_g2_ln:
assert self.g2_ln is not None
g2_new = self.g2_ln(g2_new)
if self.output_h1_ln:
assert self.h1_ln is not None
h1_new = self.h1_ln(h1_new)
if self.output_h2_ln:
assert self.h2_ln is not None
h2_new = self.h2_ln(h2_new)
return g1_new, g2_new, h1_new, h2_new

@torch.jit.export
Expand Down Expand Up @@ -1482,19 +1545,15 @@ def list_update_res_residual(
uu = update_list[0]
# make jit happy
if update_name == "g1":
assert nitem == len(self.g1_residual) + 1
for ii, vv in enumerate(self.g1_residual):
uu = uu + vv * update_list[ii + 1]
elif update_name == "g2":
assert nitem == len(self.g2_residual) + 1
for ii, vv in enumerate(self.g2_residual):
uu = uu + vv * update_list[ii + 1]
elif update_name == "h1":
assert nitem == len(self.h1_residual) + 1
for ii, vv in enumerate(self.h1_residual):
uu = uu + vv * update_list[ii + 1]
elif update_name == "h2":
assert nitem == len(self.h2_residual) + 1
for ii, vv in enumerate(self.h2_residual):
uu = uu + vv * update_list[ii + 1]
else:
Expand Down
23 changes: 23 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def __init__(
use_sqrt_nnei: bool = True,
g1_out_conv: bool = True,
g1_out_mlp: bool = True,
output_g1_ln: bool = False,
output_g2_ln: bool = False,
output_h1_ln: bool = False,
output_h2_ln: bool = False,
update_h1_has_g1: bool = False,
update_h2_has_g2: bool = False,
old_impl: bool = False,
):
r"""
Expand Down Expand Up @@ -234,6 +240,12 @@ def __init__(
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
self.update_h1_has_g1 = update_h1_has_g1
self.update_h2_has_g2 = update_h2_has_g2
self.output_g1_ln = output_g1_ln
self.output_g2_ln = output_g2_ln
self.output_h1_ln = output_h1_ln
self.output_h2_ln = output_h2_ln
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
Expand Down Expand Up @@ -325,6 +337,12 @@ def __init__(
use_sqrt_nnei=self.use_sqrt_nnei,
g1_out_conv=self.g1_out_conv,
g1_out_mlp=self.g1_out_mlp,
update_h1_has_g1=self.update_h1_has_g1,
update_h2_has_g2=self.update_h2_has_g2,
output_g1_ln=self.output_g1_ln,
output_g2_ln=self.output_g2_ln,
output_h1_ln=self.output_h1_ln,
output_h2_ln=self.output_h2_ln,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -496,13 +514,16 @@ def forward(
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
mapping3 = mapping.view(nframes, nall, 1, self.g1_dim).expand(-1, -1, 3, -1)
else:
mapping3 = None
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
# h1: nb x nloc x ng1 x 3
# h1_ext: nb x nall x ng1 x 3
if comm_dict is None:
assert mapping is not None
assert mapping3 is not None
g1_ext = torch.gather(g1, 1, mapping)
h1_ext = torch.gather(h1, 1, mapping3)
else:
Expand All @@ -528,6 +549,8 @@ def forward(
torch.tensor(nall - nloc), # pylint: disable=no-explicit-dtype,no-explicit-device
)
g1_ext = ret[0].unsqueeze(0)
h1_ext = ret[0].unsqueeze(0) # place holder
assert h1_ext is not None
g1, g2, h1, h2 = ll.forward(
g1_ext,
g2,
Expand Down
36 changes: 36 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,42 @@ def dpa2_repformer_args():
default=True,
doc=doc_g1_out_mlp,
),
Argument(
"update_h1_has_g1",
bool,
optional=True,
default=False,
),
Argument(
"update_h2_has_g2",
bool,
optional=True,
default=False,
),
Argument(
"output_g1_ln",
bool,
optional=True,
default=False,
),
Argument(
"output_g2_ln",
bool,
optional=True,
default=False,
),
Argument(
"output_h1_ln",
bool,
optional=True,
default=False,
),
Argument(
"output_h2_ln",
bool,
optional=True,
default=False,
),
Argument(
"update_h2",
bool,
Expand Down

0 comments on commit 5a354bc

Please sign in to comment.