Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Jul 10, 2024
1 parent 774a225 commit c3157d4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/graph_pes/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ALL_MODELS: list[type[GraphPESModel]] = [
globals()[model]
for model in __all__
if model not in ["FixedOffset", "LearnableOffset"]
if model not in ["FixedOffset", "LearnableOffset", "load_model"]
]


Expand Down
2 changes: 2 additions & 0 deletions src/graph_pes/models/e3nn/nequip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

import e3nn
import e3nn.util.jit
import torch
from e3nn import o3
from graph_pes.core import GraphPESModel
Expand Down Expand Up @@ -321,6 +322,7 @@ def __call__(
)


@e3nn.util.jit.compile_mode("script")
class NequIP(GraphPESModel):
def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion src/graph_pes/models/pairwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

import torch
from jaxtyping import Float
Expand Down Expand Up @@ -141,7 +142,12 @@ def sigma(self):

# don't use Z_i and Z_j, but include them for consistency with the
# abstract method
def interaction(self, r: torch.Tensor, Z_i=None, Z_j=None):
def interaction(
self,
r: torch.Tensor,
Z_i: Optional[torch.Tensor] = None,
Z_j: Optional[torch.Tensor] = None,
):
"""
Evaluate the pair potential.
Expand Down
20 changes: 11 additions & 9 deletions src/graph_pes/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,15 +471,17 @@ def register_Zs(self, Zs: list[int]):

def forward(self, Z: Tensor) -> Tensor:
internal_idx = self.Z_to_idx[Z]
try:
return torch.nn.functional.one_hot(internal_idx, self.n_elements)
except IndexError:
raise ValueError(
f"Unknown atomic number: {sorted(set(Z.tolist()))}. "
f"Expected {self.n_elements} elements. "
"Did you forget to call `register_Zs`?"
) from None

return torch.nn.functional.one_hot(internal_idx, self.n_elements)

Check warning on line 474 in src/graph_pes/nn.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/nn.py#L473-L474

Added lines #L473 - L474 were not covered by tests
# try:
# return torch.nn.functional.one_hot(internal_idx, self.n_elements)
# except IndexError:
# raise ValueError(
# f"Unknown atomic number: {sorted(set(Z.tolist()))}. "
# f"Expected {self.n_elements} elements. "
# "Did you forget to call `register_Zs`?"
# ) from None

@torch.jit.unused
@property
def registered_elements(self) -> list[str]:
return [chemical_symbols[Z] for Z in self.Z_to_idx if Z <= MAX_Z]

Check warning on line 487 in src/graph_pes/nn.py

View check run for this annotation

Codecov / codecov/patch

src/graph_pes/nn.py#L487

Added line #L487 was not covered by tests

0 comments on commit c3157d4

Please sign in to comment.