Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
carriepl-mila committed Nov 24, 2023
1 parent 69e97fc commit 9896a9d
Showing 1 changed file with 73 additions and 17 deletions.
90 changes: 73 additions & 17 deletions gflownet/proxy/crystals/pyxtal_gnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from importlib.metadata import PackageNotFoundError, version

import pymatgen
import pymatgen.core
import pyxtal
import torch
from pymatgen.core.surface import (
SlabGenerator,
get_symmetrically_distinct_miller_indices,
)
from pyxtal.lattice import Lattice
from pyxtal.msg import Comp_CompatibilityError
from pyxtal.symmetry import Group
Expand All @@ -10,10 +16,14 @@
from gflownet.proxy.base import Proxy
from gflownet.utils.common import tfloat

# This is to revert a non-backwards-compatible change that was made since
# the moment the ocdata package was developped. Unfortunately, we can't use
# an older version of pymatgen from before this change because such an old
# version would not be compatible with PyXtal.
pymatgen.Composition = pymatgen.core.Composition

# URLs to the code repositories used by this proxy.
DAVE_REPO_URL = "https://github.com/sh-divya/ActiveLearningMaterials.git"
OCP_REPO_URL = "https://github.com/RolnickLab/ocp"


def ensure_library_version(library_name, repo_url, release=None):
"""Ensure that a library is available for import with a given version
Expand All @@ -33,13 +43,6 @@ def ensure_library_version(library_name, repo_url, release=None):

raise PackageNotFoundError(f"Library `{library_name}` not found")

if release is not None and lib_version != release:
print(f" 💥 `{library_name}` version mismatch: ")
print(f" current ({lib_version}) != requested ({release})")
print(" Install the requested version with:")
print(f" $ pip install --upgrade git+{pip_url}\n")
raise ImportError(f"Wrong version for library `{library_name}`")


def is_valid_crystal(
space_group, elements, num_ions, a, b, c, alpha, beta, gamma
Expand Down Expand Up @@ -81,6 +84,9 @@ class PyxtalGNN(Proxy):
This proxy assumes that the samples do not contain atom coordinates and
therefore uses PyXtal to sample atom coordinates before feeding the samples
to the GNN model.
Requirements :
- OCD github repo cloned locally accessible from the pythonpath
"""

ENERGY_INVALID_SAMPLE = 10
Expand All @@ -91,9 +97,6 @@ def __init__(self, ckpt_path=None, dave_release=None, n_pyxtal_samples=1, **kwar
# Import the necessary util function from the DAVE repository
ensure_library_version("dave", DAVE_REPO_URL, dave_release)

# Import that the OCP package is available for the surface creation functions
ensure_library_version("ocp-models", OCP_REPO_URL)

self.n_pyxtal_samples = n_pyxtal_samples

@torch.no_grad()
Expand Down Expand Up @@ -144,7 +147,7 @@ def evaluate_state(self, state) -> float:
pyxtal_sample_scores = []
for sample in pyxtal_samples:
# Convert the PyXtal crystal to the graph format expected by the model
sample_graph = self.graph_from_pyxtal(sample)
sample_graph = self.graph_from_pyxtal_crystal(sample)

# Score the sample using the model
sample_score = self.energy_from_graph(sample_graph)
Expand All @@ -159,9 +162,43 @@ def evaluate_state(self, state) -> float:

return global_sample_score

def graph_from_pyxtal(self, pyxtal_crystal):
def graph_from_pyxtal_crystal(self, pyxtal_crystal):
# Obtain util function from the DAVE repository
from dave.utils.atoms_to_graph import AtomsToGraphs, pymatgen_structure_to_graph

# Obtain util classes from OCD repository
from ocdata.adsorbates import Adsorbate
from ocdata.bulk_obj import Bulk
#from ocdata.surfaces import Surface
#from ocdata.combined import Combined

# Obtain list of possible adsorbates
pass # TODO

# Obtain a list of all symmetrically distinct surfaces of the bulk structure
# created from the crystal
crystal_bulk = pyxtal_crystal.to_pymatgen() # How to do this step?
crystal_slabs = self.enumerate_bulk_surfaces(crystal_bulk)
"""
crystal_slabs = []
for miller_indices in get_symmetrically_distinct_miller_indices(
crystal_bulk, MAX_MILLER
):
slab_gen = SlabGenerator(
initial_structure=bulk_struct,
miller_index=millers,
min_slab_size=7.0,
min_vacuum_size=20.0,
lll_reduce=False,
center_slab=True,
primitive=True,
max_normal_search=1,
)
slabs = slab_gen.get_slabs(
tol=0.3, bonds=None, max_broken_bonds=0, symmetrize=False
)
crystal_slabs.extend(slabs)
"""

# Convert the PyXtal crystal to pymatgen and then to a graph format
a2g = AtomsToGraphs(
Expand Down Expand Up @@ -212,6 +249,25 @@ def generate_pyxtal_samples(

return pyxtal_samples

def pyxtal2proxy(self, pyxtal_crystal):
# TODO
return None
def enumerate_bulk_surfaces(self, bulk_structure):
"""
This method wraps the method Bulk.enumerate_surfaces found at
https://github.com/RolnickLab/ocp/blob/sample-adslab/ocdata/bulk_obj.py#L173C38-L173C38.
This method implements the functionallity that we want but it is
implemented in the context of the Bulk class which wasn't made to be used on a
user-provided bulk structure.
"""
# Obtain util classes from OCD repository
from ocdata.adsorbates import Adsorbate
from ocdata.bulk_obj import Bulk

# Instantiate Bulk object.
# The standard __init__ requires providing a database of bulk structures so
# we override it to allow instantiation without having a database.
Bulk.__init__ = lambda x: None
bulk_instance = Bulk()
bulk_instance.bulk_atoms = bulk_structure.to_ase_atoms()
bulk_instance.mpid = None

return bulk_instance.enumerate_surfaces()

0 comments on commit 9896a9d

Please sign in to comment.