-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add analysis module and useful transform for dividing property by the…
… number of atoms
- Loading branch information
1 parent
8e8d52b
commit dea4d52
Showing
10 changed files
with
5,298 additions
and
382 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
######## | ||
Analysis | ||
######## | ||
|
||
|
||
.. autofunction:: graph_pes.analysis.parity_plot |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,9 @@ | |
models | ||
data | ||
training | ||
analysis | ||
examples | ||
|
||
|
||
######## | ||
GraphPES | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,155 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
from functools import wraps | ||
|
||
import matplotlib.pyplot as plt | ||
import torch | ||
from graph_pes.core import GraphPESModel, energy_and_forces | ||
from graph_pes.data import AtomicGraph, AtomicGraphBatch | ||
from graph_pes.transform import ( | ||
Chain, | ||
PerSpeciesScale, | ||
PerSpeciesShift, | ||
Transform, | ||
) | ||
|
||
|
||
def parity_plots( | ||
from matplotlib.ticker import MaxNLocator | ||
|
||
from .core import GraphPESModel, get_predictions | ||
from .data.atomic_graph import AtomicGraph | ||
from .data.batching import AtomicGraphBatch | ||
from .transform import Identity, Transform | ||
from .util import Keys | ||
|
||
|
||
def my_style(func): | ||
""" | ||
Decorator to use my home-made plt style within a function | ||
""" | ||
|
||
style = { | ||
"figure.figsize": (3, 3), | ||
"axes.spines.right": False, | ||
"axes.spines.top": False, | ||
} | ||
|
||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
with plt.rc_context(style): | ||
return func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
|
||
def move_axes(ax): | ||
""" | ||
Move the axes to the center of the figure | ||
""" | ||
ax.spines["left"].set_position(("outward", 10)) | ||
ax.spines["bottom"].set_position(("outward", 10)) | ||
|
||
|
||
@my_style | ||
def parity_plot( | ||
model: GraphPESModel, | ||
graphs: list[AtomicGraph], | ||
E_transform: Transform | None = None, | ||
F_transform: Transform | None = None, | ||
axs: tuple[plt.Axes, plt.Axes] | None = None, | ||
E_kwargs: dict[str, Any] | None = None, | ||
F_kwargs: dict[str, Any] | None = None, | ||
**kwargs, | ||
) -> tuple[plt.Axes, plt.Axes]: | ||
if E_transform is None: | ||
E_transform = Chain([PerSpeciesScale(), PerSpeciesShift()]) | ||
if F_transform is None: | ||
F_transform = PerSpeciesScale() | ||
|
||
batch = AtomicGraphBatch.from_graphs(graphs) | ||
|
||
true_E, true_F = batch.labels["energy"], batch.labels["forces"] | ||
|
||
E_transform.fit_to_target(true_E, batch) | ||
F_transform.fit_to_target(true_F, batch) | ||
|
||
preds = energy_and_forces(model, batch) | ||
pred_E, pred_F = preds["energy"].detach(), preds["forces"].detach() | ||
|
||
with torch.no_grad(): | ||
scaled_pred_E = E_transform.inverse(pred_E, batch) | ||
scaled_true_E = E_transform.inverse(true_E, batch) | ||
scaled_pred_F = F_transform.inverse(pred_F, batch) | ||
scaled_true_F = F_transform.inverse(true_F, batch) | ||
|
||
if axs is None: | ||
_, axs = plt.subplots(1, 2, figsize=(6, 3)) # type: ignore | ||
|
||
E_ax, F_ax = axs # type: ignore | ||
|
||
E_defaults = dict(marker="+") | ||
E_kwargs = {**E_defaults, **kwargs, **(E_kwargs or {})} | ||
E_ax.scatter(scaled_true_E, scaled_pred_E, **E_kwargs) | ||
E_ax.axline((0, 0), slope=1, color="k", ls="--", lw=1) | ||
E_ax.set_aspect("equal", "datalim") | ||
E_ax.set_xlabel(r"$E$ (a.u.)") | ||
E_ax.set_ylabel(r"$\tilde{E}$ (a.u.)") | ||
|
||
F_defaults = dict(lw=0, s=3, alpha=0.2) | ||
F_kwargs = {**F_defaults, **kwargs, **(F_kwargs or {})} | ||
F_ax.scatter(scaled_true_F, scaled_pred_F, **F_kwargs) | ||
F_ax.axline((0, 0), slope=1, color="k", ls="--", lw=1) | ||
F_ax.set_aspect("equal", "datalim") | ||
F_ax.set_xlabel(r"$F$ (a.u.)") | ||
F_ax.set_ylabel(r"$\tilde{F}$ (a.u.)") | ||
|
||
return axs # type: ignore | ||
graphs: AtomicGraphBatch | list[AtomicGraph], | ||
property: Keys, | ||
property_label: str | None = None, | ||
transform: Transform | None = None, | ||
units: str | None = None, | ||
ax: plt.Axes | None = None, # type: ignore | ||
**scatter_kwargs, | ||
): | ||
r""" | ||
A nicely formatted parity plot of model predictions vs ground truth | ||
for the given :code:`property`. | ||
Parameters | ||
---------- | ||
model | ||
The model to for generating predictions. | ||
graphs | ||
The graphs to make predictions on. | ||
property | ||
The property to plot, e.g. :code:`Keys.ENERGY`. | ||
property_label | ||
The string that the property is indexed by on the graphs. If not | ||
provided, defaults to the value of :code:`property`, e.g. | ||
:code:`Keys.ENERGY` :math:`\rightarrow` :code:`"energy"`. | ||
transform | ||
The transform to apply to the predictions and labels before plotting. | ||
If not provided, no transform is applied. | ||
units | ||
The units of the property, for labelling the axes. If not provided, no | ||
units are used. | ||
ax | ||
The axes to plot on. If not provided, the current axes are used. | ||
scatter_kwargs | ||
Keyword arguments to pass to :code:`plt.scatter`. | ||
Examples | ||
-------- | ||
Default settings (no units, transforms or custom scatter keywords): | ||
.. code-block:: python | ||
parity_plot(model, train, Keys.ENERGY) | ||
.. image:: notebooks/Cu-LJ-default-parity.svg | ||
:align: center | ||
Custom settings, as seen in | ||
:doc:`this example notebook <notebooks/example>`: | ||
.. code-block:: python | ||
from graph_pes.transform import DividePerAtom | ||
from graph_pes.util import Keys | ||
parity_plot( | ||
model, | ||
train, | ||
Keys.ENERGY, | ||
transform=DividePerAtom(), | ||
units="eV/atom", | ||
c="royalblue", | ||
label="Train", | ||
) | ||
... | ||
.. image:: notebooks/Cu-LJ-parity.svg | ||
:align: center | ||
""" | ||
# deal with defaults | ||
transform = transform or Identity() | ||
if property_label is None: | ||
property_label = property.value | ||
|
||
# get the predictions and labels | ||
if isinstance(graphs, list): | ||
graphs = AtomicGraphBatch.from_graphs(graphs) | ||
|
||
ground_truth = transform(graphs[property_label], graphs).detach() | ||
predictions = transform( | ||
get_predictions(model, graphs, {property: property_label})[ | ||
property_label | ||
], | ||
graphs, | ||
).detach() | ||
|
||
# plot | ||
ax: plt.Axes = ax or plt.gca() | ||
|
||
default_kwargs = dict(lw=0, clip_on=False) | ||
scatter_kwargs = {**default_kwargs, **scatter_kwargs} | ||
ax.scatter(ground_truth, predictions, **scatter_kwargs) | ||
|
||
# get a point guaranteed to be on the plot | ||
z = ground_truth.view(-1)[0].item() | ||
ax.axline((z, z), slope=1, c="k", ls="--", lw=1) | ||
|
||
# aesthetics | ||
axis_label = f"{property_label} ({units})" if units else property_label | ||
ax.set_xlabel(f"True {axis_label}") | ||
ax.set_ylabel(f"Predicted {axis_label}") | ||
ax.set_aspect("equal", "datalim") | ||
x0, x1 = ax.get_xlim() | ||
y0, y1 = ax.get_ylim() | ||
ax.set_xlim(min(x0, y0), max(x1, y1)) | ||
ax.set_ylim(min(x0, y0), max(x1, y1)) | ||
move_axes(ax) | ||
|
||
# 5 ticks each | ||
ax.xaxis.set_major_locator(MaxNLocator(5)) | ||
ax.yaxis.set_major_locator(MaxNLocator(5)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters