From 51ea7def242ac308d3ca419ea60679bf61656096 Mon Sep 17 00:00:00 2001 From: Karlo Berket Date: Fri, 6 Sep 2024 08:56:51 -0700 Subject: [PATCH] add LRU cache to structure matcher (#4036) * add LRU cache to _get_reduced_structure computations * pre-commit auto-fixes * make structure hashable via as_dict * make structure hash recursive as_dict, change structure test to check hashability * precommit * moved computation using lru_cache out of class method to avoid memory leakage issue * pre-commit auto-fixes * fix structure matcher caching, fix a few tests (mcsqs wrong file destination and missing pytest approx in TestBSPlot) * precommit * add suggested SiteOrderedIStructure from @kbuma * pre-commit auto-fixes * add cast in eq for SiteOrderedIStructure to make mypy happy * pre-commit auto-fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: esoteric-ephemera --- src/pymatgen/analysis/structure_matcher.py | 46 +++++++++++++++++-- tests/electronic_structure/test_plotter.py | 2 +- .../test_advanced_transformations.py | 4 +- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/pymatgen/analysis/structure_matcher.py b/src/pymatgen/analysis/structure_matcher.py index ee1a73fcbae..39d2b5d030d 100644 --- a/src/pymatgen/analysis/structure_matcher.py +++ b/src/pymatgen/analysis/structure_matcher.py @@ -4,12 +4,13 @@ import abc import itertools -from typing import TYPE_CHECKING +from functools import lru_cache +from typing import TYPE_CHECKING, cast import numpy as np from monty.json import MSONable -from pymatgen.core import Composition, Lattice, Structure, get_el_sp +from pymatgen.core import SETTINGS, Composition, IStructure, Lattice, Structure, get_el_sp from pymatgen.optimization.linear_assignment import LinearAssignment from pymatgen.util.coord import lattice_points_in_supercell from pymatgen.util.coord_cython import is_coord_subset_pbc, pbc_shortest_vectors @@ -29,6 +30,31 @@ __email__ = "wrichard@mit.edu" __status__ = "Production" __date__ = "Dec 3, 2012" +LRU_CACHE_SIZE = SETTINGS.get("STRUCTURE_MATCHER_CACHE_SIZE", 300) + + +class SiteOrderedIStructure(IStructure): + """ + Imutable structure where the order of sites matters. + + In caching reduced structures (see `StructureMatcher._get_reduced_structure`) + the order of input sites can be important. + In general, the order of sites in a structure does not matter, but when + a method like `StructureMatcher.get_s2_like_s1` tries to put s2's sites in + the same order as s1, the site order matters. + """ + + def __eq__(self, other: object) -> bool: + """Check for IStructure equality and same site order.""" + if not super().__eq__(other): + return False + other = cast(SiteOrderedIStructure, other) # make mypy happy + + return list(self.sites) == list(other.sites) + + def __hash__(self) -> int: + """Use the composition hash for now.""" + return super().__hash__() class AbstractComparator(MSONable, abc.ABC): @@ -939,9 +965,12 @@ def _anonymous_match( break return matches - @classmethod - def _get_reduced_structure(cls, struct: Structure, primitive_cell: bool = True, niggli: bool = True) -> Structure: - """Helper method to find a reduced structure.""" + @staticmethod + @lru_cache(maxsize=LRU_CACHE_SIZE) + def _get_reduced_istructure( + struct: SiteOrderedIStructure, primitive_cell: bool = True, niggli: bool = True + ) -> SiteOrderedIStructure: + """Helper method to find a reduced imutable structure.""" reduced = struct.copy() if niggli: reduced = reduced.get_reduced_structure(reduction_algo="niggli") @@ -949,6 +978,13 @@ def _get_reduced_structure(cls, struct: Structure, primitive_cell: bool = True, reduced = reduced.get_primitive_structure() return reduced + @classmethod + def _get_reduced_structure(cls, struct: Structure, primitive_cell: bool = True, niggli: bool = True) -> Structure: + """Helper method to find a reduced structure.""" + return Structure.from_sites( + cls._get_reduced_istructure(SiteOrderedIStructure.from_sites(struct), primitive_cell, niggli) + ) + def get_rms_anonymous(self, struct1, struct2): """ Performs an anonymous fitting, which allows distinct species in one diff --git a/tests/electronic_structure/test_plotter.py b/tests/electronic_structure/test_plotter.py index 464efdbe63f..0214bb7e7df 100644 --- a/tests/electronic_structure/test_plotter.py +++ b/tests/electronic_structure/test_plotter.py @@ -153,7 +153,7 @@ def test_bs_plot_data(self): def test_get_ticks(self): assert self.plotter.get_ticks()["label"][5] == "K", "wrong tick label" - assert self.plotter.get_ticks()["distance"][5] == 2.406607625322699, "wrong tick distance" + assert self.plotter.get_ticks()["distance"][5] == pytest.approx(2.406607625322699), "wrong tick distance" # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. diff --git a/tests/transformations/test_advanced_transformations.py b/tests/transformations/test_advanced_transformations.py index dd872c06f43..f1964cc4f2c 100644 --- a/tests/transformations/test_advanced_transformations.py +++ b/tests/transformations/test_advanced_transformations.py @@ -594,7 +594,7 @@ def test_apply_transformation(self): @pytest.mark.skipif(not mcsqs_cmd, reason="mcsqs not present.") class TestSQSTransformation(PymatgenTest): def test_apply_transformation(self): - pzt_structs = loadfn(f"{TEST_FILES_DIR}/mcsqs/pzt-structs.json") + pzt_structs = loadfn(f"{TEST_FILES_DIR}/io/atat/mcsqs/pzt-structs.json") trans = SQSTransformation(scaling=[2, 1, 1], search_time=0.01, instances=1, wd=0) # nonsensical example just for testing purposes struct = self.get_structure("Pb2TiZrO6").copy() @@ -605,7 +605,7 @@ def test_apply_transformation(self): def test_return_ranked_list(self): # list of structures - pzt_structs_2 = loadfn(f"{TEST_FILES_DIR}/mcsqs/pzt-structs-2.json") + pzt_structs_2 = loadfn(f"{TEST_FILES_DIR}/io/atat/mcsqs/pzt-structs-2.json") n_structs_expected = 1 sqs_kwargs = {"scaling": 2, "search_time": 0.01, "instances": 8, "wd": 0}