Skip to content

Commit

Permalink
painn fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Feb 6, 2024
1 parent 03da8cb commit cdd0cf0
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/graph_pes/models/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,27 @@ def forward(
Float[Tensor, "graph.n_atoms self.internal_dim"],
]:
central_atoms, neighbours = graph.neighbour_index
d = graph.neighbour_distances
unit_vectors = graph.neighbour_vectors / d.unsqueeze(-1)
d = graph.neighbour_distances.unsqueeze(-1)
unit_vectors = graph.neighbour_vectors / d

# continous filter message creation
x_ij = self.filter_generator(d) * self.φ(scalar_embeddings)
x_ij = self.filter_generator(d) * self.φ(scalar_embeddings)[neighbours]
a, b, c = torch.split(x_ij, self.internal_dim, dim=-1)

# simple sum over neighbours to get scalar messages
Δs = torch.zeros_like(scalar_embeddings)
Δs.scatter_add_(0, neighbours, a)
Δs.scatter_add_(0, neighbours.unsqueeze(-1).expand_as(a), a)

# create vector messages
v_ij = b * unit_vectors + c * vector_embeddings[neighbours]
v_ij = (
b.unsqueeze(-1) * unit_vectors.unsqueeze(1)
+ c.unsqueeze(-1) * vector_embeddings[neighbours]
)

Δv = torch.zeros_like(vector_embeddings)
Δv.scatter_add_(0, neighbours, v_ij)
Δv.scatter_add_(
0, neighbours.unsqueeze(-1).unsqueeze(-1).expand_as(v_ij), v_ij
)

return Δv, Δs

Expand Down Expand Up @@ -108,7 +113,7 @@ def forward(
Δv = u * a.unsqueeze(-1)

# scalar update:
dot = torch.sum(u * v, dim=1, keepdim=True) # u . v
dot = torch.sum(u * v, dim=-1)
Δs = b + c * dot

return Δv, Δs
Expand Down Expand Up @@ -151,14 +156,14 @@ def predict_local_energies(

for interaction, update in zip(self.interactions, self.updates):
Δv, Δs = interaction(vector_embeddings, scalar_embeddings, graph)
vector_embeddings += Δv
scalar_embeddings += Δs
vector_embeddings = vector_embeddings + Δv
scalar_embeddings = scalar_embeddings + Δs

Δv, Δs = update(
vector_embeddings,
scalar_embeddings,
)
vector_embeddings += Δv
scalar_embeddings += Δs
vector_embeddings = vector_embeddings + Δv
scalar_embeddings = scalar_embeddings + Δs

return self.read_out(scalar_embeddings)

0 comments on commit cdd0cf0

Please sign in to comment.