Skip to content

Commit

Permalink
add h1 channel and the convolution updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Oct 4, 2024
1 parent 7ce5b03 commit 7b065bf
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 23 deletions.
182 changes: 162 additions & 20 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ def __init__(
self.update_residual_init = update_residual_init
self.g1_residual = []
self.g2_residual = []
self.h1_residual = []
self.h2_residual = []

if self.update_style == "res_residual":
Expand All @@ -675,6 +676,7 @@ def __init__(
)
self.linear2 = None
self.proj_g1g2 = None
self.proj_h1h2 = None
self.proj_g1g1g2 = None
self.attn2g_map = None
self.attn2_mh_apply = None
Expand Down Expand Up @@ -733,7 +735,14 @@ def __init__(
g1_dim,
bias=False,
precision=precision,
seed=child_seed(seed, 4),
seed=child_seed(child_seed(seed, 4), 0),
)
self.proj_h1h2 = MLPLayer(
g2_dim,
g1_dim,
bias=False,
precision=precision,
seed=child_seed(child_seed(seed, 4), 1),
)
if self.update_style == "res_residual":
self.g1_residual.append(
Expand All @@ -742,7 +751,16 @@ def __init__(
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 17),
seed=child_seed(child_seed(seed, 17), 0),
)
)
self.h1_residual.append(
get_residual(
g1_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(child_seed(seed, 17), 1),
)
)
if self.update_g2_has_g1g1:
Expand Down Expand Up @@ -831,6 +849,7 @@ def __init__(

self.g1_residual = nn.ParameterList(self.g1_residual)
self.g2_residual = nn.ParameterList(self.g2_residual)
self.h1_residual = nn.ParameterList(self.h1_residual)
self.h2_residual = nn.ParameterList(self.h2_residual)

def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int:
Expand Down Expand Up @@ -865,8 +884,10 @@ def _update_h2(

def _update_g1_conv(
self,
gg1: torch.Tensor,
g2: torch.Tensor,
gg1: torch.Tensor, # nb x nloc x nnei x ng1
hh1: torch.Tensor, # nb x nloc x nnei x 3 x ng1
g2: torch.Tensor, # nb x nloc x nnei x ng2
h2: torch.Tensor, # nb x nloc x nnei x 3 x ng2
nlist_mask: torch.Tensor,
sw: torch.Tensor,
) -> torch.Tensor:
Expand All @@ -886,16 +907,24 @@ def _update_g1_conv(
and remains 0 beyond rcut, with shape nb x nloc x nnei.
"""
assert self.proj_g1g2 is not None
assert self.proj_h1h2 is not None
nb, nloc, nnei, _ = g2.shape
ng1 = gg1.shape[-1]
ng2 = g2.shape[-1]
if not self.g1_out_conv:
# gg1 : nb x nloc x nnei x ng2
gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2)
else:
# gg1 : nb x nloc x nnei x ng1
gg1 = gg1.view(nb, nloc, nnei, ng1)
# hh1 : nb x nloc x nnei x 3 x ng1
hh1 = hh1.view(nb, nloc, nnei, 3, ng1)
# nb x nloc x nnei x ng2/ng1
gg1 = _apply_nlist_mask(gg1, nlist_mask)
hh1 = _apply_nlist_mask(
hh1.view(nb, nloc, nnei, 3 * ng1),
nlist_mask,
).view(nb, nloc, nnei, 3, ng1)
if not self.smooth:
# normalized by number of neighbors, not smooth
# nb x nloc x 1
Expand All @@ -905,17 +934,111 @@ def _update_g1_conv(
).unsqueeze(-1)
else:
gg1 = _apply_switch(gg1, sw)
hh1 = _apply_switch(hh1.view(nb, nloc, nnei, 3 * ng1), sw).view(
nb, nloc, nnei, 3, ng1
)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1), dtype=gg1.dtype, device=gg1.device
(nb, nloc), dtype=gg1.dtype, device=gg1.device
)
if not self.g1_out_conv:
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
g1_1 = g1_11
h1_1 = None
else:
g2 = self.proj_g1g2(g2).view(nb, nloc, nnei, ng1)
# nb x nloc x ng1
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
return g1_11
g2 = self.proj_g1g2(g2).view(nb, nloc, nnei, ng1)
# nb x nloc x 3 x ng1
h2 = self.proj_h1h2(h2).view(nb, nloc, nnei, 3, ng1)
##
# gg1 nb x nloc x nnei x ng1
# hh1: nb x nloc x nnei x 3 x ng1
# g2: nb x nloc x nnei x ng1
# h2: nb x nloc x nnei x 3 x ng1
if False:
## low efficiency implementation, but easy to read
## the g part: 0.0 + 1.1
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei[:, :, None]
g1_12 = (
torch.einsum("ijkdc,ijkdc->ijc", h2, hh1)
* invnnei[:, :, None]
/ 3.0
)
# nb x nloc x ng1
g1_1 = g1_11 + g1_12
## the h part: 0.1 + 1.0
h1_11 = (
torch.einsum("ijkc,ijkdc->ijdc", g2, hh1)
* invnnei[:, :, None, None]
)
h1_12 = (
torch.einsum("ijkdc,ijkc->ijdc", h2, gg1)
* invnnei[:, :, None, None]
)
# nb x nloc x 3 x ng1
h1_1 = h1_11 + h1_12
else:
## implementation via batch matmul
g1_11 = (
torch.bmm(
torch.reshape(
torch.permute(g2, (0, 1, 3, 2)), (nb * nloc * ng1, 1, nnei)
),
torch.reshape(
torch.permute(gg1, (0, 1, 3, 2)), (nb * nloc * ng1, nnei, 1)
),
).view(nb, nloc, ng1)
* invnnei[:, :, None]
)
g1_12 = (
torch.bmm(
torch.reshape(
torch.permute(h2, (0, 1, 4, 2, 3)),
(nb * nloc * ng1, 1, nnei * 3),
),
torch.reshape(
torch.permute(hh1, (0, 1, 4, 2, 3)),
(nb * nloc * ng1, nnei * 3, 1),
),
).view(nb, nloc, ng1)
* invnnei[:, :, None]
/ 3.0
)
# nb x nloc x ng1
g1_1 = g1_11 + g1_12
## the h part: 0.1 + 1.0
tg2 = g2.unsqueeze(-2).expand(-1, -1, -1, 3, -1)
h1_11 = (
torch.bmm(
torch.reshape(
torch.permute(tg2, (0, 1, 3, 4, 2)),
(nb * nloc * 3 * ng1, 1, nnei),
),
torch.reshape(
torch.permute(hh1, (0, 1, 3, 4, 2)),
(nb * nloc * 3 * ng1, nnei, 1),
),
).view(nb, nloc, 3, ng1)
* invnnei[:, :, None, None]
)
tg1 = gg1.unsqueeze(-2).expand(-1, -1, -1, 3, -1)
h1_12 = (
torch.bmm(
torch.reshape(
torch.permute(h2, (0, 1, 3, 4, 2)),
(nb * nloc * 3 * ng1, 1, nnei),
),
torch.reshape(
torch.permute(tg1, (0, 1, 3, 4, 2)),
(nb * nloc * 3 * ng1, nnei, 1),
),
).view(nb, nloc, 3, ng1)
* invnnei[:, :, None, None]
)
# nb x nloc x 3 x ng1
h1_1 = h1_11 + h1_12
# print(torch.std(g1_11), torch.std(g1_12), torch.std(h1_11), torch.std(h1_12))
return g1_1, h1_1

@staticmethod
def _cal_hg(
Expand Down Expand Up @@ -1098,26 +1221,29 @@ def forward(
self,
g1_ext: torch.Tensor, # nf x nall x ng1
g2: torch.Tensor, # nf x nloc x nnei x ng2
h2: torch.Tensor, # nf x nloc x nnei x 3
h1_ext: torch.Tensor, # nf x nall x 3 x ng1
h2: torch.Tensor, # nf x nloc x nnei x 3 x ng2
nlist: torch.Tensor, # nf x nloc x nnei
nlist_mask: torch.Tensor, # nf x nloc x nnei
sw: torch.Tensor, # switch func, nf x nloc x nnei
):
"""
Parameters
----------
g1_ext : nf x nall x ng1 extended single-atom chanel
g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant
h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant
g1_ext : nf x nall x ng1 extended single-atom chanel, scalar
g2 : nf x nloc x nnei x ng2 pair-atom channel, scalar
h1_ext : nf x nall x 3 x ng1 extended single-atom chanel, vector
h2 : nf x nloc x nnei x 3 x ng2 pair-atom channel, vector
nlist : nf x nloc x nnei neighbor list (padded neis are set to 0)
nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0
sw : nf x nloc x nnei switch function
Returns
-------
g1: nf x nloc x ng1 updated single-atom chanel
g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant
h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant
g1: nf x nloc x ng1 updated single-atom chanel, scalar
g2: nf x nloc x nnei x ng2 updated pair-atom channel, scalar
h1: nf x nloc x 3 x ng1 updated single-atom chanel, vector
h2: nf x nloc x nnei x 3 x ng2 updated pair-atom channel, vector
"""
cal_gg1 = (
self.update_g1_has_drrd
Expand All @@ -1128,21 +1254,29 @@ def forward(

nb, nloc, nnei, _ = g2.shape
nall = g1_ext.shape[1]
ng1 = g1_ext.shape[-1]
g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1)
h1, _ = torch.split(h1_ext, [nloc, nall - nloc], dim=1)
assert (nb, nloc) == g1.shape[:2]
assert (nb, nloc) == h1.shape[:2]
assert (nb, nloc, nnei) == h2.shape[:3]

g2_update: List[torch.Tensor] = [g2]
h2_update: List[torch.Tensor] = [h2]
g1_update: List[torch.Tensor] = [g1]
h1_update: List[torch.Tensor] = [h1]
g1_mlp: List[torch.Tensor] = [g1] if not self.g1_out_mlp else []

if self.g1_out_mlp:
assert self.g1_self_mlp is not None
g1_self_mlp = self.act(self.g1_self_mlp(g1))
g1_update.append(g1_self_mlp)

if cal_gg1:
gg1 = _make_nei_g1(g1_ext, nlist)
hh1 = _make_nei_g1(h1_ext.view([nb, nall, 3 * ng1]), nlist).view(
[nb, nloc, nnei, 3, ng1]
)
else:
gg1 = None

Expand Down Expand Up @@ -1178,20 +1312,20 @@ def forward(
if self.update_h2:
# linear_head(attention_weights * h2)
h2_update.append(self._update_h2(h2, AAg))

if self.update_g1_has_conv:
assert gg1 is not None
g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw)
g1_conv, h1_conv = self._update_g1_conv(gg1, hh1, g2, h2, nlist_mask, sw)
if not self.g1_out_conv:
g1_mlp.append(g1_conv)
else:
g1_update.append(g1_conv)
h1_update.append(h1_conv)

if self.update_g1_has_grrg:
g1_mlp.append(
self.symmetrization_op(
g2,
h2,
h2[..., 0],
nlist_mask,
sw,
self.axis_neuron,
Expand All @@ -1205,7 +1339,7 @@ def forward(
g1_mlp.append(
self.symmetrization_op(
gg1,
h2,
h2[..., 0],
nlist_mask,
sw,
self.axis_neuron,
Expand All @@ -1231,7 +1365,8 @@ def forward(
else:
g2_new, h2_new = g2, h2
g1_new = self.list_update(g1_update, "g1")
return g1_new, g2_new, h2_new
h1_new = self.list_update(h1_update, "h1")
return g1_new, g2_new, h1_new, h2_new

@torch.jit.export
def list_update_res_avg(
Expand Down Expand Up @@ -1261,12 +1396,19 @@ def list_update_res_residual(
uu = update_list[0]
# make jit happy
if update_name == "g1":
assert nitem == len(self.g1_residual) + 1
for ii, vv in enumerate(self.g1_residual):
uu = uu + vv * update_list[ii + 1]
elif update_name == "g2":
assert nitem == len(self.g2_residual) + 1
for ii, vv in enumerate(self.g2_residual):
uu = uu + vv * update_list[ii + 1]
elif update_name == "h1":
assert nitem == len(self.h1_residual) + 1
for ii, vv in enumerate(self.h1_residual):
uu = uu + vv * update_list[ii + 1]
elif update_name == "h2":
assert nitem == len(self.h2_residual) + 1
for ii, vv in enumerate(self.h2_residual):
uu = uu + vv * update_list[ii + 1]
else:
Expand Down
Loading

0 comments on commit 7b065bf

Please sign in to comment.