Skip to content

Commit

Permalink
Added support for numpy>2.0, and spglib>2.0, fixed warnings with docs…
Browse files Browse the repository at this point in the history
…tring (#32)

* Added support for numpy>2.0, and spglib>2.0, fixed warnings with docstring.

* Remove unused imports.
  • Loading branch information
lauri-codes authored Feb 21, 2025
1 parent 08e7521 commit 642d3ca
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
6 changes: 3 additions & 3 deletions docs/assets/conventional_cell_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def get_spglib_conventional(system):
dataset = spglib.get_symmetry_dataset(cell)

return Atoms(
symbols=dataset["std_types"],
scaled_positions=dataset["std_positions"],
cell=dataset["std_lattice"],
symbols=dataset.std_types,
scaled_positions=dataset.std_positions,
cell=dataset.std_lattice,
)

# Lets define two variants of NaCl in rocksalt structure
Expand Down
57 changes: 34 additions & 23 deletions matid/symmetry/symmetryanalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from ase import Atoms


class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


class SymmetryAnalyzer(object):
"""A base class for getting symmetry related properties of unit cells."""

Expand Down Expand Up @@ -144,7 +150,7 @@ def get_space_group_number(self):
int: The space group number.
"""
dataset = self.get_symmetry_dataset()
value = dataset["number"]
value = dataset.number

return value

Expand All @@ -154,7 +160,7 @@ def get_space_group_international_short(self):
str: The international space group short symbol.
"""
dataset = self.get_symmetry_dataset()
value = dataset["international"]
value = dataset.international

return value

Expand All @@ -164,7 +170,7 @@ def get_hall_symbol(self):
str: The Hall symbol.
"""
dataset = self.get_symmetry_dataset()
value = dataset["hall"]
value = dataset.hall

return value

Expand All @@ -174,7 +180,7 @@ def get_hall_number(self):
int: The Hall number.
"""
dataset = self.get_symmetry_dataset()
value = dataset["hall_number"]
value = dataset.hall_number

return value

Expand All @@ -186,7 +192,7 @@ def get_point_group(self):
str: point group symbol
"""
dataset = self.get_symmetry_dataset()
value = dataset["pointgroup"]
value = dataset.pointgroup

return value

Expand Down Expand Up @@ -367,7 +373,7 @@ def get_conventional_system(self):
# The index of the originally non-periodic dimension may not correspond
# to the one in the normalized system, because the normalized system
# may use a different coordinate system.
transformation_matrix = self.get_symmetry_dataset()["transformation_matrix"]
transformation_matrix = self.get_symmetry_dataset().transformation_matrix
nonperiodic_axis = None
prec = 1e-8
for i_axis, axis in enumerate(transformation_matrix):
Expand Down Expand Up @@ -448,7 +454,7 @@ def get_rotations(self):
np.ndarray: Rotation matrices.
"""
dataset = self.get_symmetry_dataset()
value = dataset["rotations"]
value = dataset.rotations

return value

Expand All @@ -461,7 +467,7 @@ def get_translations(self):
np.ndarray: Translation vectors.
"""
dataset = self.get_symmetry_dataset()
value = dataset["translations"]
value = dataset.translations

return value

Expand All @@ -472,7 +478,7 @@ def get_choice(self):
settings.
"""
dataset = self.get_symmetry_dataset()
value = dataset["choice"]
value = dataset.choice

return value

Expand Down Expand Up @@ -595,6 +601,11 @@ def get_symmetry_dataset(self):
if symmetry_dataset is None:
raise CellNormalizationError("Spglib error when finding symmetry dataset.")

# Prior to spglib 2.5.0 the dataset is returned as a dictionary: this
# provides backwards compatibility
if isinstance(symmetry_dataset, dict):
symmetry_dataset = AttrDict(symmetry_dataset)

self._symmetry_dataset = symmetry_dataset

return symmetry_dataset
Expand All @@ -610,9 +621,9 @@ def _get_spglib_conventional_system(self):
return self._spglib_conventional_system

dataset = self.get_symmetry_dataset()
cell = dataset["std_lattice"]
pos = dataset["std_positions"]
num = dataset["std_types"]
cell = dataset.std_lattice
pos = dataset.std_positions
num = dataset.std_types
spg_conv_sys = self._spglib_description_to_system((cell, pos, num))

self._spglib_conventional_system = spg_conv_sys
Expand All @@ -624,7 +635,7 @@ def _get_spglib_wyckoff_letters_original(self):
list of str: Wyckoff letters for the atoms in the original system.
"""
dataset = self.get_symmetry_dataset()
value = np.array(dataset["wyckoffs"])
value = np.array(dataset.wyckoffs)

return value

Expand All @@ -640,7 +651,7 @@ def _get_spglib_equivalent_atoms_original(self):
# equivalent atoms reported by spglib are based on the symmetry of the
# original cell. Equivalence in crystallographic_orbits is instead
# based on the primitive cell/conventional cell which is what we want.
value = dataset["crystallographic_orbits"]
value = dataset.crystallographic_orbits

return value

Expand All @@ -653,7 +664,7 @@ def _get_spglib_wyckoff_letters_conventional(self):
if self._spglib_wyckoff_letters_conventional is None:
wyckoff_letters_primitive = self._get_spglib_wyckoff_letters_primitive()
dataset = self.get_symmetry_dataset()
mapping = dataset["std_mapping_to_primitive"]
mapping = dataset.std_mapping_to_primitive
self._spglib_wyckoff_letters_conventional = wyckoff_letters_primitive[
mapping
]
Expand All @@ -668,7 +679,7 @@ def _get_spglib_equivalent_atoms_conventional(self):
if self._spglib_equivalent_atoms_conventional is None:
equivalent_atoms_primitive = self._get_spglib_equivalent_atoms_primitive()
dataset = self.get_symmetry_dataset()
mapping = dataset["std_mapping_to_primitive"]
mapping = dataset.std_mapping_to_primitive
self._spglib_equivalent_atoms_conventional = equivalent_atoms_primitive[
mapping
]
Expand All @@ -690,7 +701,7 @@ def _get_spglib_origin_shift(self):
3*1 np.ndarray: The shift of the origin as a vector.
"""
dataset = self.get_symmetry_dataset()
value = dataset["origin_shift"]
value = dataset.origin_shift

return value

Expand All @@ -707,8 +718,8 @@ def get_symmetry_operations(self):
"""
dataset = self.get_symmetry_dataset()
operations = {
"rotations": dataset["rotations"],
"translations": dataset["translations"],
"rotations": dataset.rotations,
"translations": dataset.translations,
}

return operations
Expand All @@ -728,7 +739,7 @@ def _get_spglib_transformation_matrix(self):
3x3 np.ndarray:
"""
dataset = self.get_symmetry_dataset()
value = dataset["transformation_matrix"]
value = dataset.transformation_matrix

return value

Expand Down Expand Up @@ -802,7 +813,7 @@ def _get_spglib_primitive_to_original_mapping(self):
"""
if self._spglib_primitive_to_original_mapping is None:
dataset = self.get_symmetry_dataset()
mapping = dataset["mapping_to_primitive"]
mapping = dataset.mapping_to_primitive
_, indices = np.unique(mapping, return_index=True)
self._spglib_primitive_to_original_mapping = indices

Expand Down Expand Up @@ -896,7 +907,7 @@ def _get_primitive_system(
# Keep one occurrence for each atom that should be within the cell and
# wrap it's position to tbe inside the primitive cell.
conv_num = conv_system.get_atomic_numbers()
conv_to_prim_map = self._symmetry_dataset["std_mapping_to_primitive"]
conv_to_prim_map = self._symmetry_dataset.std_mapping_to_primitive
_, inside_mask = np.unique(conv_to_prim_map, return_index=True)
prim_pos = prim_pos[inside_mask]
prim_num = conv_num[inside_mask]
Expand Down Expand Up @@ -1204,7 +1215,7 @@ def _get_wyckoff_sets(
precision,
return_parameters,
):
"""Used to get detailed information about about the sets of equivalent
r"""Used to get detailed information about about the sets of equivalent
atoms. The detected Wyckoff set variables (x, y, z) are reported
consistenly by selecting the variable sets that has lowest x value, then
lowest y and finally lowest z.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

for hall_number in range(1, 531):
dataset = spglib.get_spacegroup_type(hall_number)
number = dataset["number"]
number = dataset.number
space_hall_map[number].append(hall_number)

degenerate_spgs = []
Expand All @@ -26,7 +26,7 @@
degenerate_spgs.append(key)
first_hall = value[0]
dataset = spglib.get_spacegroup_type(first_hall)
choice = dataset["choice"]
choice = dataset.choice

# try:
# origin = int(choice)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

for hall_number in range(1, 531):
dataset = spglib.get_spacegroup_type(hall_number)
number = dataset["number"]
international_short = dataset["international_short"]
number = dataset.number
international_short = dataset.international_short

# Check that the spglib data has no two different international symbols for
# the same space group number
Expand All @@ -28,7 +28,7 @@
# Point group. There actually seeems to be a bug in spglib 1.9.4, where
# the Hermann-Mauguin point group symbol is in the plalce of Schonflies
# data and vice versa.
pointgroup = dataset["pointgroup_schoenflies"]
pointgroup = dataset.pointgroup_schoenflies
space_group_database[number]["pointgroup"] = pointgroup

# Crystal system
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[build-system]
requires = ["setuptools", "wheel", "pybind11~=2.11.1"]
requires = ["setuptools", "wheel", "pybind11~=2.13.6"]
build-backend = "setuptools.build_meta"

[project]
name = 'matid'
version = '2.1.4'
version = '2.1.5'
description = 'MatID is a Python package for identifying and analyzing atomistic systems based on their structure.'
readme = "README.md"
authors = [{ name = "Lauri Himanen" }]
license = { file = "LICENSE" }
requires-python = ">=3.8"
dependencies = [
"numpy<2.0.0",
"numpy",
"ase",
"spglib>=1.15.0",
"spglib>=2.0.0",
"scikit-learn",
"networkx>=2.4",
]
Expand Down

0 comments on commit 642d3ca

Please sign in to comment.