From 5a354bcd5b59ebaee92cb531969bd7245bb5dcec Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:31:02 +0800 Subject: [PATCH] add layernorm on g/h; make jit happy (#3) * add layernorm on g/h; make jit happy * Update repformer_layer.py --- deepmd/dpmodel/descriptor/dpa2.py | 12 +++ deepmd/pt/model/descriptor/dpa2.py | 6 ++ deepmd/pt/model/descriptor/repformer_layer.py | 87 ++++++++++++++++--- deepmd/pt/model/descriptor/repformers.py | 23 +++++ deepmd/utils/argcheck.py | 36 ++++++++ 5 files changed, 150 insertions(+), 14 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 43c57f443f..943eb6c3a2 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -206,6 +206,12 @@ def __init__( use_sqrt_nnei: bool = True, g1_out_conv: bool = True, g1_out_mlp: bool = True, + update_h1_has_g1: bool = False, + update_h2_has_g2: bool = False, + output_g1_ln: bool = False, + output_g2_ln: bool = False, + output_h1_ln: bool = False, + output_h2_ln: bool = False, ln_eps: Optional[float] = 1e-5, ): r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor. @@ -308,6 +314,12 @@ def __init__( self.use_sqrt_nnei = use_sqrt_nnei self.g1_out_conv = g1_out_conv self.g1_out_mlp = g1_out_mlp + self.update_h1_has_g1 = update_h1_has_g1 + self.update_h2_has_g2 = update_h2_has_g2 + self.output_g1_ln = output_g1_ln + self.output_g2_ln = output_g2_ln + self.output_h1_ln = output_h1_ln + self.output_h2_ln = output_h2_ln # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 9fc4fc4a21..d8ac6a4650 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -237,6 +237,12 @@ def init_subclass_params(sub_data, sub_class): use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, g1_out_conv=self.repformer_args.g1_out_conv, g1_out_mlp=self.repformer_args.g1_out_mlp, + update_h1_has_g1=self.repformer_args.update_h1_has_g1, + update_h2_has_g2=self.repformer_args.update_h2_has_g2, + output_g1_ln=self.repformer_args.output_g1_ln, + output_g2_ln=self.repformer_args.output_g2_ln, + output_h1_ln=self.repformer_args.output_h1_ln, + output_h2_ln=self.repformer_args.output_h2_ln, seed=child_seed(seed, 1), old_impl=old_impl, ) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 5e6f53e43f..011e129ee4 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -3,6 +3,7 @@ List, Optional, Union, + Tuple, ) import torch @@ -587,8 +588,8 @@ def __init__( update_g2_has_g1g1: bool = True, update_g2_has_attn: bool = True, update_h2: bool = False, - update_h2_has_g2: bool = True, - update_h1_has_g1: bool = True, + update_h1_has_g1: bool = False, + update_h2_has_g2: bool = False, attn1_hidden: int = 64, attn1_nhead: int = 4, attn2_hidden: int = 16, @@ -605,6 +606,10 @@ def __init__( use_sqrt_nnei: bool = True, g1_out_conv: bool = True, g1_out_mlp: bool = True, + output_g1_ln=False, + output_g2_ln=False, + output_h1_ln=False, + output_h2_ln=False, seed: Optional[Union[int, List[int]]] = None, ): super().__init__() @@ -656,6 +661,10 @@ def __init__( self.use_sqrt_nnei = use_sqrt_nnei self.g1_out_conv = g1_out_conv self.g1_out_mlp = g1_out_mlp + self.output_g1_ln = output_g1_ln + self.output_g2_ln = output_g2_ln + self.output_h1_ln = output_h1_ln + self.output_h2_ln = output_h2_ln assert update_residual_init in [ "norm", @@ -695,6 +704,10 @@ def __init__( self.attn2_lm = None self.attn2_ev_apply = None self.loc_attn = None + self.g1_ln = None + self.g2_ln = None + self.h1_ln = None + self.h2_ln = None if self.update_chnnl_2: self.linear2 = MLPLayer( @@ -919,6 +932,38 @@ def __init__( self.g2_residual = nn.ParameterList(self.g2_residual) self.h1_residual = nn.ParameterList(self.h1_residual) self.h2_residual = nn.ParameterList(self.h2_residual) + if self.output_g1_ln: + self.g1_ln = LayerNorm( + g1_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 18), + ) + if self.output_g2_ln: + self.g2_ln = LayerNorm( + g2_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 19), + ) + if self.output_h1_ln: + self.h1_ln = LayerNorm( + g1_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 20), + ) + if self.output_h2_ln: + self.h2_ln = LayerNorm( + g2_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 21), + ) def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret = g1d if not self.g1_out_mlp else 0 @@ -958,7 +1003,7 @@ def _update_g1_conv( h2: torch.Tensor, # nb x nloc x nnei x 3 x ng2 nlist_mask: torch.Tensor, sw: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate the convolution update for atomic invariant rep. @@ -974,14 +1019,11 @@ def _update_g1_conv( The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei. """ - assert self.proj_g1g2_1 is not None - assert self.proj_h1h2_1 is not None - assert self.proj_g1g2_2 is not None - assert self.proj_h1h2_2 is not None nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] if not self.g1_out_conv: + assert self.proj_g1g2 is not None # gg1 : nb x nloc x nnei x ng2 gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) else: @@ -1012,10 +1054,14 @@ def _update_g1_conv( ) if not self.g1_out_conv: # nb x nloc x ng2 - g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei - g1_1 = g1_11 - h1_1 = None + g1_1 = torch.sum(g2 * gg1, dim=2) * invnnei + g1_11 = g1_1 + g1_12, h1_11, h1_12 = None, None, None else: + assert self.proj_g1g2_1 is not None + assert self.proj_h1h2_1 is not None + assert self.proj_g1g2_2 is not None + assert self.proj_h1h2_2 is not None # nb x nloc x ng1 g2_1 = self.proj_g1g2_1(g2).view(nb, nloc, nnei, ng1) g2_2 = self.proj_g1g2_2(g2).view(nb, nloc, nnei, ng1) @@ -1112,6 +1158,9 @@ def _update_g1_conv( # torch.std(h1_11).detach().numpy(), # torch.std(h1_12).detach().numpy(), # ) + assert g1_12 is not None + assert h1_11 is not None + assert h1_12 is not None return g1_11, g1_12, h1_11, h1_12 @staticmethod @@ -1353,6 +1402,7 @@ def forward( ) else: gg1 = None + hh1 = None if self.update_chnnl_2: # mlp(g2) @@ -1389,6 +1439,7 @@ def forward( if self.update_g1_has_conv: assert gg1 is not None + assert hh1 is not None g11_conv, g12_conv, h11_conv, h12_conv = self._update_g1_conv( gg1, hh1, g2, h2, nlist_mask, sw ) @@ -1452,6 +1503,18 @@ def forward( g2_new, h2_new = g2, h2 g1_new = self.list_update(g1_update, "g1") h1_new = self.list_update(h1_update, "h1") + if self.output_g1_ln: + assert self.g1_ln is not None + g1_new = self.g1_ln(g1_new) + if self.output_g2_ln: + assert self.g2_ln is not None + g2_new = self.g2_ln(g2_new) + if self.output_h1_ln: + assert self.h1_ln is not None + h1_new = self.h1_ln(h1_new) + if self.output_h2_ln: + assert self.h2_ln is not None + h2_new = self.h2_ln(h2_new) return g1_new, g2_new, h1_new, h2_new @torch.jit.export @@ -1482,19 +1545,15 @@ def list_update_res_residual( uu = update_list[0] # make jit happy if update_name == "g1": - assert nitem == len(self.g1_residual) + 1 for ii, vv in enumerate(self.g1_residual): uu = uu + vv * update_list[ii + 1] elif update_name == "g2": - assert nitem == len(self.g2_residual) + 1 for ii, vv in enumerate(self.g2_residual): uu = uu + vv * update_list[ii + 1] elif update_name == "h1": - assert nitem == len(self.h1_residual) + 1 for ii, vv in enumerate(self.h1_residual): uu = uu + vv * update_list[ii + 1] elif update_name == "h2": - assert nitem == len(self.h2_residual) + 1 for ii, vv in enumerate(self.h2_residual): uu = uu + vv * update_list[ii + 1] else: diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 386c9f4c0b..a131ce8b3f 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -108,6 +108,12 @@ def __init__( use_sqrt_nnei: bool = True, g1_out_conv: bool = True, g1_out_mlp: bool = True, + output_g1_ln: bool = False, + output_g2_ln: bool = False, + output_h1_ln: bool = False, + output_h2_ln: bool = False, + update_h1_has_g1: bool = False, + update_h2_has_g2: bool = False, old_impl: bool = False, ): r""" @@ -234,6 +240,12 @@ def __init__( self.use_sqrt_nnei = use_sqrt_nnei self.g1_out_conv = g1_out_conv self.g1_out_mlp = g1_out_mlp + self.update_h1_has_g1 = update_h1_has_g1 + self.update_h2_has_g2 = update_h2_has_g2 + self.output_g1_ln = output_g1_ln + self.output_g2_ln = output_g2_ln + self.output_h1_ln = output_h1_ln + self.output_h2_ln = output_h2_ln # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection @@ -325,6 +337,12 @@ def __init__( use_sqrt_nnei=self.use_sqrt_nnei, g1_out_conv=self.g1_out_conv, g1_out_mlp=self.g1_out_mlp, + update_h1_has_g1=self.update_h1_has_g1, + update_h2_has_g2=self.update_h2_has_g2, + output_g1_ln=self.output_g1_ln, + output_g2_ln=self.output_g2_ln, + output_h1_ln=self.output_h1_ln, + output_h2_ln=self.output_h2_ln, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -496,6 +514,8 @@ def forward( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) ) mapping3 = mapping.view(nframes, nall, 1, self.g1_dim).expand(-1, -1, 3, -1) + else: + mapping3 = None for idx, ll in enumerate(self.layers): # g1: nb x nloc x ng1 # g1_ext: nb x nall x ng1 @@ -503,6 +523,7 @@ def forward( # h1_ext: nb x nall x ng1 x 3 if comm_dict is None: assert mapping is not None + assert mapping3 is not None g1_ext = torch.gather(g1, 1, mapping) h1_ext = torch.gather(h1, 1, mapping3) else: @@ -528,6 +549,8 @@ def forward( torch.tensor(nall - nloc), # pylint: disable=no-explicit-dtype,no-explicit-device ) g1_ext = ret[0].unsqueeze(0) + h1_ext = ret[0].unsqueeze(0) # place holder + assert h1_ext is not None g1, g2, h1, h2 = ll.forward( g1_ext, g2, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a799b6b0c4..161054c37d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1234,6 +1234,42 @@ def dpa2_repformer_args(): default=True, doc=doc_g1_out_mlp, ), + Argument( + "update_h1_has_g1", + bool, + optional=True, + default=False, + ), + Argument( + "update_h2_has_g2", + bool, + optional=True, + default=False, + ), + Argument( + "output_g1_ln", + bool, + optional=True, + default=False, + ), + Argument( + "output_g2_ln", + bool, + optional=True, + default=False, + ), + Argument( + "output_h1_ln", + bool, + optional=True, + default=False, + ), + Argument( + "output_h2_ln", + bool, + optional=True, + default=False, + ), Argument( "update_h2", bool,