Skip to content

Commit

Permalink
Add analysis module and useful transform for dividing property by the…
Browse files Browse the repository at this point in the history
… number of atoms
  • Loading branch information
jla-gardner committed Jan 16, 2024
1 parent 8e8d52b commit dea4d52
Show file tree
Hide file tree
Showing 10 changed files with 5,298 additions and 382 deletions.
6 changes: 6 additions & 0 deletions docs/source/analysis.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
########
Analysis
########


.. autofunction:: graph_pes.analysis.parity_plot
4 changes: 4 additions & 0 deletions docs/source/data/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Available Transforms
:members:
Useful Other Transforms
=======================

.. autoclass :: graph_pes.transform.DividePerAtom()
.. autoclass :: graph_pes.transform.Chain
.. autoclass :: graph_pes.transform.Identity()
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
models
data
training
analysis
examples


########
GraphPES
Expand Down
763 changes: 763 additions & 0 deletions docs/source/notebooks/Cu-LJ-default-parity.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,066 changes: 1,066 additions & 0 deletions docs/source/notebooks/Cu-LJ-parity.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,590 changes: 3,273 additions & 317 deletions docs/source/notebooks/example.ipynb

Large diffs are not rendered by default.

214 changes: 150 additions & 64 deletions src/graph_pes/analysis.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,155 @@
from __future__ import annotations

from typing import Any
from functools import wraps

import matplotlib.pyplot as plt
import torch
from graph_pes.core import GraphPESModel, energy_and_forces
from graph_pes.data import AtomicGraph, AtomicGraphBatch
from graph_pes.transform import (
Chain,
PerSpeciesScale,
PerSpeciesShift,
Transform,
)


def parity_plots(
from matplotlib.ticker import MaxNLocator

from .core import GraphPESModel, get_predictions
from .data.atomic_graph import AtomicGraph
from .data.batching import AtomicGraphBatch
from .transform import Identity, Transform
from .util import Keys


def my_style(func):
"""
Decorator to use my home-made plt style within a function
"""

style = {
"figure.figsize": (3, 3),
"axes.spines.right": False,
"axes.spines.top": False,
}

@wraps(func)
def wrapper(*args, **kwargs):
with plt.rc_context(style):
return func(*args, **kwargs)

return wrapper


def move_axes(ax):
"""
Move the axes to the center of the figure
"""
ax.spines["left"].set_position(("outward", 10))
ax.spines["bottom"].set_position(("outward", 10))


@my_style
def parity_plot(
model: GraphPESModel,
graphs: list[AtomicGraph],
E_transform: Transform | None = None,
F_transform: Transform | None = None,
axs: tuple[plt.Axes, plt.Axes] | None = None,
E_kwargs: dict[str, Any] | None = None,
F_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> tuple[plt.Axes, plt.Axes]:
if E_transform is None:
E_transform = Chain([PerSpeciesScale(), PerSpeciesShift()])
if F_transform is None:
F_transform = PerSpeciesScale()

batch = AtomicGraphBatch.from_graphs(graphs)

true_E, true_F = batch.labels["energy"], batch.labels["forces"]

E_transform.fit_to_target(true_E, batch)
F_transform.fit_to_target(true_F, batch)

preds = energy_and_forces(model, batch)
pred_E, pred_F = preds["energy"].detach(), preds["forces"].detach()

with torch.no_grad():
scaled_pred_E = E_transform.inverse(pred_E, batch)
scaled_true_E = E_transform.inverse(true_E, batch)
scaled_pred_F = F_transform.inverse(pred_F, batch)
scaled_true_F = F_transform.inverse(true_F, batch)

if axs is None:
_, axs = plt.subplots(1, 2, figsize=(6, 3)) # type: ignore

E_ax, F_ax = axs # type: ignore

E_defaults = dict(marker="+")
E_kwargs = {**E_defaults, **kwargs, **(E_kwargs or {})}
E_ax.scatter(scaled_true_E, scaled_pred_E, **E_kwargs)
E_ax.axline((0, 0), slope=1, color="k", ls="--", lw=1)
E_ax.set_aspect("equal", "datalim")
E_ax.set_xlabel(r"$E$ (a.u.)")
E_ax.set_ylabel(r"$\tilde{E}$ (a.u.)")

F_defaults = dict(lw=0, s=3, alpha=0.2)
F_kwargs = {**F_defaults, **kwargs, **(F_kwargs or {})}
F_ax.scatter(scaled_true_F, scaled_pred_F, **F_kwargs)
F_ax.axline((0, 0), slope=1, color="k", ls="--", lw=1)
F_ax.set_aspect("equal", "datalim")
F_ax.set_xlabel(r"$F$ (a.u.)")
F_ax.set_ylabel(r"$\tilde{F}$ (a.u.)")

return axs # type: ignore
graphs: AtomicGraphBatch | list[AtomicGraph],
property: Keys,
property_label: str | None = None,
transform: Transform | None = None,
units: str | None = None,
ax: plt.Axes | None = None, # type: ignore
**scatter_kwargs,
):
r"""
A nicely formatted parity plot of model predictions vs ground truth
for the given :code:`property`.
Parameters
----------
model
The model to for generating predictions.
graphs
The graphs to make predictions on.
property
The property to plot, e.g. :code:`Keys.ENERGY`.
property_label
The string that the property is indexed by on the graphs. If not
provided, defaults to the value of :code:`property`, e.g.
:code:`Keys.ENERGY` :math:`\rightarrow` :code:`"energy"`.
transform
The transform to apply to the predictions and labels before plotting.
If not provided, no transform is applied.
units
The units of the property, for labelling the axes. If not provided, no
units are used.
ax
The axes to plot on. If not provided, the current axes are used.
scatter_kwargs
Keyword arguments to pass to :code:`plt.scatter`.
Examples
--------
Default settings (no units, transforms or custom scatter keywords):
.. code-block:: python
parity_plot(model, train, Keys.ENERGY)
.. image:: notebooks/Cu-LJ-default-parity.svg
:align: center
Custom settings, as seen in
:doc:`this example notebook <notebooks/example>`:
.. code-block:: python
from graph_pes.transform import DividePerAtom
from graph_pes.util import Keys
parity_plot(
model,
train,
Keys.ENERGY,
transform=DividePerAtom(),
units="eV/atom",
c="royalblue",
label="Train",
)
...
.. image:: notebooks/Cu-LJ-parity.svg
:align: center
"""
# deal with defaults
transform = transform or Identity()
if property_label is None:
property_label = property.value

# get the predictions and labels
if isinstance(graphs, list):
graphs = AtomicGraphBatch.from_graphs(graphs)

ground_truth = transform(graphs[property_label], graphs).detach()
predictions = transform(
get_predictions(model, graphs, {property: property_label})[
property_label
],
graphs,
).detach()

# plot
ax: plt.Axes = ax or plt.gca()

default_kwargs = dict(lw=0, clip_on=False)
scatter_kwargs = {**default_kwargs, **scatter_kwargs}
ax.scatter(ground_truth, predictions, **scatter_kwargs)

# get a point guaranteed to be on the plot
z = ground_truth.view(-1)[0].item()
ax.axline((z, z), slope=1, c="k", ls="--", lw=1)

# aesthetics
axis_label = f"{property_label} ({units})" if units else property_label
ax.set_xlabel(f"True {axis_label}")
ax.set_ylabel(f"Predicted {axis_label}")
ax.set_aspect("equal", "datalim")
x0, x1 = ax.get_xlim()
y0, y1 = ax.get_ylim()
ax.set_xlim(min(x0, y0), max(x1, y1))
ax.set_ylim(min(x0, y0), max(x1, y1))
move_axes(ax)

# 5 ticks each
ax.xaxis.set_major_locator(MaxNLocator(5))
ax.yaxis.set_major_locator(MaxNLocator(5))
5 changes: 4 additions & 1 deletion src/graph_pes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __repr__(self):
# when in eval mode
def get_predictions(
pes: GraphPESModel,
structure: AtomicGraph,
structure: AtomicGraph | AtomicGraphBatch | list[AtomicGraph],
property_labels: dict[Keys, str] | None = None,
) -> dict[str, torch.Tensor]:
"""
Expand All @@ -216,6 +216,9 @@ def get_predictions(
"""

if isinstance(structure, list):
structure = AtomicGraphBatch.from_graphs(structure)

if property_labels is None:
property_labels = {
Keys.ENERGY: "energy",
Expand Down
1 change: 1 addition & 0 deletions src/graph_pes/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Loss(nn.Module):
and that the resulting gradients and parameter updates are well-behaved.
:class:`Loss`'s in `graph-pes` are thus lightweight wrappers around:
* an (optional) pre-transform, :math:`T`
* a loss metric, :math:`M`.
Expand Down
29 changes: 29 additions & 0 deletions src/graph_pes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,32 @@ def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor:
def fit(self, x: Tensor, graphs: AtomicGraphBatch) -> Transform:
self.scale.data = x.var()
return self


class DividePerAtom(Transform):
"""
A convenience transform for dividing a property by the number of atoms
in the structure.
"""

def __init__(self):
super().__init__(trainable=False)

def fit(self, x: Tensor, graphs: AtomicGraphBatch) -> Transform:
return self

def forward(self, x: Tensor, graph: AtomicGraph) -> Tensor:
structure_sizes = (
graph.structure_sizes
if isinstance(graph, AtomicGraphBatch)
else graph.n_atoms
)
return x / structure_sizes

def inverse(self, x: Tensor, graph: AtomicGraph) -> Tensor:
structure_sizes = (
graph.structure_sizes
if isinstance(graph, AtomicGraphBatch)
else graph.n_atoms
)
return x * structure_sizes

0 comments on commit dea4d52

Please sign in to comment.