diff --git a/gflownet/proxy/crystals/pyxtal_gnn.py b/gflownet/proxy/crystals/pyxtal_gnn.py index c2869c3af..e035a2c92 100644 --- a/gflownet/proxy/crystals/pyxtal_gnn.py +++ b/gflownet/proxy/crystals/pyxtal_gnn.py @@ -1,7 +1,13 @@ from importlib.metadata import PackageNotFoundError, version +import pymatgen +import pymatgen.core import pyxtal import torch +from pymatgen.core.surface import ( + SlabGenerator, + get_symmetrically_distinct_miller_indices, +) from pyxtal.lattice import Lattice from pyxtal.msg import Comp_CompatibilityError from pyxtal.symmetry import Group @@ -10,10 +16,14 @@ from gflownet.proxy.base import Proxy from gflownet.utils.common import tfloat +# This is to revert a non-backwards-compatible change that was made since +# the moment the ocdata package was developped. Unfortunately, we can't use +# an older version of pymatgen from before this change because such an old +# version would not be compatible with PyXtal. +pymatgen.Composition = pymatgen.core.Composition + # URLs to the code repositories used by this proxy. DAVE_REPO_URL = "https://github.com/sh-divya/ActiveLearningMaterials.git" -OCP_REPO_URL = "https://github.com/RolnickLab/ocp" - def ensure_library_version(library_name, repo_url, release=None): """Ensure that a library is available for import with a given version @@ -33,13 +43,6 @@ def ensure_library_version(library_name, repo_url, release=None): raise PackageNotFoundError(f"Library `{library_name}` not found") - if release is not None and lib_version != release: - print(f" 💥 `{library_name}` version mismatch: ") - print(f" current ({lib_version}) != requested ({release})") - print(" Install the requested version with:") - print(f" $ pip install --upgrade git+{pip_url}\n") - raise ImportError(f"Wrong version for library `{library_name}`") - def is_valid_crystal( space_group, elements, num_ions, a, b, c, alpha, beta, gamma @@ -81,6 +84,9 @@ class PyxtalGNN(Proxy): This proxy assumes that the samples do not contain atom coordinates and therefore uses PyXtal to sample atom coordinates before feeding the samples to the GNN model. + + Requirements : + - OCD github repo cloned locally accessible from the pythonpath """ ENERGY_INVALID_SAMPLE = 10 @@ -91,9 +97,6 @@ def __init__(self, ckpt_path=None, dave_release=None, n_pyxtal_samples=1, **kwar # Import the necessary util function from the DAVE repository ensure_library_version("dave", DAVE_REPO_URL, dave_release) - # Import that the OCP package is available for the surface creation functions - ensure_library_version("ocp-models", OCP_REPO_URL) - self.n_pyxtal_samples = n_pyxtal_samples @torch.no_grad() @@ -144,7 +147,7 @@ def evaluate_state(self, state) -> float: pyxtal_sample_scores = [] for sample in pyxtal_samples: # Convert the PyXtal crystal to the graph format expected by the model - sample_graph = self.graph_from_pyxtal(sample) + sample_graph = self.graph_from_pyxtal_crystal(sample) # Score the sample using the model sample_score = self.energy_from_graph(sample_graph) @@ -159,9 +162,43 @@ def evaluate_state(self, state) -> float: return global_sample_score - def graph_from_pyxtal(self, pyxtal_crystal): + def graph_from_pyxtal_crystal(self, pyxtal_crystal): # Obtain util function from the DAVE repository from dave.utils.atoms_to_graph import AtomsToGraphs, pymatgen_structure_to_graph + + # Obtain util classes from OCD repository + from ocdata.adsorbates import Adsorbate + from ocdata.bulk_obj import Bulk + #from ocdata.surfaces import Surface + #from ocdata.combined import Combined + + # Obtain list of possible adsorbates + pass # TODO + + # Obtain a list of all symmetrically distinct surfaces of the bulk structure + # created from the crystal + crystal_bulk = pyxtal_crystal.to_pymatgen() # How to do this step? + crystal_slabs = self.enumerate_bulk_surfaces(crystal_bulk) + """ + crystal_slabs = [] + for miller_indices in get_symmetrically_distinct_miller_indices( + crystal_bulk, MAX_MILLER + ): + slab_gen = SlabGenerator( + initial_structure=bulk_struct, + miller_index=millers, + min_slab_size=7.0, + min_vacuum_size=20.0, + lll_reduce=False, + center_slab=True, + primitive=True, + max_normal_search=1, + ) + slabs = slab_gen.get_slabs( + tol=0.3, bonds=None, max_broken_bonds=0, symmetrize=False + ) + crystal_slabs.extend(slabs) + """ # Convert the PyXtal crystal to pymatgen and then to a graph format a2g = AtomsToGraphs( @@ -212,6 +249,25 @@ def generate_pyxtal_samples( return pyxtal_samples - def pyxtal2proxy(self, pyxtal_crystal): - # TODO - return None + def enumerate_bulk_surfaces(self, bulk_structure): + """ + This method wraps the method Bulk.enumerate_surfaces found at + https://github.com/RolnickLab/ocp/blob/sample-adslab/ocdata/bulk_obj.py#L173C38-L173C38. + This method implements the functionallity that we want but it is + implemented in the context of the Bulk class which wasn't made to be used on a + user-provided bulk structure. + + """ + # Obtain util classes from OCD repository + from ocdata.adsorbates import Adsorbate + from ocdata.bulk_obj import Bulk + + # Instantiate Bulk object. + # The standard __init__ requires providing a database of bulk structures so + # we override it to allow instantiation without having a database. + Bulk.__init__ = lambda x: None + bulk_instance = Bulk() + bulk_instance.bulk_atoms = bulk_structure.to_ase_atoms() + bulk_instance.mpid = None + + return bulk_instance.enumerate_surfaces()