Skip to content

Commit

Permalink
gpu compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Feb 6, 2024
1 parent 56fbe1a commit 03da8cb
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/graph_pes/data/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def neighbour_vectors(self) -> torch.Tensor:
return self._positions[j] - self._positions[i]

# otherwise calculate offsets on a per-structure basis
actual_offsets = torch.zeros((self.neighbour_index.shape[1], 3))
actual_offsets = torch.zeros(
(self.neighbour_index.shape[1], 3), device=i.device
)
# TODO: parallelise this loop
for batch, (start, end) in enumerate(pairs(self.ptr)):
mask = (i >= start) & (i < end)
Expand Down
22 changes: 22 additions & 0 deletions src/graph_pes/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,25 @@ class MAE(torch.nn.L1Loss):

def __init__(self):
super().__init__()


class MeanVectorPercentageError(torch.nn.Module):
r"""
Mean vector percentage error metric:
.. math::
\frac{1}{N} \sum_i^N \frac{\left{||} \hat{v}_i - v_i \right{||}}
{||v_i|| + \varepsilon}
"""

def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon

def forward(
self, input: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
return (
(input - target).norm(dim=-1)
/ (target.norm(dim=-1) + self.epsilon)
).mean()
9 changes: 6 additions & 3 deletions src/graph_pes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .data import AtomicGraph
from .data.batching import AtomicDataLoader, AtomicGraphBatch
from .loss import RMSE, Loss, WeightedLoss
from .transform import Chain, PerAtomScale, Scale
from .transform import PerAtomScale, PerAtomStandardScaler, Scale
from .util import Keys


Expand Down Expand Up @@ -226,7 +226,7 @@ def get_loss(
) -> WeightedLoss:
if loss is None:
default_transforms = {
Keys.ENERGY: Chain([PerAtomScale(), PerAtomScale()]),
Keys.ENERGY: PerAtomStandardScaler(), # TODO is this right?
Keys.FORCES: PerAtomScale(),
Keys.STRESS: Scale(),
}
Expand Down Expand Up @@ -271,7 +271,10 @@ def default_trainer_kwargs() -> dict:


def device_info_filter(record):
return "PU available: " not in record.getMessage()
return (
"PU available: " not in record.getMessage()
and "LOCAL_RANK" not in record.getMessage()
)


# disable verbose logging from pytorch lightning
Expand Down
1 change: 0 additions & 1 deletion src/graph_pes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""The maximum atomic number in the periodic table."""


T = TypeVar("T")


@overload
Expand Down

0 comments on commit 03da8cb

Please sign in to comment.