Skip to content

Commit

Permalink
Keys -> Property
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Feb 7, 2024
1 parent fdfe5e5 commit a21d9ea
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 56 deletions.
12 changes: 6 additions & 6 deletions src/graph_pes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .data.atomic_graph import AtomicGraph, convert_to_atomic_graphs
from .data.batching import AtomicGraphBatch
from .transform import Identity, Transform
from .util import Keys
from .util import Property

_my_style = {
"figure.figsize": (3.5, 3),
Expand Down Expand Up @@ -75,7 +75,7 @@ def move_axes(ax: plt.Axes | None = None): # type: ignore
def parity_plot(
model: GraphPESModel,
graphs: AtomicGraphBatch | list[AtomicGraph],
property: Keys,
property: Property,
property_label: str | None = None,
transform: Transform | None = None,
units: str | None = None,
Expand All @@ -93,11 +93,11 @@ def parity_plot(
graphs
The graphs to make predictions on.
property
The property to plot, e.g. :code:`Keys.ENERGY`.
The property to plot, e.g. :code:`Property.ENERGY`.
property_label
The string that the property is indexed by on the graphs. If not
provided, defaults to the value of :code:`property`, e.g.
:code:`Keys.ENERGY` :math:`\rightarrow` :code:`"energy"`.
:code:`Property.ENERGY` :math:`\rightarrow` :code:`"energy"`.
transform
The transform to apply to the predictions and labels before plotting.
If not provided, no transform is applied.
Expand All @@ -115,7 +115,7 @@ def parity_plot(
.. code-block:: python
parity_plot(model, train, Keys.ENERGY)
parity_plot(model, train, Property.ENERGY)
.. image:: notebooks/Cu-LJ-default-parity.svg
:align: center
Expand All @@ -136,7 +136,7 @@ def parity_plot(
parity_plot(
model,
data,
Keys.ENERGY,
Property.ENERGY,
transform=DividePerAtom(),
units="eV / atom",
label=name,
Expand Down
26 changes: 13 additions & 13 deletions src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PerAtomShift,
Transform,
)
from graph_pes.util import Keys, differentiate, require_grad
from graph_pes.util import Property, differentiate, require_grad
from jaxtyping import Float
from torch import Tensor, nn

Expand Down Expand Up @@ -164,7 +164,7 @@ def __repr__(self):
def get_predictions(
pes: GraphPESModel,
structure: AtomicGraph | AtomicGraphBatch | list[AtomicGraph],
property_labels: dict[Keys, str] | None = None,
property_labels: dict[Property, str] | None = None,
) -> dict[str, torch.Tensor]:
"""
Evaluate the `pes` on `structure` to get the labels requested.
Expand Down Expand Up @@ -195,20 +195,20 @@ def get_predictions(

if property_labels is None:
property_labels = {
Keys.ENERGY: "energy",
Keys.FORCES: "forces",
Property.ENERGY: "energy",
Property.FORCES: "forces",
}
if structure.has_cell:
property_labels[Keys.STRESS] = "stress"
property_labels[Property.STRESS] = "stress"

else:
if Keys.STRESS in property_labels and not structure.has_cell:
if Property.STRESS in property_labels and not structure.has_cell:
raise ValueError("Can't predict stress without cell information.")

predictions = {}

# setup for calculating stress:
if Keys.STRESS in property_labels:
if Property.STRESS in property_labels:
# The virial stress tensor is the gradient of the total energy wrt
# an infinitesimal change in the cell parameters.
# We therefore add this change to the cell, such that
Expand All @@ -229,15 +229,15 @@ def get_predictions(
with require_grad(structure._positions), require_grad(change_to_cell):
energy = pes(structure)

if Keys.ENERGY in property_labels:
predictions[property_labels[Keys.ENERGY]] = energy
if Property.ENERGY in property_labels:
predictions[property_labels[Property.ENERGY]] = energy

if Keys.FORCES in property_labels:
if Property.FORCES in property_labels:
dE_dR = differentiate(energy, structure._positions)
predictions[property_labels[Keys.FORCES]] = -dE_dR
predictions[property_labels[Property.FORCES]] = -dE_dR

if Keys.STRESS in property_labels:
if Property.STRESS in property_labels:
stress = differentiate(energy, change_to_cell)
predictions[property_labels[Keys.STRESS]] = stress
predictions[property_labels[Property.STRESS]] = stress

return predictions
38 changes: 19 additions & 19 deletions src/graph_pes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .data.batching import AtomicDataLoader, AtomicGraphBatch
from .loss import RMSE, Loss, WeightedLoss
from .transform import PerAtomScale, PerAtomStandardScaler, Scale
from .util import Keys
from .util import Property


def train_model(
Expand All @@ -24,7 +24,7 @@ def train_model(
optimizer: Callable[[], torch.optim.Optimizer | OptimizerLRSchedulerConfig]
| None = None,
loss: WeightedLoss | Loss | None = None,
property_labels: dict[Keys, str] | None = None,
property_labels: dict[Property, str] | None = None,
*,
batch_size: int = 32,
pre_fit_model: bool = True,
Expand All @@ -46,7 +46,7 @@ def train_model(
if property_labels is None:
property_labels = get_existing_keys(batch)
if not property_labels:
expected = [key.value for key in Keys.__members__.values()]
expected = [key.value for key in Property.__members__.values()]
raise ValueError(
"No property_keys were provided, and none were found in "
f"the data. Expected at least one of: {expected}"
Expand All @@ -62,17 +62,17 @@ def train_model(
)

expected_shapes = {
Keys.ENERGY: (batch.n_structures,),
Keys.FORCES: (batch.n_atoms, 3),
Keys.STRESS: (batch.n_structures, 3, 3),
Property.ENERGY: (batch.n_structures,),
Property.FORCES: (batch.n_atoms, 3),
Property.STRESS: (batch.n_structures, 3, 3),
}
for key, label in property_labels.items():
if batch[label].shape != expected_shapes[key]:
raise ValueError(
f"Expected {label} to have shape {expected_shapes[key]}, "
f"but found {batch[label].shape}"
)
if Keys.STRESS in property_labels and not batch.has_cell:
if Property.STRESS in property_labels and not batch.has_cell:
raise ValueError("Can't train on stress without cell information.")

# create the data loaders
Expand All @@ -85,8 +85,8 @@ def train_model(

# deal with fitting transforms
# TODO: what if not training on energy?
if pre_fit_model and Keys.ENERGY in property_labels:
model.pre_fit(batch, property_labels[Keys.ENERGY])
if pre_fit_model and Property.ENERGY in property_labels:
model.pre_fit(batch, property_labels[Property.ENERGY])

actual_loss = get_loss(loss, property_labels)
actual_loss.fit_transform(batch)
Expand Down Expand Up @@ -120,10 +120,10 @@ def train_model(
return task.load_best_weights(model, trainer)


def get_existing_keys(batch: AtomicGraphBatch) -> dict[Keys, str]:
def get_existing_keys(batch: AtomicGraphBatch) -> dict[Property, str]:
return {
key: key.value
for key in Keys.__members__.values()
for key in Property.__members__.values()
if key.value in batch
}

Expand All @@ -134,7 +134,7 @@ def __init__(
model: GraphPESModel,
optimizer: torch.optim.Optimizer | OptimizerLRSchedulerConfig,
loss: WeightedLoss,
property_labels: dict[Keys, str],
property_labels: dict[Property, str],
):
super().__init__()
self.model = model
Expand Down Expand Up @@ -222,18 +222,18 @@ def load_best_weights(


def get_loss(
loss: WeightedLoss | Loss | None, property_labels: dict[Keys, str]
loss: WeightedLoss | Loss | None, property_labels: dict[Property, str]
) -> WeightedLoss:
if loss is None:
default_transforms = {
Keys.ENERGY: PerAtomStandardScaler(), # TODO is this right?
Keys.FORCES: PerAtomScale(),
Keys.STRESS: Scale(),
Property.ENERGY: PerAtomStandardScaler(), # TODO is this right?
Property.FORCES: PerAtomScale(),
Property.STRESS: Scale(),
}
default_weights = {
Keys.ENERGY: 1.0,
Keys.FORCES: 1.0,
Keys.STRESS: 1.0,
Property.ENERGY: 1.0,
Property.FORCES: 1.0,
Property.STRESS: 1.0,
}
return WeightedLoss(
[
Expand Down
4 changes: 1 addition & 3 deletions src/graph_pes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
"""The maximum atomic number in the periodic table."""




@overload
def pairs(a: Sequence[T]) -> Iterator[tuple[T, T]]:
...
Expand Down Expand Up @@ -122,7 +120,7 @@ def as_possible_tensor(value: object) -> Tensor | None:
return None


class Keys(Enum):
class Property(Enum):
ENERGY = "energy"
FORCES = "forces"
STRESS = "stress"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from graph_pes.data.batching import AtomicGraphBatch
from graph_pes.models.pairwise import LennardJones
from graph_pes.training import Loss, train_model
from graph_pes.util import Keys
from graph_pes.util import Property


def test_integration():
Expand All @@ -15,7 +15,7 @@ def test_integration():
model = LennardJones()

loss = Loss("energy")
property_labels = {Keys.ENERGY: "energy"}
property_labels = {Property.ENERGY: "energy"}
before = loss(
get_predictions(model, _batch, property_labels),
_batch,
Expand Down
28 changes: 15 additions & 13 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from graph_pes.core import get_predictions
from graph_pes.data import AtomicGraphBatch, convert_to_atomic_graph
from graph_pes.models.pairwise import LennardJones
from graph_pes.util import Keys
from graph_pes.util import Property

no_pbc = convert_to_atomic_graph(
Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=False),
Expand All @@ -18,9 +18,9 @@

def test_predictions():
expected_shapes = {
Keys.ENERGY: (),
Keys.FORCES: (2, 3),
Keys.STRESS: (3, 3),
Property.ENERGY: (),
Property.FORCES: (2, 3),
Property.STRESS: (3, 3),
}

model = LennardJones()
Expand All @@ -30,39 +30,41 @@ def test_predictions():
predictions = get_predictions(model, no_pbc)
assert set(predictions.keys()) == {"energy", "forces"}

for key in Keys.ENERGY, Keys.FORCES:
for key in Property.ENERGY, Property.FORCES:
assert predictions[key.value].shape == expected_shapes[key]

# if we ask for stress, we get an error:
with pytest.raises(ValueError):
get_predictions(model, no_pbc, {Keys.STRESS: "stress"})
get_predictions(model, no_pbc, {Property.STRESS: "stress"})

# with pbc structures, we should get all three predictions
predictions = get_predictions(model, pbc)
assert set(predictions.keys()) == {"energy", "forces", "stress"}

for key in Keys.ENERGY, Keys.FORCES, Keys.STRESS:
for key in Property.ENERGY, Property.FORCES, Property.STRESS:
assert predictions[key.value].shape == expected_shapes[key]

# check that requesting a subset of predictions works, and that
# the names are correctly mapped:
predictions = get_predictions(model, no_pbc, {Keys.ENERGY: "total_energy"})
predictions = get_predictions(
model, no_pbc, {Property.ENERGY: "total_energy"}
)
assert set(predictions.keys()) == {"total_energy"}
assert predictions["total_energy"].shape == expected_shapes[Keys.ENERGY]
assert predictions["total_energy"].shape == expected_shapes[Property.ENERGY]


def test_batched_prediction():
batch = AtomicGraphBatch.from_graphs([pbc, pbc])

expected_shapes = {
Keys.ENERGY: (2,), # two structures
Keys.FORCES: (4, 3), # four atoms
Keys.STRESS: (2, 3, 3), # two structures
Property.ENERGY: (2,), # two structures
Property.FORCES: (4, 3), # four atoms
Property.STRESS: (2, 3, 3), # two structures
}

predictions = get_predictions(LennardJones(), batch)

for key in Keys.ENERGY, Keys.FORCES, Keys.STRESS:
for key in Property.ENERGY, Property.FORCES, Property.STRESS:
assert predictions[key.value].shape == expected_shapes[key]


Expand Down

0 comments on commit a21d9ea

Please sign in to comment.