Skip to content

Commit

Permalink
off line wandb, add gaussian prior draft and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Nov 17, 2024
1 parent c4c4030 commit 247a72f
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 92 deletions.
68 changes: 41 additions & 27 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,12 @@ def __init__(
node_feats_irreps: o3.Irreps,
target_irreps: o3.Irreps,
correlation: int,
learned_radials_dim: int,
use_sc: bool = True,
num_elements: Optional[int] = None,
tensor_format = "symmetric_cp",
flexible_feats_L = False,
gaussian_prior = False,
) -> None:
super().__init__()

Expand All @@ -269,25 +272,21 @@ def __init__(
irreps_out=target_irreps,
correlation=correlation,
num_elements=num_elements,
tensor_format=tensor_format
tensor_format=tensor_format,
flexible_feats_L=flexible_feats_L,
gaussian_prior=gaussian_prior,
)
# Update linear
if tensor_format in ["symmetric_cp", "non_symmetric_cp"]:
self.linear = o3.Linear(
target_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
)
elif tensor_format in ["flexible_symmetric_tucker", "symmetric_tucker", "non_symmetric_tucker"]:
if tensor_format == "flexible_symmetric_tucker":
tucker_irreps = make_tucker_irreps_flexible(target_irreps, correlation)
else:
tucker_irreps = make_tucker_irreps(target_irreps, correlation)
print("tucker irreps:", tucker_irreps)
print("target irreps:", target_irreps)
self.linear = o3.Linear(
tucker_irreps,
mid_irreps = target_irreps
elif tensor_format in ["flexible_non_symmetric_tucker", "flexible_symmetric_tucker",]:
mid_irreps = make_tucker_irreps_flexible(target_irreps, correlation)
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]:
mid_irreps = make_tucker_irreps(target_irreps, correlation)
else:
print("Tensor formatting not supported. Check your input")
self.linear = o3.Linear(
mid_irreps,
target_irreps,
internal_weights=True,
shared_weights=True,
Expand All @@ -298,8 +297,9 @@ def forward(
node_feats: torch.Tensor,
sc: Optional[torch.Tensor],
node_attrs: torch.Tensor,
learned_radials: torch.Tensor,
) -> torch.Tensor:
node_feats = self.symmetric_contractions(node_feats, node_attrs)
node_feats = self.symmetric_contractions(node_feats, node_attrs, learned_radials)
if self.use_sc and sc is not None:
return self.linear(node_feats) + sc
return self.linear(node_feats)
Expand Down Expand Up @@ -686,6 +686,7 @@ def _setup(self) -> None:
)

# Linear
# TODO: clena up unused reshape layer for flexible tucker formats later
irreps_mid = irreps_mid.simplify()
self.irreps_out = self.target_irreps

Expand All @@ -699,7 +700,7 @@ def _setup(self) -> None:
)
self.reshape = reshape_irreps(self.irreps_out)

elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker", "flexible_non_symmetric_tucker"]:
self.linear = torch.nn.ModuleList([])
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
Expand Down Expand Up @@ -736,15 +737,18 @@ def forward(
src=mji, index=receiver, dim=0, dim_size=num_nodes
) # [n_nodes, irreps]

if self.tensor_format in ["symmetric_cp", "symmetric_tucker", "flexible_symmetric_tucker"]:

if self.tensor_format in ["symmetric_cp", "symmetric_tucker",]:
message = self.linear(original_message) / self.avg_num_neighbors
if self.tensor_format in ["flexible_symmetric_tucker", ]:
return (message, sc)
else:
return (
self.tensor_format_layer(self.reshape(message)),
sc,
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
return (
self.tensor_format_layer(self.reshape(message)),
sc,
) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2]
elif self.tensor_format in ["flexible_symmetric_tucker"]:
message = self.linear(original_message) / self.avg_num_neighbors
# requires format contraction in SymmetricContraction - no reshape
# to [n_nodes, channels, (lmax + 1) ** 2 ] yet
return (message, sc)
elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
message = self.reshape[0](self.linear[0](original_message))
message = message.unsqueeze(-1)
Expand All @@ -753,7 +757,17 @@ def forward(
message = torch.cat((message, _message), dim = -1)
return (
message / self.avg_num_neighbors,
sc
sc,
)
elif self.tensor_format in ["flexible_non_symmetric_tucker"]:
message = self.linear[0](original_message) # [n_nodes, klm]
message = message.unsqueeze(-1) # [n_nnodes, klm, 1]
for idx in range(1, self.correlation):
_message = self.linear[idx](original_message).unsqueeze(-1)
message = torch.cat((message, _message), dim = -1)
return (
message / self.avg_num_neighbors,
sc,
)


Expand Down
2 changes: 1 addition & 1 deletion mace/modules/irreps_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def make_tucker_irreps_flexible(target_irreps, correlation):
tmp_irreps = o3.Irreps(str(ir))
num_feats = 0
for nu in range(1, correlation + 1):
num_feats += (math.ceil(ir.mul ** (1 / nu))) ** nu
num_feats += (int(ir.mul ** (1 / nu))) ** nu
tp_irreps += o3.Irreps(f"{num_feats}x{tmp_irreps[0].ir}")
return tp_irreps

Expand Down
27 changes: 17 additions & 10 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
radial_type: Optional[str] = "bessel",
heads: Optional[List[str]] = None,
tensor_format = "symmetric_cp",
flexible_feats_L = False,
gaussian_prior = False,
):
super().__init__()
self.register_buffer(
Expand Down Expand Up @@ -138,21 +140,18 @@ def __init__(
correlation=correlation[0],
num_elements=num_elements,
use_sc=use_sc_first,
learned_radials_dim=inter.conv_tp.weight_numel,
#
tensor_format=tensor_format,
flexible_feats_L=flexible_feats_L,
gaussian_prior=gaussian_prior,
)
self.products = torch.nn.ModuleList([prod])

self.readouts = torch.nn.ModuleList()
if tensor_format in ["symmetric_cp", "non_symmetric_cp"]:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker", "flexible_symmetric_tucker"]:
self.readouts.append(
#LinearReadoutBlock(make_tp_irreps(hidden_irreps, correlation[0]), o3.Irreps(f"{len(heads)}x0e"))
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)

for i in range(num_interactions - 1):
if i == num_interactions - 2:
Expand All @@ -174,7 +173,9 @@ def __init__(
hidden_irreps=hidden_irreps_out,
avg_num_neighbors=avg_num_neighbors,
radial_MLP=radial_MLP,
correlation=correlation[i + 1]
correlation=correlation[i + 1],
#
#tensor_format=tensor_format,
)
self.interactions.append(inter)
prod = EquivariantProductBasisBlock(
Expand All @@ -183,6 +184,12 @@ def __init__(
correlation=correlation[i + 1],
num_elements=num_elements,
use_sc=True,
learned_radials_dim=inter.conv_tp.weight_numel,
##
# tensor_format=tensor_format,
# flexible_feats_L=flexible_feats_L,
# gaussian_prior=gaussian_prior,
# learned_radials_dim=inter.conv_tp.weight_numel
)
self.products.append(prod)
if i == num_interactions - 2:
Expand Down
Loading

0 comments on commit 247a72f

Please sign in to comment.