Skip to content

Commit

Permalink
Merge branch 'main' into fix-name-dftdt
Browse files Browse the repository at this point in the history
  • Loading branch information
dft-dutoit committed Jul 11, 2024
2 parents a998814 + 97c355d commit 99ccd00
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 109 deletions.
20 changes: 0 additions & 20 deletions cgap17.py

This file was deleted.

40 changes: 0 additions & 40 deletions config.yaml

This file was deleted.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ classifiers = [
keywords = []
dependencies = [
"torch",
"torch-geometric",
"pytorch-lightning",
"ase==3.22.1", # TODO: unpin this? or pin other versions?
"numpy",
Expand Down
77 changes: 29 additions & 48 deletions src/graph_pes/models/schnet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import torch
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing

from graph_pes.graphs import AtomicGraph
from graph_pes.graphs.operations import neighbour_distances
from graph_pes.graphs.operations import (
neighbour_distances,
split_over_neighbours,
sum_over_neighbours,
)
from graph_pes.models.scaling import AutoScaledPESModel
from graph_pes.nn import MLP, PerElementEmbedding, ShiftedSoftplus

from .distances import DistanceExpansion, GaussianSmearing


class CFConv(MessagePassing):
class CFConv(torch.nn.Module):
r"""
CFConv: The Continous Filter Convolution
Expand Down Expand Up @@ -46,44 +48,24 @@ class CFConv(MessagePassing):
The filter function :math:`\mathbb{F}`.
"""

def __init__(self, filter_generator: nn.Module):
super().__init__(aggr="add")
def __init__(self, filter_generator: torch.nn.Module):
super().__init__()
self.filter_generator = filter_generator

def message(
self, x_j: torch.Tensor, neighbour_distances: torch.Tensor
) -> torch.Tensor:
"""
Apply the filter function to the distances and multiply by the
node features.
Parameters
----------
x_j : torch.Tensor
The node features of the neighbors.
neighbour_distances : torch.Tensor
The distances to the neighbors.
"""

return x_j * self.filter_generator(neighbour_distances)

def update(self, inputs: Tensor) -> Tensor:
"""
Identity update function.
"""
return inputs

def forward(
self,
neighbour_index: torch.Tensor,
neighbour_distances: torch.Tensor,
node_features: torch.Tensor,
) -> torch.Tensor:
return self.propagate(
neighbour_index,
neighbour_distances=neighbour_distances,
x=node_features,
)
node_features: torch.Tensor, # (n_atoms, F)
edge_distances: torch.Tensor, # (E,)
graph: AtomicGraph,
) -> torch.Tensor: # (n_atoms, F)
edge_features = self.filter_generator(edge_distances) # (E, F)
neighbour_features = split_over_neighbours(
node_features, graph
) # (E, F)

messages = neighbour_features * edge_features # (E, F)

return sum_over_neighbours(messages, graph)

def __repr__(self):
rep = f"CFConv({self.filter_generator})"
Expand All @@ -95,7 +77,7 @@ def __repr__(self):
return rep


class SchNetInteraction(nn.Module):
class SchNetInteraction(torch.nn.Module):
r"""
SchNet interaction block.
Expand Down Expand Up @@ -133,11 +115,11 @@ def __init__(
# schnet interaction block's are composed of 3 elements

# 1. linear transform to get new node features
self.linear = nn.Linear(n_features, n_features, bias=False)
self.linear = torch.nn.Linear(n_features, n_features, bias=False)

# 2. cfconv to mix these new features with distances information,
# and aggregate over neighbors to create completely new node features
filter_generator = nn.Sequential(
filter_generator = torch.nn.Sequential(
basis_type(expansion_features, cutoff),
MLP(
[expansion_features, n_features, n_features],
Expand All @@ -154,15 +136,15 @@ def __init__(

def forward(
self,
neighbour_index: torch.Tensor,
neighbour_distances: torch.Tensor,
node_features: torch.Tensor,
neighbour_distances: torch.Tensor,
graph: AtomicGraph,
):
# 1. linear transform to get new node features
h = self.linear(node_features)
# 2. cfconv to mix these new features with distances information,
# and aggregate over neighbors to create completely new node features
h = self.cfconv(neighbour_index, neighbour_distances, h)
h = self.cfconv(h, neighbour_distances, graph)
# 3. mlp to further embed these new node features
return self.mlp(h)

Expand Down Expand Up @@ -225,7 +207,7 @@ def __init__(

self.chemical_embedding = PerElementEmbedding(node_features)

self.interactions = nn.ModuleList(
self.interactions = torch.nn.ModuleList(
SchNetInteraction(
node_features, expansion_features, cutoff, expansion
)
Expand All @@ -239,10 +221,9 @@ def __init__(

def predict_unscaled_energies(self, graph: AtomicGraph) -> torch.Tensor:
h = self.chemical_embedding(graph["atomic_numbers"])
d = neighbour_distances(graph)

for interaction in self.interactions:
h = h + interaction(
graph["neighbour_index"], neighbour_distances(graph), h
)
h = h + interaction(h, d, graph)

return self.read_out(h)

0 comments on commit 99ccd00

Please sign in to comment.