From 03da8cb2b1c7788c7d6c488a890d28fc63261e50 Mon Sep 17 00:00:00 2001 From: John Gardner Date: Tue, 6 Feb 2024 13:12:10 +0000 Subject: [PATCH] gpu compatibility --- src/graph_pes/data/batching.py | 4 +++- src/graph_pes/loss.py | 22 ++++++++++++++++++++++ src/graph_pes/training.py | 9 ++++++--- src/graph_pes/util.py | 1 - 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/graph_pes/data/batching.py b/src/graph_pes/data/batching.py index 6c7bbc39..b148429f 100644 --- a/src/graph_pes/data/batching.py +++ b/src/graph_pes/data/batching.py @@ -156,7 +156,9 @@ def neighbour_vectors(self) -> torch.Tensor: return self._positions[j] - self._positions[i] # otherwise calculate offsets on a per-structure basis - actual_offsets = torch.zeros((self.neighbour_index.shape[1], 3)) + actual_offsets = torch.zeros( + (self.neighbour_index.shape[1], 3), device=i.device + ) # TODO: parallelise this loop for batch, (start, end) in enumerate(pairs(self.ptr)): mask = (i >= start) & (i < end) diff --git a/src/graph_pes/loss.py b/src/graph_pes/loss.py index 28781a67..d5a17702 100644 --- a/src/graph_pes/loss.py +++ b/src/graph_pes/loss.py @@ -237,3 +237,25 @@ class MAE(torch.nn.L1Loss): def __init__(self): super().__init__() + + +class MeanVectorPercentageError(torch.nn.Module): + r""" + Mean vector percentage error metric: + + .. math:: + \frac{1}{N} \sum_i^N \frac{\left{||} \hat{v}_i - v_i \right{||}} + {||v_i|| + \varepsilon} + """ + + def __init__(self, epsilon: float = 1e-6): + super().__init__() + self.epsilon = epsilon + + def forward( + self, input: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: + return ( + (input - target).norm(dim=-1) + / (target.norm(dim=-1) + self.epsilon) + ).mean() diff --git a/src/graph_pes/training.py b/src/graph_pes/training.py index 0d87d534..2058197e 100644 --- a/src/graph_pes/training.py +++ b/src/graph_pes/training.py @@ -13,7 +13,7 @@ from .data import AtomicGraph from .data.batching import AtomicDataLoader, AtomicGraphBatch from .loss import RMSE, Loss, WeightedLoss -from .transform import Chain, PerAtomScale, Scale +from .transform import PerAtomScale, PerAtomStandardScaler, Scale from .util import Keys @@ -226,7 +226,7 @@ def get_loss( ) -> WeightedLoss: if loss is None: default_transforms = { - Keys.ENERGY: Chain([PerAtomScale(), PerAtomScale()]), + Keys.ENERGY: PerAtomStandardScaler(), # TODO is this right? Keys.FORCES: PerAtomScale(), Keys.STRESS: Scale(), } @@ -271,7 +271,10 @@ def default_trainer_kwargs() -> dict: def device_info_filter(record): - return "PU available: " not in record.getMessage() + return ( + "PU available: " not in record.getMessage() + and "LOCAL_RANK" not in record.getMessage() + ) # disable verbose logging from pytorch lightning diff --git a/src/graph_pes/util.py b/src/graph_pes/util.py index f0738033..48374b46 100644 --- a/src/graph_pes/util.py +++ b/src/graph_pes/util.py @@ -15,7 +15,6 @@ """The maximum atomic number in the periodic table.""" -T = TypeVar("T") @overload