diff --git a/.github/workflows/issue-metrics.yml b/.github/workflows/issue-metrics.yml index 26e498a56ae..484dd0effcd 100644 --- a/.github/workflows/issue-metrics.yml +++ b/.github/workflows/issue-metrics.yml @@ -9,6 +9,8 @@ permissions: jobs: build: + # prevent this action from running on forks + if: github.repository == 'materialsproject/pymatgen' name: issue metrics runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b6d5930dd17..dc2b8238851 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,11 +35,11 @@ jobs: resolution: highest extras: ci,optional - os: ubuntu-latest - python: '>3.9' + python: ">3.9" resolution: lowest-direct extras: ci,optional - os: macos-latest - python: '3.10' + python: "3.10" resolution: lowest-direct extras: ci # test with only required dependencies installed @@ -70,18 +70,29 @@ jobs: - name: Install ubuntu-only conda dependencies if: matrix.config.os == 'ubuntu-latest' run: | - micromamba install -n pmg -c conda-forge enumlib packmol bader openbabel openff-toolkit --yes + micromamba install -n pmg -c conda-forge enumlib packmol bader openbabel openff-toolkit pygraphviz --yes - name: Install pymatgen and dependencies run: | micromamba activate pmg + # TODO remove temporary fix. added since uv install torch is flaky. # track https://github.com/astral-sh/uv/issues/1921 for resolution pip install torch --upgrade - uv pip install numpy cython + uv pip install cython setuptools wheel + uv pip install --editable '.[${{ matrix.config.extras }}]' --resolution=${{ matrix.config.resolution }} + - name: Install optional Ubuntu dependencies + if: matrix.config.os == 'ubuntu-latest' + run: | + micromamba activate pmg + + # TODO: uv cannot install BoltzTraP2 (#3786), + # suggesting no NumPy when there is + pip install BoltzTraP2 + - name: pytest split ${{ matrix.split }} run: | micromamba activate pmg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 23276773f85..a911872a807 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,16 +2,16 @@ exclude: ^(docs|tests/files|tasks.py) ci: autoupdate_schedule: monthly - skip: [ mypy, pyright ] + skip: [mypy, pyright] autofix_commit_msg: pre-commit auto-fixes autoupdate_commit_msg: pre-commit autoupdate repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.4 + rev: v0.5.6 hooks: - id: ruff - args: [ --fix, --unsafe-fixes ] + args: [--fix, --unsafe-fixes] - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks @@ -22,7 +22,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.0 + rev: v1.11.1 hooks: - id: mypy @@ -30,16 +30,16 @@ repos: rev: v2.3.0 hooks: - id: codespell - stages: [ commit, commit-msg ] - exclude_types: [ html ] - additional_dependencies: [ tomli ] # needed to read pyproject.toml below py3.11 + stages: [commit, commit-msg] + exclude_types: [html] + additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11 exclude: src/pymatgen/analysis/aflow_prototypes.json - repo: https://github.com/MarcoGorelli/cython-lint rev: v0.16.2 hooks: - id: cython-lint - args: [ --no-pycodestyle ] + args: [--no-pycodestyle] - id: double-quote-cython-strings - repo: https://github.com/adamchainz/blacken-docs @@ -56,15 +56,15 @@ repos: # MD033: no inline HTML # MD041: first line in a file should be a top-level heading # MD025: single title - args: [ --disable, MD013, MD024, MD025, MD033, MD041, "--" ] + args: [--disable, MD013, MD024, MD025, MD033, MD041, "--"] - repo: https://github.com/kynan/nbstripout rev: 0.7.1 hooks: - id: nbstripout - args: [ --drop-empty-cells, --keep-output ] + args: [--drop-empty-cells, --keep-output] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.373 + rev: v1.1.374 hooks: - id: pyright diff --git a/dev_scripts/chemenv/explicit_permutations.py b/dev_scripts/chemenv/explicit_permutations.py index 097e80b25e0..f3225fa5293 100644 --- a/dev_scripts/chemenv/explicit_permutations.py +++ b/dev_scripts/chemenv/explicit_permutations.py @@ -10,6 +10,7 @@ import os import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import ( AllCoordinationGeometries, ExplicitPermutationsAlgorithm, diff --git a/dev_scripts/chemenv/explicit_permutations_plane_algorithm.py b/dev_scripts/chemenv/explicit_permutations_plane_algorithm.py index 689ae58d86a..bd71080c300 100644 --- a/dev_scripts/chemenv/explicit_permutations_plane_algorithm.py +++ b/dev_scripts/chemenv/explicit_permutations_plane_algorithm.py @@ -9,6 +9,7 @@ import json import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import AllCoordinationGeometries from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import ( AbstractGeometry, diff --git a/dev_scripts/chemenv/get_plane_permutations_optimized.py b/dev_scripts/chemenv/get_plane_permutations_optimized.py index 127ca009f19..1244d13e487 100644 --- a/dev_scripts/chemenv/get_plane_permutations_optimized.py +++ b/dev_scripts/chemenv/get_plane_permutations_optimized.py @@ -15,6 +15,7 @@ import numpy as np import tabulate + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import AllCoordinationGeometries from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import ( AbstractGeometry, diff --git a/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py b/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py index 70449d924f8..5709df6317a 100644 --- a/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py +++ b/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import ( AngleNbSetWeight, CNBiasNbSetWeight, @@ -50,7 +51,9 @@ def __init__(self, initial_environment_symbol, expected_final_environment_symbol self.abstract_geometry = AbstractGeometry.from_cg(self.coordination_geometry) @classmethod - def simple_expansion(cls, initial_environment_symbol, expected_final_environment_symbol, neighbors_indices): + def simple_expansion( + cls, initial_environment_symbol, expected_final_environment_symbol, neighbors_indices + ) -> CoordinationEnvironmentMorphing: """Simple expansion of a coordination environment. Args: diff --git a/dev_scripts/chemenv/test_algos.py b/dev_scripts/chemenv/test_algos.py index 17933a067a3..dbffdce27e5 100644 --- a/dev_scripts/chemenv/test_algos.py +++ b/dev_scripts/chemenv/test_algos.py @@ -8,6 +8,7 @@ from random import shuffle import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import AllCoordinationGeometries from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import ( AbstractGeometry, diff --git a/dev_scripts/chemenv/view_environment.py b/dev_scripts/chemenv/view_environment.py index 69aa6adc42e..2caa22e9f34 100644 --- a/dev_scripts/chemenv/view_environment.py +++ b/dev_scripts/chemenv/view_environment.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import ( SEPARATION_PLANE, AllCoordinationGeometries, diff --git a/dev_scripts/potcar_scrambler.py b/dev_scripts/potcar_scrambler.py index c0793c713d9..0dd2b0190fa 100644 --- a/dev_scripts/potcar_scrambler.py +++ b/dev_scripts/potcar_scrambler.py @@ -9,6 +9,7 @@ import numpy as np from monty.os.path import zpath from monty.serialization import zopen + from pymatgen.core import SETTINGS from pymatgen.io.vasp import Potcar, PotcarSingle from pymatgen.io.vasp.sets import _load_yaml_config diff --git a/dev_scripts/update_pt_data.py b/dev_scripts/update_pt_data.py index ee7889c75ea..88f321ed712 100644 --- a/dev_scripts/update_pt_data.py +++ b/dev_scripts/update_pt_data.py @@ -11,9 +11,10 @@ import requests from monty.dev import requires from monty.serialization import dumpfn, loadfn -from pymatgen.core import Element, get_el_sp from ruamel import yaml +from pymatgen.core import Element, get_el_sp + try: from bs4 import BeautifulSoup except ImportError: diff --git a/dev_scripts/update_spacegroup_data.py b/dev_scripts/update_spacegroup_data.py index a4607cdbee7..07cd5d66344 100644 --- a/dev_scripts/update_spacegroup_data.py +++ b/dev_scripts/update_spacegroup_data.py @@ -12,6 +12,7 @@ import sys from monty.serialization import dumpfn, loadfn + from pymatgen.symmetry.groups import PointGroup __author__ = "Katharina Ueltzen @kaueltzen" @@ -29,7 +30,7 @@ def convert_symmops_to_sg_encoding(symbol: str) -> str: Args: symbol (str): "hermann_mauguin" or "universal_h_m" key of symmops.json Returns: - symbol in the format of SYMM_DATA["space_group_encoding"] keys + str: symbol in the format of SYMM_DATA["space_group_encoding"] keys """ symbol_representation = symbol.split(":") representation = ":" + "".join(symbol_representation[1].split(" ")) if len(symbol_representation) > 1 else "" @@ -50,7 +51,7 @@ def remove_identity_from_full_hermann_mauguin(symbol: str) -> str: Args: symbol (str): "hermann_mauguin" key of symmops.json Returns: - short "hermann_mauguin" key + str: short "hermann_mauguin" key """ if symbol in ("P 1", "C 1", "P 1 "): return symbol diff --git a/pyproject.toml b/pyproject.toml index 0546b4b7ac0..f56ecb39b41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,13 +10,11 @@ build-backend = "setuptools.build_meta" [project] name = "pymatgen" -authors = [ - { name = "Pymatgen Development Team", email = "ongsp@ucsd.edu" }, -] +authors = [{ name = "Pymatgen Development Team", email = "ongsp@ucsd.edu" }] maintainers = [ - { name = "Shyue Ping Ong", email = "ongsp@ucsd.edu" }, - { name = "Matthew Horton", email = "m.k.horton@gmail.com" }, { name = "Janosh Riebesell", email = "janosh.riebesell@gmail.com" }, + { name = "Matthew Horton", email = "m.k.horton@gmail.com" }, + { name = "Shyue Ping Ong", email = "ongsp@ucsd.edu" }, ] description = """ Python Materials Genomics is a robust materials analysis code that defines core object representations for structures @@ -26,6 +24,7 @@ readme = "README.md" requires-python = ">=3.9" keywords = [ "ABINIT", + "VASP", "analysis", "crystal", "diagrams", @@ -38,7 +37,6 @@ keywords = [ "qchem", "science", "structure", - "VASP", ] license = { text = "MIT" } classifiers = [ @@ -46,23 +44,24 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Chemistry", "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Scientific/Engineering :: Physics", "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ + "joblib>=1", "matplotlib>=3.8", - "monty>=2024.5.24", + "monty>=2024.7.29", "networkx>=2.2", - 'numpy>=1.25.0,<2.0 ; platform_system == "Windows"', - 'numpy>=1.25.0 ; platform_system != "Windows"', - "palettable>=3.1.1", + "numpy>=1.25.0 ; platform_system != 'Windows'", + "numpy>=1.25.0,<2.0 ; platform_system == 'Windows'", + "palettable>=3.3.3", "pandas>=2", "plotly>=4.5.0", "pybtex>=0.24.0", @@ -74,7 +73,6 @@ dependencies = [ "tabulate>=0.9", "tqdm>=4.60", "uncertainties>=3.1.4", - "joblib>=1", ] version = "2024.7.18" @@ -91,23 +89,14 @@ ase = ["ase>=3.23.0"] tblite = ["tblite[ase]>=0.3.0; python_version<'3.12'"] vis = ["vtk>=6.0.0"] abinit = ["netcdf4>=1.6.5"] -mlp = ["matgl>=1.1.1", "chgnet>=0.3.8"] +mlp = ["chgnet>=0.3.8", "matgl>=1.1.1"] electronic_structure = ["fdint>=2.0.2"] -ci = [ - "pytest>=8", - "pytest-cov>=4", - "pytest-split>=0.8", -] -docs = [ - "sphinx", - "sphinx_rtd_theme", -] +ci = ["pytest-cov>=4", "pytest-split>=0.8", "pytest>=8"] +docs = ["sphinx", "sphinx_rtd_theme"] optional = [ "ase>=3.23.0", - # TODO restore BoltzTraP2 when install fixed, hopefully following merge of - # https://gitlab.com/sousaw/BoltzTraP2/-/merge_requests/18 - # caused CI failure due to ModuleNotFoundError: No module named 'packaging' - # "BoltzTraP2>=22.3.2; platform_system!='Windows'", + # TODO: uv cannot install BoltzTraP2 + # "BoltzTraP2>=24.7.2 ; platform_system != 'Windows'", "chemview>=0.6", "chgnet>=0.3.8", "f90nml>=1.1.2", @@ -115,6 +104,8 @@ optional = [ "h5py>=3.11.0", "jarvis-tools>=2020.7.14", "matgl>=1.1.1", + # TODO: track https://github.com/matplotlib/matplotlib/issues/28551 + "matplotlib>=3.8,!=3.9.1", "netCDF4>=1.6.5", "phonopy>=2.23", "seekpath>=2.0.1", @@ -139,7 +130,7 @@ where = ["src"] include = ["pymatgen", "pymatgen.*"] [tool.setuptools.package-data] -"pymatgen.analysis" = ["*.yaml", "*.json", "*.csv"] +"pymatgen.analysis" = ["*.csv", "*.json", "*.yaml"] "pymatgen.analysis.chemenv" = [ "coordination_environments/coordination_geometries_files/*.json", "coordination_environments/coordination_geometries_files/*.txt", @@ -152,27 +143,19 @@ include = ["pymatgen", "pymatgen.*"] "pymatgen.entries" = ["*.json.gz", "*.yaml", "data/*.json"] "pymatgen.core" = ["*.json"] "pymatgen" = ["py.typed"] -"pymatgen.io.vasp" = ["*.yaml", "*.json", "*.json.gz", "*.json.bz2"] +"pymatgen.io.vasp" = ["*.json", "*.json.bz2", "*.json.gz", "*.yaml"] "pymatgen.io.feff" = ["*.yaml"] "pymatgen.io.cp2k" = ["*.yaml"] "pymatgen.io.lobster" = ["lobster_basis/*.yaml"] "pymatgen.command_line" = ["*"] -"pymatgen.util" = ["structures/*.json", "*.json"] +"pymatgen.util" = ["*.json", "structures/*.json"] "pymatgen.vis" = ["*.yaml"] "pymatgen.io.lammps" = ["CoeffsDataType.yaml", "templates/*.template"] -"pymatgen.symmetry" = ["*.yaml", "*.json", "*.sqlite"] +"pymatgen.symmetry" = ["*.json", "*.sqlite", "*.yaml"] [tool.pdm.dev-dependencies] -lint = [ - "mypy>=1.10.0", - "ruff>=0.4.9", - "pre-commit>=3.7.1", -] -test = [ - "pytest>=8.2.2", - "pytest-cov>=5.0.0", - "pytest-split>=0.9.0", -] +lint = ["mypy>=1.10.0", "pre-commit>=3.7.1", "ruff>=0.4.9"] +test = ["pytest-cov>=5.0.0", "pytest-split>=0.9.0", "pytest>=8.2.2"] [tool.versioningit.vcs] method = "git" @@ -193,63 +176,65 @@ line-length = 120 [tool.ruff.lint] select = ["ALL"] ignore = [ - # Rule families - "ANN", # flake8-annotations (not ready, require types for ALL args) - "ARG", # Check for unused function arguments - "BLE", # General catch of Exception - "C90", # Check for functions with a high McCabe complexity - "COM", # flake8-commas (conflict with line wrapper) - "CPY", # Missing copyright notice at top of file (need preview mode) - "EM", # Format nice error messages - "ERA", # Check for commented-out code - "FIX", # Check for FIXME, TODO and other developer notes - "FURB", # refurb (need preview mode, too many preview errors) - "G", # validate logging format strings - "INP", # Ban PEP-420 implicit namespace packages - "N", # pep8-naming (many var/arg names are intended) - "NPY", # NumPy-specific rules (TODO: enable this) - "PTH", # Prefer pathlib over os.path - "S", # flake8-bandit (TODO: enable this) - "SLF", # Access "private" class members - "T20", # Check for print/pprint - "TD", # TODO tags related + # Rule families + "ANN", # flake8-annotations (not ready, require types for ALL args) + "ARG", # Check for unused function arguments + "BLE", # General catch of Exception + "C90", # Check for functions with a high McCabe complexity + "COM", # flake8-commas (conflict with line wrapper) + "CPY", # Missing copyright notice at top of file (need preview mode) + "EM", # Format nice error messages + "ERA", # Check for commented-out code + "FIX", # Check for FIXME, TODO and other developer notes + "FURB", # refurb (need preview mode, too many preview errors) + "G", # validate logging format strings + "INP", # Ban PEP-420 implicit namespace packages + "N", # pep8-naming (many var/arg names are intended) + "NPY", # NumPy-specific rules (TODO: enable this) + "PTH", # Prefer pathlib over os.path + "S", # flake8-bandit (TODO: enable this) + "SLF", # Access "private" class members + "T20", # Check for print/pprint + "TD", # TODO tags related - # Single rules - "B023", # Function definition does not bind loop variable - "B028", # No explicit stacklevel keyword argument found - "B904", # Within an except clause, raise exceptions with ... - "C408", # unnecessary-collection-call - "D105", # Missing docstring in magic method - "D205", # 1 blank line required between summary line and description - "D212", # Multi-line docstring summary should start at the first line - "DTZ003", # TODO: fix this (issue #3791) - "FBT001", # Boolean-typed positional argument in function definition - "FBT002", # Boolean default positional argument in function - "PD901", # pandas-df-variable-name - "PERF203", # try-except-in-loop - "PERF401", # manual-list-comprehension - "PLR0911", # too many return statements - "PLR0912", # too many branches - "PLR0913", # too many arguments - "PLR0915", # too many statements - "PLR2004", # magic values in comparison - "PLW2901", # Outer for loop variable overwritten by inner assignment target - "PT013", # pytest-incorrect-pytest-import - "SIM105", # Use contextlib.suppress() instead of try-except-pass - "TRY003", # Avoid specifying long messages outside the exception class - "TRY300", # Checks for return statements in try blocks - "TRY301", # Checks for raise statements within try blocks + # Single rules + "B023", # Function definition does not bind loop variable + "B028", # No explicit stacklevel keyword argument found + "B904", # Within an except clause, raise exceptions with ... + "C408", # unnecessary-collection-call + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + "FBT001", # Boolean-typed positional argument in function definition + "FBT002", # Boolean default positional argument in function + "PD901", # pandas-df-variable-name + "PERF203", # try-except-in-loop + "PERF401", # manual-list-comprehension + "PLR0911", # too many return statements + "PLR0912", # too many branches + "PLR0913", # too many arguments + "PLR0915", # too many statements + "PLR2004", # magic-value-comparison TODO fix these + "PLW2901", # Outer for loop variable overwritten by inner assignment target + "PT013", # pytest-incorrect-pytest-import + "SIM105", # Use contextlib.suppress() instead of try-except-pass + "TRY003", # Avoid specifying long messages outside the exception class + "TRY300", # Checks for return statements in try blocks + "TRY301", # Checks for raise statements within try blocks ] pydocstyle.convention = "google" isort.required-imports = ["from __future__ import annotations"] isort.split-on-trailing-comma = false +isort.known-first-party = ["pymatgen"] [tool.ruff.format] docstring-code-format = true [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] -"tests/**" = ["ANN201", "D", "PLR0124"] +# PLR2004: magic-value-comparison +# PLR6301: no-self-use +"tests/**" = ["ANN201", "D", "PLR0124", "PLR2004", "PLR6301"] "src/pymatgen/analysis/*" = ["D"] "src/pymatgen/io/*" = ["D"] "dev_scripts/*" = ["D"] diff --git a/src/pymatgen/alchemy/filters.py b/src/pymatgen/alchemy/filters.py index c16c05bf634..923ca1e8056 100644 --- a/src/pymatgen/alchemy/filters.py +++ b/src/pymatgen/alchemy/filters.py @@ -7,14 +7,16 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher from pymatgen.core import get_el_sp from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + class AbstractStructureFilter(MSONable, abc.ABC): """Structures that return True when passed to the test() method are retained during @@ -41,7 +43,7 @@ class ContainsSpecieFilter(AbstractStructureFilter): def __init__(self, species, strict_compare=False, AND=True, exclude=False): """ Args: - species ([Species/Element]): list of species to look for + species (list[SpeciesLike]): species to look for AND: whether all species must be present to pass (or fail) filter. strict_compare: if true, compares objects by specie or element object if false, compares atomic number @@ -155,7 +157,7 @@ def from_dict(cls, dct: dict) -> Self: dct (dict): Dict representation. Returns: - Filter + SpecieProximityFilter """ return cls(**dct["init_args"]) diff --git a/src/pymatgen/alchemy/materials.py b/src/pymatgen/alchemy/materials.py index 701133215e3..c5bded7c26b 100644 --- a/src/pymatgen/alchemy/materials.py +++ b/src/pymatgen/alchemy/materials.py @@ -5,13 +5,14 @@ from __future__ import annotations -import datetime import json import re +from datetime import datetime, timezone from typing import TYPE_CHECKING from warnings import warn from monty.json import MSONable, jsanitize + from pymatgen.core.structure import Structure from pymatgen.io.cif import CifParser from pymatgen.io.vasp.inputs import Poscar @@ -23,9 +24,10 @@ from collections.abc import Sequence from typing import Any - from pymatgen.alchemy.filters import AbstractStructureFilter from typing_extensions import Self + from pymatgen.alchemy.filters import AbstractStructureFilter + class TransformedStructure(MSONable): """Container for new structures that include history of transformations. @@ -300,7 +302,7 @@ def from_cif_str( source = "uploaded cif" source_info = { "source": source, - "datetime": str(datetime.datetime.now(tz=datetime.timezone.utc)), + "datetime": str(datetime.now(tz=timezone.utc)), "original_file": raw_str, "cif_data": cif_dict[cif_keys[0]], } @@ -328,7 +330,7 @@ def from_poscar_str( struct = poscar.structure source_info = { "source": "POSCAR", - "datetime": str(datetime.datetime.now(tz=datetime.timezone.utc)), + "datetime": str(datetime.now(tz=timezone.utc)), "original_file": raw_str, } return cls(struct, transformations, history=[source_info]) @@ -339,7 +341,7 @@ def as_dict(self) -> dict[str, Any]: dct["@module"] = type(self).__module__ dct["@class"] = type(self).__name__ dct["history"] = jsanitize(self.history) - dct["last_modified"] = str(datetime.datetime.now(datetime.timezone.utc)) + dct["last_modified"] = str(datetime.now(timezone.utc)) dct["other_parameters"] = jsanitize(self.other_parameters) return dct @@ -364,13 +366,13 @@ def to_snl(self, authors: list[str], **kwargs) -> StructureNL: history = [] for hist in self.history: snl_metadata = hist.pop("_snl", {}) - history.append( + history += [ { "name": snl_metadata.pop("name", "pymatgen"), "url": snl_metadata.pop("url", "http://pypi.python.org/pypi/pymatgen"), "description": hist, } - ) + ] return StructureNL(self.final_structure, authors, history=history, **kwargs) diff --git a/src/pymatgen/alchemy/transmuters.py b/src/pymatgen/alchemy/transmuters.py index d51ea47c247..207a1bfac35 100644 --- a/src/pymatgen/alchemy/transmuters.py +++ b/src/pymatgen/alchemy/transmuters.py @@ -21,9 +21,10 @@ from collections.abc import Sequence from typing import Callable - from pymatgen.alchemy.filters import AbstractStructureFilter from typing_extensions import Self + from pymatgen.alchemy.filters import AbstractStructureFilter + __author__ = "Shyue Ping Ong, Will Richards" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -102,7 +103,7 @@ def redo_next_change(self) -> None: for ts in self.transformed_structures: ts.redo_next_change() - def append_transformation(self, transformation, extend_collection=False, clear_redo=True): + def append_transformation(self, transformation, extend_collection=False, clear_redo=True) -> list[bool]: """Append a transformation to all TransformedStructures. Args: @@ -116,8 +117,8 @@ def append_transformation(self, transformation, extend_collection=False, clear_r redo, the redo list should not be cleared to allow multiple redos. Returns: - list[bool]: corresponding to initial transformed structures each boolean - describes whether the transformation altered the structure + list[bool]: Each list item is True if the transformation altered the structure + with the corresponding index. """ if self.ncores and transformation.use_multiprocessing: with Pool(self.ncores) as pool: @@ -135,6 +136,9 @@ def append_transformation(self, transformation, extend_collection=False, clear_r new_structures += new self.transformed_structures += new_structures + # len(ts) > 1 checks if the structure has history + return [len(ts) > 1 for ts in self.transformed_structures] + def extend_transformations(self, transformations): """Extend a sequence of transformations to the TransformedStructure. @@ -330,13 +334,10 @@ def batch_write_vasp_input( output_dir: Directory to output files create_directory (bool): Create the directory if not present. Defaults to True. - subfolder: Function to create subdirectory name from - transformed_structure. - e.g. lambda x: x.other_parameters["tags"][0] to use the first - tag. - include_cif (bool): Boolean indication whether to output a CIF as - well. CIF files are generally better supported in visualization - programs. + subfolder: Function to create subdirectory name from transformed_structure. + E.g. lambda x: x.other_parameters["tags"][0] to use the first tag. + include_cif (bool): Pass True to output a CIF as well. CIF files are generally + better supported in visualization programs. **kwargs: Any kwargs supported by vasp_input_set. """ for idx, struct in enumerate(transformed_structures): diff --git a/src/pymatgen/analysis/adsorption.py b/src/pymatgen/analysis/adsorption.py index 88f9f98dc3e..5fe6b7a3742 100644 --- a/src/pymatgen/analysis/adsorption.py +++ b/src/pymatgen/analysis/adsorption.py @@ -12,6 +12,8 @@ from matplotlib import patches from matplotlib.path import Path from monty.serialization import loadfn +from scipy.spatial import Delaunay + from pymatgen import vis from pymatgen.analysis.local_env import VoronoiNN from pymatgen.analysis.structure_matcher import StructureMatcher @@ -20,14 +22,14 @@ from pymatgen.core.surface import generate_all_slabs from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.coord import in_coord_list_pbc -from scipy.spatial import Delaunay if TYPE_CHECKING: import matplotlib.pyplot as plt from numpy.typing import ArrayLike - from pymatgen.core.surface import Slab from typing_extensions import Self + from pymatgen.core.surface import Slab + __author__ = "Joseph Montoya" __copyright__ = "Copyright 2016, The Materials Project" __version__ = "0.1" diff --git a/src/pymatgen/analysis/bond_dissociation.py b/src/pymatgen/analysis/bond_dissociation.py index cbf6ed83845..73b9f2d4fe2 100644 --- a/src/pymatgen/analysis/bond_dissociation.py +++ b/src/pymatgen/analysis/bond_dissociation.py @@ -8,6 +8,7 @@ import networkx as nx from monty.json import MSONable + from pymatgen.analysis.fragmenter import open_ring from pymatgen.analysis.graphs import MoleculeGraph, MolGraphSplitError from pymatgen.analysis.local_env import OpenBabelNN diff --git a/src/pymatgen/analysis/bond_valence.py b/src/pymatgen/analysis/bond_valence.py index a81a71f5462..a4ed8389120 100644 --- a/src/pymatgen/analysis/bond_valence.py +++ b/src/pymatgen/analysis/bond_valence.py @@ -11,6 +11,7 @@ import numpy as np from monty.serialization import loadfn + from pymatgen.core import Element, Species, get_el_sp from pymatgen.symmetry.analyzer import SpacegroupAnalyzer @@ -420,7 +421,7 @@ def _recurse(assigned=None): return [[int(frac_site) for frac_site in assigned[site]] for site in structure] raise ValueError("Valences cannot be assigned!") - def get_oxi_state_decorated_structure(self, structure: Structure): + def get_oxi_state_decorated_structure(self, structure: Structure) -> Structure: """Get an oxidation state decorated structure. This currently works only for ordered structures only. @@ -428,7 +429,7 @@ def get_oxi_state_decorated_structure(self, structure: Structure): structure: Structure to analyze Returns: - A modified structure that is oxidation state decorated. + Structure: modified with oxidation state decorations. Raises: ValueError if the valences cannot be determined. diff --git a/src/pymatgen/analysis/chemenv/connectivity/connected_components.py b/src/pymatgen/analysis/chemenv/connectivity/connected_components.py index a7b361eda31..6acb98df505 100644 --- a/src/pymatgen/analysis/chemenv/connectivity/connected_components.py +++ b/src/pymatgen/analysis/chemenv/connectivity/connected_components.py @@ -13,6 +13,7 @@ from monty.json import MSONable, jsanitize from networkx.algorithms.components import is_connected from networkx.algorithms.traversal import bfs_tree + from pymatgen.analysis.chemenv.connectivity.environment_nodes import EnvironmentNode from pymatgen.analysis.chemenv.utils.chemenv_errors import ChemenvError from pymatgen.analysis.chemenv.utils.graph_utils import get_delta @@ -372,14 +373,15 @@ def __len__(self): def compute_periodicity(self, algorithm="all_simple_paths") -> None: """ Args: - algorithm (): + algorithm (str): Algorithm to use to compute the periodicity vectors. Can be + either "all_simple_paths" or "cycle_basis". """ if algorithm == "all_simple_paths": self.compute_periodicity_all_simple_paths_algorithm() elif algorithm == "cycle_basis": self.compute_periodicity_cycle_basis() else: - raise ValueError(f"Algorithm {algorithm!r} is not allowed to compute periodicity") + raise ValueError(f"{algorithm=} is not allowed to compute periodicity") self._order_periodicity_vectors() def compute_periodicity_all_simple_paths_algorithm(self): @@ -512,7 +514,7 @@ def compute_periodicity_cycle_basis(self) -> None: def make_supergraph(self, multiplicity): """ Args: - multiplicity (): + multiplicity (int): Multiplicity of the super graph. Returns: nx.MultiGraph: Super graph of the connected component. @@ -635,7 +637,8 @@ def periodicity(self): def elastic_centered_graph(self, start_node=None): """ Args: - start_node (): + start_node (Node, optional): Node to start the elastic centering from. + If not provided, the first node in the graph is used. Returns: nx.MultiGraph: Elastic centered subgraph. diff --git a/src/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py b/src/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py index 26495f4a437..99cb354f89c 100644 --- a/src/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py +++ b/src/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py @@ -5,6 +5,7 @@ import logging import numpy as np + from pymatgen.analysis.chemenv.connectivity.structure_connectivity import StructureConnectivity __author__ = "David Waroquiers" diff --git a/src/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py b/src/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py index 4fb4c54cc50..8159a6095fa 100644 --- a/src/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py +++ b/src/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py @@ -9,6 +9,7 @@ import networkx as nx import numpy as np from monty.json import MSONable, jsanitize + from pymatgen.analysis.chemenv.connectivity.connected_components import ConnectedComponent from pymatgen.analysis.chemenv.connectivity.environment_nodes import get_environment_node from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments @@ -75,8 +76,8 @@ def __init__( def environment_subgraph(self, environments_symbols=None, only_atoms=None): """ Args: - environments_symbols (): - only_atoms (): + environments_symbols (list[str]): symbols of the environments to consider. + only_atoms (list[str]): atoms to consider. Returns: nx.MultiGraph: The subgraph of the structure connectivity graph @@ -186,8 +187,8 @@ def setup_environment_subgraph(self, environments_symbols, only_atoms=None): ) self._environment_subgraph.add_node(env_node) else: - # TODO: add the possibility of a "constraint" on the minimum percentage - # of the atoms on the site + # TODO add the possibility of a "constraint" on the minimum percentage + # of the atoms on the site this_site_elements = [ sp.symbol for sp in self.light_structure_environments.structure[isite].species_and_occu ] diff --git a/src/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py b/src/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py index abc7c2b4778..3f814ac3ade 100644 --- a/src/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py +++ b/src/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py @@ -13,6 +13,8 @@ import numpy as np from monty.json import MSONable +from scipy.stats import gmean + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import AllCoordinationGeometries from pymatgen.analysis.chemenv.coordination_environments.voronoi import DetailedVoronoiContainer from pymatgen.analysis.chemenv.utils.chemenv_errors import EquivalentSiteSearchError @@ -27,7 +29,6 @@ from pymatgen.core.operations import SymmOp from pymatgen.core.sites import PeriodicSite from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from scipy.stats import gmean if TYPE_CHECKING: from typing import ClassVar diff --git a/src/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py b/src/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py index 523448b5bde..20aa39158fc 100644 --- a/src/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py +++ b/src/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py @@ -23,6 +23,7 @@ import numpy as np from numpy.linalg import norm, svd + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import MultiWeightsChemenvStrategy from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import ( @@ -612,8 +613,8 @@ def compute_structure_environments( optimization: optimization algorithm Returns: - The StructureEnvironments object containing all the information about the coordination - environments in the structure. + StructureEnvironments: contains all the information about the coordination + environments in the structure. """ time_init = time.process_time() if info is None: diff --git a/src/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py b/src/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py index 5b81229ced8..65f02d2c9a7 100644 --- a/src/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py +++ b/src/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py @@ -17,6 +17,7 @@ from matplotlib.gridspec import GridSpec from matplotlib.patches import Polygon from monty.json import MontyDecoder, MSONable, jsanitize + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import AllCoordinationGeometries from pymatgen.analysis.chemenv.coordination_environments.voronoi import DetailedVoronoiContainer from pymatgen.analysis.chemenv.utils.chemenv_errors import ChemenvError @@ -372,7 +373,7 @@ def from_dict(cls, dct, structure: Structure, detailed_voronoi) -> Self: the structure and the DetailedVoronoiContainer. As an inner (nested) class, the NeighborsSet is not supposed to be used anywhere else that inside the - StructureEnvironments. The from_dict method is thus using the structure and detailed_voronoi when + StructureEnvironments. The from_dict method is thus using the structure and detailed_voronoi when reconstructing itself. These two are both in the StructureEnvironments object. Args: diff --git a/src/pymatgen/analysis/chemenv/coordination_environments/voronoi.py b/src/pymatgen/analysis/chemenv/coordination_environments/voronoi.py index 08f257d3ae2..77fc1ffcbd4 100644 --- a/src/pymatgen/analysis/chemenv/coordination_environments/voronoi.py +++ b/src/pymatgen/analysis/chemenv/coordination_environments/voronoi.py @@ -9,6 +9,8 @@ import matplotlib.pyplot as plt import numpy as np from monty.json import MSONable +from scipy.spatial import Voronoi + from pymatgen.analysis.chemenv.utils.coordination_geometry_utils import ( get_lower_and_upper_f, rectangle_surface_intersection, @@ -18,7 +20,6 @@ from pymatgen.analysis.chemenv.utils.math_utils import normal_cdf_step from pymatgen.core.sites import PeriodicSite from pymatgen.core.structure import Structure -from scipy.spatial import Voronoi if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py b/src/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py index 358d03d39a0..edf391b023b 100644 --- a/src/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py +++ b/src/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py @@ -7,11 +7,12 @@ import numpy as np from numpy.linalg import norm -from pymatgen.analysis.chemenv.utils.chemenv_errors import SolidAngleError from scipy.integrate import quad from scipy.interpolate import UnivariateSpline from scipy.spatial import ConvexHull +from pymatgen.analysis.chemenv.utils.chemenv_errors import SolidAngleError + if TYPE_CHECKING: from typing import Callable diff --git a/src/pymatgen/analysis/chemenv/utils/func_utils.py b/src/pymatgen/analysis/chemenv/utils/func_utils.py index d8437cde208..5bae04c94a9 100644 --- a/src/pymatgen/analysis/chemenv/utils/func_utils.py +++ b/src/pymatgen/analysis/chemenv/utils/func_utils.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.analysis.chemenv.utils.math_utils import ( power2_decreasing_exp, power2_inverse_decreasing, diff --git a/src/pymatgen/analysis/chemenv/utils/scripts_utils.py b/src/pymatgen/analysis/chemenv/utils/scripts_utils.py index e153a7fcb63..ca0fe90c08c 100644 --- a/src/pymatgen/analysis/chemenv/utils/scripts_utils.py +++ b/src/pymatgen/analysis/chemenv/utils/scripts_utils.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import ( SimpleAbundanceChemenvStrategy, SimplestChemenvStrategy, diff --git a/src/pymatgen/analysis/chempot_diagram.py b/src/pymatgen/analysis/chempot_diagram.py index 769c3a69b8f..775775cf8b9 100644 --- a/src/pymatgen/analysis/chempot_diagram.py +++ b/src/pymatgen/analysis/chempot_diagram.py @@ -33,12 +33,13 @@ import plotly.express as px from monty.json import MSONable from plotly.graph_objects import Figure, Mesh3d, Scatter, Scatter3d +from scipy.spatial import ConvexHull, HalfspaceIntersection + from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram from pymatgen.core.composition import Composition, Element from pymatgen.util.coord import Simplex from pymatgen.util.due import Doi, due from pymatgen.util.string import htmlify -from scipy.spatial import ConvexHull, HalfspaceIntersection if TYPE_CHECKING: from pymatgen.entries.computed_entries import ComputedEntry @@ -167,7 +168,7 @@ def get_plot( (in eV/atom), helping provide visual clarity. Defaults to 1.0. Returns: - A Plotly Figure object + plotly.graph_objects.Figure """ if elements: elems = [Element(str(e)) for e in elements] @@ -294,7 +295,7 @@ def _get_2d_plot(self, elements: list[Element], label_stable: bool | None, eleme draw_domains[formula] = pts_2d layout = plotly_layouts["default_layout_2d"].copy() - layout.update(self._get_axis_layout_dict(elements)) + layout |= self._get_axis_layout_dict(elements) if label_stable: layout["annotations"] = annotations @@ -365,7 +366,7 @@ def _get_3d_plot( domain_simplexes[formula] = simplexes layout = plotly_layouts["default_layout_3d"].copy() - layout["scene"].update(self._get_axis_layout_dict(elements)) + layout["scene"] |= self._get_axis_layout_dict(elements) layout["scene"]["annotations"] = None if label_stable: @@ -558,7 +559,7 @@ def _get_annotation(ann_loc: np.ndarray, formula: str) -> dict[str, str | float] """Get a Plotly annotation dict given a formula and location.""" formula = htmlify(formula) annotation = plotly_layouts["default_annotation_layout"].copy() - annotation.update({"x": ann_loc[0], "y": ann_loc[1], "text": formula}) + annotation |= {"x": ann_loc[0], "y": ann_loc[1], "text": formula} if len(ann_loc) == 3: annotation["z"] = ann_loc[2] return annotation diff --git a/src/pymatgen/analysis/cost.py b/src/pymatgen/analysis/cost.py index 8ed4689301c..e5b34ed623c 100644 --- a/src/pymatgen/analysis/cost.py +++ b/src/pymatgen/analysis/cost.py @@ -17,6 +17,7 @@ import scipy.constants as const from monty.design_patterns import singleton + from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram from pymatgen.core import Composition, Element from pymatgen.util.provenance import is_valid_bibtex @@ -117,10 +118,10 @@ def __init__(self): class CostAnalyzer: """Given a CostDB, figures out the minimum cost solutions via convex hull.""" - def __init__(self, costdb): + def __init__(self, costdb: CostDB) -> None: """ Args: - costdb (): Cost database. + costdb (CostDB): Cost database to use. """ self.costdb = costdb diff --git a/src/pymatgen/analysis/diffraction/core.py b/src/pymatgen/analysis/diffraction/core.py index 4a4dce2623a..4e852b5e8f8 100644 --- a/src/pymatgen/analysis/diffraction/core.py +++ b/src/pymatgen/analysis/diffraction/core.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np + from pymatgen.core.spectrum import Spectrum from pymatgen.util.plotting import add_fig_kwargs, pretty_plot diff --git a/src/pymatgen/analysis/diffraction/neutron.py b/src/pymatgen/analysis/diffraction/neutron.py index d9ef7328935..59d36759ba0 100644 --- a/src/pymatgen/analysis/diffraction/neutron.py +++ b/src/pymatgen/analysis/diffraction/neutron.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.analysis.diffraction.core import ( AbstractDiffractionPatternCalculator, DiffractionPattern, diff --git a/src/pymatgen/analysis/diffraction/tem.py b/src/pymatgen/analysis/diffraction/tem.py index 98103a22945..d20699f7699 100644 --- a/src/pymatgen/analysis/diffraction/tem.py +++ b/src/pymatgen/analysis/diffraction/tem.py @@ -11,6 +11,7 @@ import pandas as pd import plotly.graph_objects as go import scipy.constants as sc + from pymatgen.analysis.diffraction.core import AbstractDiffractionPatternCalculator from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.string import latexify_spacegroup, unicodeify_spacegroup @@ -18,6 +19,7 @@ if TYPE_CHECKING: from numpy.typing import NDArray + from pymatgen.core import Structure __author__ = "Frank Wan, Jason Liang" @@ -336,8 +338,7 @@ def is_parallel( plane: Tuple3Ints, other_plane: Tuple3Ints, ) -> bool: - """ - Checks if two hkl planes are parallel in reciprocal space. + """Checks if two hkl planes are parallel in reciprocal space. Args: structure (Structure): The input structure. @@ -345,7 +346,7 @@ def is_parallel( other_plane (3-tuple): The other plane to be compared. Returns: - boolean + bool: True if the planes are parallel, False otherwise. """ phi = self.get_interplanar_angle(structure, plane, other_plane) return phi in (180, 0) or np.isnan(phi) diff --git a/src/pymatgen/analysis/diffraction/xrd.py b/src/pymatgen/analysis/diffraction/xrd.py index 654cab6d3fe..31ad7e98017 100644 --- a/src/pymatgen/analysis/diffraction/xrd.py +++ b/src/pymatgen/analysis/diffraction/xrd.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.analysis.diffraction.core import ( AbstractDiffractionPatternCalculator, DiffractionPattern, @@ -100,7 +101,7 @@ def __init__(self, wavelength="CuKa", symprec: float = 0, debye_waller_factors=N """Initialize the XRD calculator with a given radiation. Args: - wavelength (str/float): The wavelength can be specified as either a + wavelength (str | float): The wavelength can be specified as either a float or a string. If it is a string, it must be one of the supported definitions in the AVAILABLE_RADIATION class variable, which provides useful commonly used wavelengths. diff --git a/src/pymatgen/analysis/dimensionality.py b/src/pymatgen/analysis/dimensionality.py index 93746d4f5db..7ce19444e03 100644 --- a/src/pymatgen/analysis/dimensionality.py +++ b/src/pymatgen/analysis/dimensionality.py @@ -29,6 +29,7 @@ import networkx as nx import numpy as np from networkx.readwrite import json_graph + from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph from pymatgen.analysis.local_env import JmolNN from pymatgen.analysis.structure_analyzer import get_max_bond_lengths diff --git a/src/pymatgen/analysis/elasticity/elastic.py b/src/pymatgen/analysis/elasticity/elastic.py index 805b6017013..ef931a6dabf 100644 --- a/src/pymatgen/analysis/elasticity/elastic.py +++ b/src/pymatgen/analysis/elasticity/elastic.py @@ -13,23 +13,25 @@ import numpy as np import sympy as sp +from scipy.integrate import quad +from scipy.optimize import root +from scipy.special import factorial + from pymatgen.analysis.elasticity.strain import Strain from pymatgen.analysis.elasticity.stress import Stress from pymatgen.core.tensors import DEFAULT_QUAD, SquareTensor, Tensor, TensorCollection, get_uvec from pymatgen.core.units import Unit from pymatgen.util.due import Doi, due -from scipy.integrate import quad -from scipy.optimize import root -from scipy.special import factorial if TYPE_CHECKING: from collections.abc import Sequence from typing import Literal from numpy.typing import ArrayLike - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + __author__ = "Joseph Montoya" __copyright__ = "Copyright 2012, The Materials Project" @@ -53,9 +55,9 @@ class NthOrderElasticTensor(Tensor): def __new__(cls, input_array, check_rank=None, tol: float = 1e-4) -> Self: """ Args: - input_array (): - check_rank (): - tol (): + input_array (np.ndarray): input array for tensor + check_rank (int): rank of tensor to check + tol (float): tolerance for initial symmetry test of tensor """ obj = super().__new__(cls, input_array, check_rank=check_rank) if obj.rank % 2 != 0: @@ -522,7 +524,7 @@ class ComplianceTensor(Tensor): def __new__(cls, s_array) -> Self: """ Args: - s_array (): + s_array (np.ndarray): input array for tensor """ vscale = np.ones((6, 6)) vscale[3:] *= 2 diff --git a/src/pymatgen/analysis/elasticity/strain.py b/src/pymatgen/analysis/elasticity/strain.py index 13092b5daf2..99fe716a249 100644 --- a/src/pymatgen/analysis/elasticity/strain.py +++ b/src/pymatgen/analysis/elasticity/strain.py @@ -12,6 +12,7 @@ import numpy as np import scipy + from pymatgen.core.lattice import Lattice from pymatgen.core.tensors import SquareTensor, symmetry_reduce @@ -20,9 +21,10 @@ from typing import Literal from numpy.typing import ArrayLike - from pymatgen.core.structure import Structure from typing_extensions import Self + from pymatgen.core.structure import Structure + __author__ = "Joseph Montoya" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Maarten de Jong, Mark Asta, Anubhav Jain" diff --git a/src/pymatgen/analysis/elasticity/stress.py b/src/pymatgen/analysis/elasticity/stress.py index 720eb00dc4f..3e5765130ca 100644 --- a/src/pymatgen/analysis/elasticity/stress.py +++ b/src/pymatgen/analysis/elasticity/stress.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.core.tensors import SquareTensor if TYPE_CHECKING: diff --git a/src/pymatgen/analysis/energy_models.py b/src/pymatgen/analysis/energy_models.py index 501287f7f01..fa3c6a1d091 100644 --- a/src/pymatgen/analysis/energy_models.py +++ b/src/pymatgen/analysis/energy_models.py @@ -10,13 +10,15 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.analysis.ewald import EwaldSummation from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + __version__ = "0.1" diff --git a/src/pymatgen/analysis/eos.py b/src/pymatgen/analysis/eos.py index fc526884537..ebb5bcbc2db 100644 --- a/src/pymatgen/analysis/eos.py +++ b/src/pymatgen/analysis/eos.py @@ -13,9 +13,10 @@ from typing import TYPE_CHECKING import numpy as np +from scipy.optimize import leastsq, minimize + from pymatgen.core.units import FloatWithUnit from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig, pretty_plot -from scipy.optimize import leastsq, minimize if TYPE_CHECKING: from typing import ClassVar @@ -37,8 +38,8 @@ class EOSBase(ABC): def __init__(self, volumes, energies): """ Args: - volumes (list/numpy.array): volumes in Ang^3 - energies (list/numpy.array): energy in eV. + volumes (Sequence[float]): in Ang^3. + energies (Sequence[float]): in eV. """ self.volumes = np.array(volumes) self.energies = np.array(energies) @@ -560,8 +561,8 @@ def fit(self, volumes, energies): """Fit energies as function of volumes. Args: - volumes (list/np.array) - energies (list/np.array) + volumes (Sequence[float]): in Ang^3 + energies (Sequence[float]): in eV Returns: EOSBase: EOSBase object diff --git a/src/pymatgen/analysis/ewald.py b/src/pymatgen/analysis/ewald.py index 46fb6c523ee..4cf0964ef90 100644 --- a/src/pymatgen/analysis/ewald.py +++ b/src/pymatgen/analysis/ewald.py @@ -5,17 +5,18 @@ import bisect import math from copy import copy, deepcopy -from datetime import datetime +from datetime import datetime, timezone from typing import TYPE_CHECKING from warnings import warn import numpy as np from monty.json import MSONable -from pymatgen.core.structure import Structure -from pymatgen.util.due import Doi, due from scipy import constants from scipy.special import comb, erfc +from pymatgen.core.structure import Structure +from pymatgen.util.due import Doi, due + if TYPE_CHECKING: from typing import Any @@ -534,7 +535,7 @@ def __init__(self, matrix, m_list, num_to_return=1, algo=ALGO_FAST): # sets this to true it breaks the recursion and stops the search. self._finished = False - self._start_time = datetime.utcnow() + self._start_time = datetime.now(tz=timezone.utc) self.minimize_matrix() @@ -604,7 +605,7 @@ def best_case(self, matrix, m_list, indices_left): interaction_correction = np.sum(step3) if self._algo == self.ALGO_TIME_LIMIT: - elapsed_time = datetime.utcnow() - self._start_time + elapsed_time = datetime.now(tz=timezone.utc) - self._start_time speedup_parameter = elapsed_time.total_seconds() / 1800 avg_int = np.sum(interaction_matrix, axis=None) avg_frac = np.mean(np.outer(1 - fractions, 1 - fractions)) diff --git a/src/pymatgen/analysis/ferroelectricity/polarization.py b/src/pymatgen/analysis/ferroelectricity/polarization.py index 2e0175fdf3b..7ce1cd2e4d8 100644 --- a/src/pymatgen/analysis/ferroelectricity/polarization.py +++ b/src/pymatgen/analysis/ferroelectricity/polarization.py @@ -48,16 +48,18 @@ from typing import TYPE_CHECKING import numpy as np +from scipy.interpolate import UnivariateSpline + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure -from scipy.interpolate import UnivariateSpline if TYPE_CHECKING: from collections.abc import Sequence - from pymatgen.core.sites import PeriodicSite from typing_extensions import Self + from pymatgen.core.sites import PeriodicSite + __author__ = "Tess Smidt" __copyright__ = "Copyright 2017, The Materials Project" @@ -73,10 +75,7 @@ def zval_dict_from_potcar(potcar) -> dict[str, float]: potcar: Potcar object """ - zval_dict = {} - for p in potcar: - zval_dict.update({p.element: p.ZVAL}) - return zval_dict + return {p.element: p.ZVAL for p in potcar} def calc_ionic(site: PeriodicSite, structure: Structure, zval: float) -> np.ndarray: diff --git a/src/pymatgen/analysis/fragmenter.py b/src/pymatgen/analysis/fragmenter.py index 98ad92a2b24..7458a46dfeb 100644 --- a/src/pymatgen/analysis/fragmenter.py +++ b/src/pymatgen/analysis/fragmenter.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.analysis.graphs import MoleculeGraph, MolGraphSplitError from pymatgen.analysis.local_env import OpenBabelNN, metal_edge_extender from pymatgen.io.babel import BabelMolAdaptor diff --git a/src/pymatgen/analysis/graphs.py b/src/pymatgen/analysis/graphs.py index 72c4c98bcbb..a23ce49ef36 100644 --- a/src/pymatgen/analysis/graphs.py +++ b/src/pymatgen/analysis/graphs.py @@ -20,12 +20,13 @@ from monty.json import MSONable from networkx.drawing.nx_agraph import write_dot from networkx.readwrite import json_graph +from scipy.spatial import KDTree +from scipy.stats import describe + from pymatgen.core import Lattice, Molecule, PeriodicSite, Structure from pymatgen.core.structure import FunctionalGroups from pymatgen.util.coord import lattice_points_in_supercell from pymatgen.vis.structure_vtk import EL_COLORS -from scipy.spatial import KDTree -from scipy.stats import describe try: import igraph @@ -38,10 +39,11 @@ from igraph import Graph from numpy.typing import ArrayLike + from typing_extensions import Self + from pymatgen.analysis.local_env import NearNeighbors from pymatgen.core import Species from pymatgen.util.typing import Tuple3Ints - from typing_extensions import Self logger = logging.getLogger(__name__) diff --git a/src/pymatgen/analysis/hhi.py b/src/pymatgen/analysis/hhi.py index b893fe3a131..85708e78366 100644 --- a/src/pymatgen/analysis/hhi.py +++ b/src/pymatgen/analysis/hhi.py @@ -14,6 +14,7 @@ import os from monty.design_patterns import singleton + from pymatgen.core import Composition, Element __author__ = "Anubhav Jain" diff --git a/src/pymatgen/analysis/interface_reactions.py b/src/pymatgen/analysis/interface_reactions.py index 908cd538ff0..71a481e8c80 100644 --- a/src/pymatgen/analysis/interface_reactions.py +++ b/src/pymatgen/analysis/interface_reactions.py @@ -15,6 +15,7 @@ from monty.json import MSONable from pandas import DataFrame from plotly.graph_objects import Figure, Scatter + from pymatgen.analysis.phase_diagram import GrandPotentialPhaseDiagram, PhaseDiagram from pymatgen.analysis.reaction_calculator import Reaction from pymatgen.core.composition import Composition @@ -110,7 +111,7 @@ def __init__( # Factor is the compositional ratio between composition self.c1 and # processed composition self.comp1. For example, the factor for - # Composition('SiO2') and Composition('O') is 2.0. This factor will be used + # Composition('SiO2') and Composition('O') is 2.0. This factor will be used # to convert mixing ratio in self.comp1 - self.comp2 tie line to that in # self.c1 - self.c2 tie line. self.factor1 = 1.0 diff --git a/src/pymatgen/analysis/interfaces/coherent_interfaces.py b/src/pymatgen/analysis/interfaces/coherent_interfaces.py index e52dab38e04..3e145fc5618 100644 --- a/src/pymatgen/analysis/interfaces/coherent_interfaces.py +++ b/src/pymatgen/analysis/interfaces/coherent_interfaces.py @@ -7,11 +7,12 @@ import numpy as np from numpy.testing import assert_allclose +from scipy.linalg import polar + from pymatgen.analysis.elasticity.strain import Deformation from pymatgen.analysis.interfaces.zsl import ZSLGenerator, fast_norm from pymatgen.core.interface import Interface, label_termination from pymatgen.core.surface import SlabGenerator -from scipy.linalg import polar if TYPE_CHECKING: from collections.abc import Iterator, Sequence diff --git a/src/pymatgen/analysis/interfaces/substrate_analyzer.py b/src/pymatgen/analysis/interfaces/substrate_analyzer.py index 62b2a4a5152..41746fdd24d 100644 --- a/src/pymatgen/analysis/interfaces/substrate_analyzer.py +++ b/src/pymatgen/analysis/interfaces/substrate_analyzer.py @@ -11,9 +11,10 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike + from typing_extensions import Self + from pymatgen.core import Structure from pymatgen.util.typing import Tuple3Ints - from typing_extensions import Self @dataclass diff --git a/src/pymatgen/analysis/interfaces/zsl.py b/src/pymatgen/analysis/interfaces/zsl.py index fdd17dbcdb0..8aae62a46c1 100644 --- a/src/pymatgen/analysis/interfaces/zsl.py +++ b/src/pymatgen/analysis/interfaces/zsl.py @@ -8,6 +8,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.util.due import Doi, due from pymatgen.util.numba import njit diff --git a/src/pymatgen/analysis/local_env.py b/src/pymatgen/analysis/local_env.py index e5a202623d9..ac27b579472 100644 --- a/src/pymatgen/analysis/local_env.py +++ b/src/pymatgen/analysis/local_env.py @@ -19,12 +19,13 @@ import numpy as np from monty.dev import deprecated, requires from monty.serialization import loadfn +from ruamel.yaml import YAML +from scipy.spatial import Voronoi + from pymatgen.analysis.bond_valence import BV_PARAMS, BVAnalyzer from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.core import Element, IStructure, PeriodicNeighbor, PeriodicSite, Site, Species, Structure -from ruamel.yaml import YAML -from scipy.spatial import Voronoi try: from openbabel import openbabel @@ -34,9 +35,10 @@ if TYPE_CHECKING: from typing import Any + from typing_extensions import Self + from pymatgen.core.composition import SpeciesLike from pymatgen.util.typing import Tuple3Ints - from typing_extensions import Self __author__ = "Shyue Ping Ong, Geoffroy Hautier, Sai Jayaraman, " @@ -1159,7 +1161,7 @@ def _is_in_targets(site, targets): targets ([Element]) List of elements Returns: - boolean: Whether this site contains a certain list of elements + bool: Whether this site contains a certain list of elements """ elems = _get_elements(site) return all(elem in targets for elem in elems) @@ -1216,7 +1218,7 @@ def __init__( # Update any user preference elemental radii if el_radius_updates: - self.el_radius.update(el_radius_updates) + self.el_radius |= el_radius_updates @property def structures_allowed(self) -> bool: @@ -1982,7 +1984,7 @@ def get_okeeffe_distance_prediction(el1, el2): """Get an estimate of the bond valence parameter (bond length) using the derived parameters from 'Atoms Sizes and Bond Lengths in Molecules and Crystals' (O'Keeffe & Brese, 1991). The estimate is based on two - experimental parameters: r and c. The value for r is based off radius, + experimental parameters: r and c. The value for r is based off radius, while c is (usually) the Allred-Rochow electronegativity. Values used are *not* generated from pymatgen, and are found in 'okeeffe_params.json'. @@ -2753,7 +2755,7 @@ def get_type(self, index): raise ValueError("Index for getting order parameter type out-of-bounds!") return self._types[index] - def get_parameters(self, index): + def get_parameters(self, index: int) -> list[float]: """Get list of floats that represents the parameters associated with calculation of the order @@ -2762,12 +2764,10 @@ def get_parameters(self, index): inputted because of processing out of efficiency reasons. Args: - index (int): - index of order parameter for which associated parameters - are to be returned. + index (int): index of order parameter for which to return associated params. Returns: - [float]: parameters of a given OP. + list[float]: parameters of a given OP. """ if index < 0 or index >= len(self._types): raise ValueError("Index for getting parameters associated with order parameter calculation out-of-bounds!") diff --git a/src/pymatgen/analysis/magnetism/analyzer.py b/src/pymatgen/analysis/magnetism/analyzer.py index 59359e773fd..720a427a0d5 100644 --- a/src/pymatgen/analysis/magnetism/analyzer.py +++ b/src/pymatgen/analysis/magnetism/analyzer.py @@ -13,6 +13,10 @@ import numpy as np from monty.serialization import loadfn +from ruamel.yaml.error import MarkedYAMLError +from scipy.signal import argrelextrema +from scipy.stats import gaussian_kde + from pymatgen.core.structure import DummySpecies, Element, Species, Structure from pymatgen.electronic_structure.core import Magmom from pymatgen.symmetry.analyzer import SpacegroupAnalyzer @@ -20,9 +24,6 @@ from pymatgen.transformations.advanced_transformations import MagOrderingTransformation, MagOrderParameterConstraint from pymatgen.transformations.standard_transformations import AutoOxiStateDecorationTransformation from pymatgen.util.due import Doi, due -from ruamel.yaml.error import MarkedYAMLError -from scipy.signal import argrelextrema -from scipy.stats import gaussian_kde if TYPE_CHECKING: from typing import Any @@ -475,7 +476,7 @@ def ordering(self) -> Ordering: ferro/ferrimagnetic is self.threshold_ordering and defaults to 1e-8. Returns: - Ordering: Enum with values FM: ferromagnetic, FiM: ferrimagnetic, + Ordering: Enum with values FM: ferromagnetic, FiM: ferrimagnetic, AFM: antiferromagnetic, NM: non-magnetic or Unknown. Unknown is returned if magnetic moments are not defined or structure is not collinear (in which case a warning is issued). diff --git a/src/pymatgen/analysis/magnetism/heisenberg.py b/src/pymatgen/analysis/magnetism/heisenberg.py index ac627f903b1..35f4c12f116 100644 --- a/src/pymatgen/analysis/magnetism/heisenberg.py +++ b/src/pymatgen/analysis/magnetism/heisenberg.py @@ -15,6 +15,7 @@ import pandas as pd from monty.json import MSONable, jsanitize from monty.serialization import dumpfn + from pymatgen.analysis.graphs import StructureGraph from pymatgen.analysis.local_env import MinimumDistanceNN from pymatgen.analysis.magnetism import CollinearMagneticStructureAnalyzer, Ordering diff --git a/src/pymatgen/analysis/magnetism/jahnteller.py b/src/pymatgen/analysis/magnetism/jahnteller.py index ae53c08473b..61919cdaa19 100644 --- a/src/pymatgen/analysis/magnetism/jahnteller.py +++ b/src/pymatgen/analysis/magnetism/jahnteller.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Literal, cast import numpy as np + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.analysis.local_env import LocalStructOrderParams, get_neighbors_of_site_with_index from pymatgen.core import Species, get_el_sp @@ -271,7 +272,7 @@ def is_jahn_teller_active( quite distorted, this threshold is smaller than one might expect Returns: - boolean, True if might be Jahn-Teller active, False if not + bool: True if might be Jahn-Teller active, False if not """ active = False diff --git a/src/pymatgen/analysis/molecule_matcher.py b/src/pymatgen/analysis/molecule_matcher.py index 1b779e6898c..0474bfd88a4 100644 --- a/src/pymatgen/analysis/molecule_matcher.py +++ b/src/pymatgen/analysis/molecule_matcher.py @@ -22,12 +22,14 @@ import numpy as np from monty.dev import requires from monty.json import MSONable -from pymatgen.core.structure import Molecule from scipy.optimize import linear_sum_assignment from scipy.spatial.distance import cdist +from pymatgen.core.structure import Molecule + try: from openbabel import openbabel + from pymatgen.io.babel import BabelMolAdaptor except ImportError: openbabel = BabelMolAdaptor = None # type: ignore[misc] @@ -494,7 +496,7 @@ def _is_molecule_linear(self, mol): mol: The molecule. OpenBabel OBMol object. Returns: - Boolean value. + bool """ if mol.NumAtoms() < 3: return True @@ -587,7 +589,7 @@ def fit(self, mol1, mol2): mol2: Second molecule. OpenBabel OBMol or pymatgen Molecule object Returns: - A boolean value indicates whether two molecules are the same. + bool: True if two molecules are the same. """ return self.get_rmsd(mol1, mol2) < self._tolerance @@ -596,7 +598,7 @@ def get_rmsd(self, mol1, mol2): Returns: RMSD if topology of the two molecules are the same - Infinite if the topology is different + Infinite if the topology is different """ label1, label2 = self._mapper.uniform_labels(mol1, mol2) if label1 is None or label2 is None: diff --git a/src/pymatgen/analysis/molecule_structure_comparator.py b/src/pymatgen/analysis/molecule_structure_comparator.py index 28400b65a34..b52ed667042 100644 --- a/src/pymatgen/analysis/molecule_structure_comparator.py +++ b/src/pymatgen/analysis/molecule_structure_comparator.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.util.due import Doi, due if TYPE_CHECKING: @@ -193,7 +194,7 @@ def are_equal(self, mol1, mol2) -> bool: def get_13_bonds(priority_bonds): """ Args: - priority_bonds (): + priority_bonds (list[tuple]): 12 bonds Returns: tuple: 13 bonds diff --git a/src/pymatgen/analysis/nmr.py b/src/pymatgen/analysis/nmr.py index 91a7ccd21d4..288309513f8 100644 --- a/src/pymatgen/analysis/nmr.py +++ b/src/pymatgen/analysis/nmr.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, NamedTuple import numpy as np + from pymatgen.core import Site, Species from pymatgen.core.tensors import SquareTensor from pymatgen.core.units import FloatWithUnit @@ -121,9 +122,9 @@ def from_maryland_notation(cls, sigma_iso, omega, kappa) -> Self: Initialize from Maryland notation. Args: - sigma_iso (): - omega (): - kappa (): + sigma_iso (float): isotropic chemical shielding + omega (float): anisotropy + kappa (float): asymmetry parameter Returns: ChemicalShielding diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 6369e2369a0..c5632f82812 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -21,6 +21,11 @@ from matplotlib.colors import LinearSegmentedColormap, Normalize from matplotlib.font_manager import FontProperties from monty.json import MontyDecoder, MSONable +from scipy import interpolate +from scipy.optimize import minimize +from scipy.spatial import ConvexHull +from tqdm import tqdm + from pymatgen.analysis.reaction_calculator import Reaction, ReactionError from pymatgen.core import DummySpecies, Element, get_el_sp from pymatgen.core.composition import Composition @@ -29,10 +34,6 @@ from pymatgen.util.due import Doi, due from pymatgen.util.plotting import pretty_plot from pymatgen.util.string import htmlify, latexify -from scipy import interpolate -from scipy.optimize import minimize -from scipy.spatial import ConvexHull -from tqdm import tqdm if TYPE_CHECKING: from collections.abc import Collection, Iterator, Sequence @@ -95,9 +96,7 @@ def energy(self) -> float: def as_dict(self): """Get MSONable dict representation of PDEntry.""" - return_dict = super().as_dict() - return_dict.update({"name": self.name, "attribute": self.attribute}) - return return_dict + return super().as_dict() | {"name": self.name, "attribute": self.attribute} @classmethod def from_dict(cls, dct: dict) -> Self: @@ -854,7 +853,7 @@ def get_decomp_and_phase_separation_energy( **kwargs: Passed to get_decomp_and_e_above_hull. Returns: - tuple[decomp, energy]: The decomposition is given as a dict of {PDEntry, amount} + tuple[decomp, energy]: The decomposition is given as a dict of {PDEntry, amount} for all entries in the decomp reaction where amount is the amount of the fractional composition. The phase separation energy is given per atom. """ @@ -1548,7 +1547,7 @@ class PatchedPhaseDiagram(PhaseDiagram): Note that this does not mean that all these entries are actually used in the phase diagram. For example, this includes the positive formation energy entries that are filtered out before Phase Diagram construction. - min_entries (list[PDEntry]): List of the lowest energy entries for each composition + min_entries (list[PDEntry]): List of the lowest energy entries for each composition in the data provided for Phase Diagram construction. el_refs (list[PDEntry]): List of elemental references for the phase diagrams. These are entries corresponding to the lowest energy element entries for @@ -1572,8 +1571,8 @@ def __init__( the entries themselves and are sorted alphabetically. If specified, element ordering (e.g. for pd coordinates) is preserved. - keep_all_spaces (bool): Boolean control on whether to keep chemical spaces - that are subspaces of other spaces. + keep_all_spaces (bool): Pass True to keep chemical spaces that are subspaces + of other spaces. verbose (bool): Whether to show progress bar during convex hull construction. """ if elements is None: @@ -1902,7 +1901,7 @@ def __init__(self, entry1, entry2, all_entries, tol: float = 1e-4, float_fmt="%. """ elem_set = set() for entry in [entry1, entry2]: - elem_set.update([el.symbol for el in entry.elements]) + elem_set |= {el.symbol for el in entry.elements} elements = tuple(elem_set) # Fix elements to ensure order. @@ -2947,7 +2946,7 @@ def _create_plotly_element_annotations(self): for d in ["xref", "yref"]: annotation.pop(d) # Scatter3d cannot contain xref, yref if self._dim == 3: - annotation.update({"x": y, "y": x}) + annotation.update(x=x, y=y) if entry.composition.is_element: z = 0.9 * self._min_energy # place label 10% above base @@ -3095,26 +3094,24 @@ def get_marker_props(coords, entries): if highlight_entries: highlight_markers = plotly_layouts["default_unary_marker_settings"].copy() highlight_markers.update( - { - "x": [0] * len(highlight_props["y"]), - "y": list(highlight_props["x"]), - "name": "Highlighted", - "marker": { - "color": "mediumvioletred", - "size": 22, - "line": {"color": "black", "width": 2}, - "symbol": "square", - }, - "opacity": 0.9, - "hovertext": highlight_props["texts"], - "error_y": { - "array": list(highlight_props["uncertainties"]), - "type": "data", - "color": "gray", - "thickness": 2.5, - "width": 5, - }, - } + x=[0] * len(highlight_props["y"]), + y=list(highlight_props["x"]), + name="Highlighted", + marker={ + "color": "mediumvioletred", + "size": 22, + "line": {"color": "black", "width": 2}, + "symbol": "square", + }, + opacity=0.9, + hovertext=highlight_props["texts"], + error_y={ + "array": list(highlight_props["uncertainties"]), + "type": "data", + "color": "gray", + "thickness": 2.5, + "width": 5, + }, ) if self._dim == 2: @@ -3136,22 +3133,20 @@ def get_marker_props(coords, entries): "width": 5, }, ) - unstable_markers.update( - { - "x": list(unstable_props["x"]), - "y": list(unstable_props["y"]), - "name": "Above Hull", - "marker": { - "color": unstable_props["energies"], - "colorscale": plotly_layouts["unstable_colorscale"], - "size": 7, - "symbol": "diamond", - "line": {"color": "black", "width": 1}, - "opacity": 0.8, - }, - "hovertext": unstable_props["texts"], - } - ) + unstable_markers |= { + "x": list(unstable_props["x"]), + "y": list(unstable_props["y"]), + "name": "Above Hull", + "marker": { + "color": unstable_props["energies"], + "colorscale": plotly_layouts["unstable_colorscale"], + "size": 7, + "symbol": "diamond", + "line": {"color": "black", "width": 1}, + "opacity": 0.8, + }, + "hovertext": unstable_props["texts"], + } if highlight_entries: highlight_markers = plotly_layouts["default_binary_marker_settings"].copy() highlight_markers.update( @@ -3179,209 +3174,191 @@ def get_marker_props(coords, entries): stable_markers = plotly_layouts["default_ternary_2d_marker_settings"].copy() unstable_markers = plotly_layouts["default_ternary_2d_marker_settings"].copy() - stable_markers.update( - { - "a": list(stable_props["x"]), - "b": list(stable_props["y"]), - "c": list(stable_props["z"]), - "name": "Stable", - "hovertext": stable_props["texts"], - "marker": { - "color": "green", - "line": {"width": 2.0, "color": "black"}, - "symbol": "circle", - "size": 15, + stable_markers |= { + "a": list(stable_props["x"]), + "b": list(stable_props["y"]), + "c": list(stable_props["z"]), + "name": "Stable", + "hovertext": stable_props["texts"], + "marker": { + "color": "green", + "line": {"width": 2.0, "color": "black"}, + "symbol": "circle", + "size": 15, + }, + } + unstable_markers |= { + "a": unstable_props["x"], + "b": unstable_props["y"], + "c": unstable_props["z"], + "name": "Above Hull", + "hovertext": unstable_props["texts"], + "marker": { + "color": unstable_props["energies"], + "opacity": 0.8, + "colorscale": plotly_layouts["unstable_colorscale"], + "line": {"width": 1, "color": "black"}, + "size": 7, + "symbol": "diamond", + "colorbar": { + "title": "Energy Above Hull
(eV/atom)", + "x": 0, + "y": 1, + "yanchor": "top", + "xpad": 0, + "ypad": 0, + "thickness": 0.02, + "thicknessmode": "fraction", + "len": 0.5, }, - } - ) - unstable_markers.update( - { - "a": unstable_props["x"], - "b": unstable_props["y"], - "c": unstable_props["z"], - "name": "Above Hull", - "hovertext": unstable_props["texts"], + }, + } + if highlight_entries: + highlight_markers = plotly_layouts["default_ternary_2d_marker_settings"].copy() + highlight_markers |= { + "a": list(highlight_props["x"]), + "b": list(highlight_props["y"]), + "c": list(highlight_props["z"]), + "name": "Highlighted", + "hovertext": highlight_props["texts"], "marker": { - "color": unstable_props["energies"], - "opacity": 0.8, - "colorscale": plotly_layouts["unstable_colorscale"], - "line": {"width": 1, "color": "black"}, - "size": 7, - "symbol": "diamond", - "colorbar": { - "title": "Energy Above Hull
(eV/atom)", - "x": 0, - "y": 1, - "yanchor": "top", - "xpad": 0, - "ypad": 0, - "thickness": 0.02, - "thicknessmode": "fraction", - "len": 0.5, - }, + "color": "mediumvioletred", + "line": {"width": 2.0, "color": "black"}, + "symbol": "square", + "size": 16, }, } - ) - if highlight_entries: - highlight_markers = plotly_layouts["default_ternary_2d_marker_settings"].copy() - highlight_markers.update( - { - "a": list(highlight_props["x"]), - "b": list(highlight_props["y"]), - "c": list(highlight_props["z"]), - "name": "Highlighted", - "hovertext": highlight_props["texts"], - "marker": { - "color": "mediumvioletred", - "line": {"width": 2.0, "color": "black"}, - "symbol": "square", - "size": 16, - }, - } - ) elif self._dim == 3 and self.ternary_style == "3d": stable_markers = plotly_layouts["default_ternary_3d_marker_settings"].copy() unstable_markers = plotly_layouts["default_ternary_3d_marker_settings"].copy() - stable_markers.update( - { - "x": list(stable_props["y"]), - "y": list(stable_props["x"]), - "z": list(stable_props["z"]), - "name": "Stable", + stable_markers |= { + "x": list(stable_props["y"]), + "y": list(stable_props["x"]), + "z": list(stable_props["z"]), + "name": "Stable", + "marker": { + "color": "#1e1e1f", + "size": 11, + "opacity": 0.99, + }, + "hovertext": stable_props["texts"], + "error_z": { + "array": list(stable_props["uncertainties"]), + "type": "data", + "color": "darkgray", + "width": 10, + "thickness": 5, + }, + } + unstable_markers |= { + "x": unstable_props["y"], + "y": unstable_props["x"], + "z": unstable_props["z"], + "name": "Above Hull", + "hovertext": unstable_props["texts"], + "marker": { + "color": unstable_props["energies"], + "colorscale": plotly_layouts["unstable_colorscale"], + "size": 5, + "line": {"color": "black", "width": 1}, + "symbol": "diamond", + "opacity": 0.7, + "colorbar": { + "title": "Energy Above Hull
(eV/atom)", + "x": 0, + "y": 1, + "yanchor": "top", + "xpad": 0, + "ypad": 0, + "thickness": 0.02, + "thicknessmode": "fraction", + "len": 0.5, + }, + }, + } + if highlight_entries: + highlight_markers = plotly_layouts["default_ternary_3d_marker_settings"].copy() + highlight_markers |= { + "x": list(highlight_props["y"]), + "y": list(highlight_props["x"]), + "z": list(highlight_props["z"]), + "name": "Highlighted", "marker": { - "color": "#1e1e1f", - "size": 11, + "size": 12, "opacity": 0.99, + "symbol": "square", + "color": "mediumvioletred", }, - "hovertext": stable_props["texts"], + "hovertext": highlight_props["texts"], "error_z": { - "array": list(stable_props["uncertainties"]), + "array": list(highlight_props["uncertainties"]), "type": "data", "color": "darkgray", "width": 10, "thickness": 5, }, } - ) - unstable_markers.update( - { - "x": unstable_props["y"], - "y": unstable_props["x"], - "z": unstable_props["z"], - "name": "Above Hull", - "hovertext": unstable_props["texts"], - "marker": { - "color": unstable_props["energies"], - "colorscale": plotly_layouts["unstable_colorscale"], - "size": 5, - "line": {"color": "black", "width": 1}, - "symbol": "diamond", - "opacity": 0.7, - "colorbar": { - "title": "Energy Above Hull
(eV/atom)", - "x": 0, - "y": 1, - "yanchor": "top", - "xpad": 0, - "ypad": 0, - "thickness": 0.02, - "thicknessmode": "fraction", - "len": 0.5, - }, - }, - } - ) - if highlight_entries: - highlight_markers = plotly_layouts["default_ternary_3d_marker_settings"].copy() - highlight_markers.update( - { - "x": list(highlight_props["y"]), - "y": list(highlight_props["x"]), - "z": list(highlight_props["z"]), - "name": "Highlighted", - "marker": { - "size": 12, - "opacity": 0.99, - "symbol": "square", - "color": "mediumvioletred", - }, - "hovertext": highlight_props["texts"], - "error_z": { - "array": list(highlight_props["uncertainties"]), - "type": "data", - "color": "darkgray", - "width": 10, - "thickness": 5, - }, - } - ) elif self._dim == 4: stable_markers = plotly_layouts["default_quaternary_marker_settings"].copy() unstable_markers = plotly_layouts["default_quaternary_marker_settings"].copy() - stable_markers.update( - { - "x": stable_props["x"], - "y": stable_props["y"], - "z": stable_props["z"], - "name": "Stable", - "marker": { - "size": 7, - "opacity": 0.99, - "color": "darkgreen", - "line": {"color": "black", "width": 1}, + stable_markers |= { + "x": stable_props["x"], + "y": stable_props["y"], + "z": stable_props["z"], + "name": "Stable", + "marker": { + "size": 7, + "opacity": 0.99, + "color": "darkgreen", + "line": {"color": "black", "width": 1}, + }, + "hovertext": stable_props["texts"], + } + unstable_markers |= { + "x": unstable_props["x"], + "y": unstable_props["y"], + "z": unstable_props["z"], + "name": "Above Hull", + "marker": { + "color": unstable_props["energies"], + "colorscale": plotly_layouts["unstable_colorscale"], + "size": 5, + "symbol": "diamond", + "line": {"color": "black", "width": 1}, + "colorbar": { + "title": "Energy Above Hull
(eV/atom)", + "x": 0, + "y": 1, + "yanchor": "top", + "xpad": 0, + "ypad": 0, + "thickness": 0.02, + "thicknessmode": "fraction", + "len": 0.5, }, - "hovertext": stable_props["texts"], - } - ) - unstable_markers.update( - { - "x": unstable_props["x"], - "y": unstable_props["y"], - "z": unstable_props["z"], - "name": "Above Hull", + }, + "hovertext": unstable_props["texts"], + "visible": "legendonly", + } + if highlight_entries: + highlight_markers = plotly_layouts["default_quaternary_marker_settings"].copy() + highlight_markers |= { + "x": highlight_props["x"], + "y": highlight_props["y"], + "z": highlight_props["z"], + "name": "Highlighted", "marker": { - "color": unstable_props["energies"], - "colorscale": plotly_layouts["unstable_colorscale"], - "size": 5, - "symbol": "diamond", + "size": 9, + "opacity": 0.99, + "symbol": "square", + "color": "mediumvioletred", "line": {"color": "black", "width": 1}, - "colorbar": { - "title": "Energy Above Hull
(eV/atom)", - "x": 0, - "y": 1, - "yanchor": "top", - "xpad": 0, - "ypad": 0, - "thickness": 0.02, - "thicknessmode": "fraction", - "len": 0.5, - }, }, - "hovertext": unstable_props["texts"], - "visible": "legendonly", + "hovertext": highlight_props["texts"], } - ) - if highlight_entries: - highlight_markers = plotly_layouts["default_quaternary_marker_settings"].copy() - highlight_markers.update( - { - "x": highlight_props["x"], - "y": highlight_props["y"], - "z": highlight_props["z"], - "name": "Highlighted", - "marker": { - "size": 9, - "opacity": 0.99, - "symbol": "square", - "color": "mediumvioletred", - "line": {"color": "black", "width": 1}, - }, - "hovertext": highlight_props["texts"], - } - ) highlight_marker_plot = None diff --git a/src/pymatgen/analysis/piezo.py b/src/pymatgen/analysis/piezo.py index 7a5b32ac8d0..2f168e1eef4 100644 --- a/src/pymatgen/analysis/piezo.py +++ b/src/pymatgen/analysis/piezo.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.core.tensors import Tensor if TYPE_CHECKING: diff --git a/src/pymatgen/analysis/piezo_sensitivity.py b/src/pymatgen/analysis/piezo_sensitivity.py index d5b4f232199..0a93140dc15 100644 --- a/src/pymatgen/analysis/piezo_sensitivity.py +++ b/src/pymatgen/analysis/piezo_sensitivity.py @@ -7,6 +7,7 @@ import numpy as np from monty.dev import requires + from pymatgen.core.tensors import Tensor from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/src/pymatgen/analysis/pourbaix_diagram.py b/src/pymatgen/analysis/pourbaix_diagram.py index dad540db5fb..135f2abbabf 100644 --- a/src/pymatgen/analysis/pourbaix_diagram.py +++ b/src/pymatgen/analysis/pourbaix_diagram.py @@ -16,6 +16,9 @@ import numpy as np from monty.json import MontyDecoder, MSONable +from scipy.spatial import ConvexHull, HalfspaceIntersection +from scipy.special import comb + from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram from pymatgen.analysis.reaction_calculator import Reaction, ReactionError from pymatgen.core import Composition, Element @@ -26,8 +29,6 @@ from pymatgen.util.due import Doi, due from pymatgen.util.plotting import pretty_plot from pymatgen.util.string import Stringify -from scipy.spatial import ConvexHull, HalfspaceIntersection -from scipy.special import comb if TYPE_CHECKING: from typing import Any @@ -84,10 +85,9 @@ class PourbaixEntry(MSONable, Stringify): def __init__(self, entry, entry_id=None, concentration=1e-6): """ Args: - entry (ComputedEntry/ComputedStructureEntry/PDEntry/IonEntry): An - entry object - entry_id (): - concentration (): + entry (ComputedEntry | ComputedStructureEntry | PDEntry | IonEntry): An entry object + entry_id (str): A string id for the entry + concentration (float): Concentration of the entry in M. Defaults to 1e-6. """ self.entry = entry if isinstance(entry, IonEntry): @@ -101,7 +101,7 @@ def __init__(self, entry, entry_id=None, concentration=1e-6): self.uncorrected_energy = entry.energy if entry_id is not None: self.entry_id = entry_id - elif hasattr(entry, "entry_id") and entry.entry_id: + elif getattr(entry, "entry_id", None): self.entry_id = entry.entry_id else: self.entry_id = None @@ -794,8 +794,8 @@ def get_decomposition_energy(self, entry, pH, V): Args: entry (PourbaixEntry): PourbaixEntry corresponding to compound to find the decomposition for - pH (float, [float]): pH at which to find the decomposition - V (float, [float]): voltage at which to find the decomposition + pH (float, list[float]): pH at which to find the decomposition + V (float, list[float]): voltage at which to find the decomposition Returns: Decomposition energy for the entry, i. e. the energy above @@ -817,13 +817,13 @@ def get_decomposition_energy(self, entry, pH, V): decomposition_energy /= entry.composition.num_atoms return decomposition_energy - def get_hull_energy(self, pH, V): + def get_hull_energy(self, pH: float | list[float], V: float | list[float]) -> np.ndarray: """Get the minimum energy of the Pourbaix "basin" that is formed from the stable Pourbaix planes. Vectorized. Args: - pH (float or [float]): pH at which to find the hull energy - V (float or [float]): V at which to find the hull energy + pH (float | list[float]): pH at which to find the hull energy + V (float | list[float]): V at which to find the hull energy Returns: np.array: minimum Pourbaix energy at conditions diff --git a/src/pymatgen/analysis/prototypes.py b/src/pymatgen/analysis/prototypes.py index 6423e53170c..a9cdf29c70c 100644 --- a/src/pymatgen/analysis/prototypes.py +++ b/src/pymatgen/analysis/prototypes.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING from monty.serialization import loadfn + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.util.due import Doi, due diff --git a/src/pymatgen/analysis/quasiharmonic.py b/src/pymatgen/analysis/quasiharmonic.py index f7b8aa08a05..e41e072acad 100644 --- a/src/pymatgen/analysis/quasiharmonic.py +++ b/src/pymatgen/analysis/quasiharmonic.py @@ -15,14 +15,15 @@ import numpy as np from monty.dev import deprecated -from pymatgen.analysis.eos import EOS, PolynomialEOS -from pymatgen.core.units import FloatWithUnit -from pymatgen.util.due import Doi, due from scipy.constants import physical_constants from scipy.integrate import quadrature from scipy.misc import derivative from scipy.optimize import minimize +from pymatgen.analysis.eos import EOS, PolynomialEOS +from pymatgen.core.units import FloatWithUnit +from pymatgen.util.due import Doi, due + __author__ = "Kiran Mathew, Brandon Bocklund" __credits__ = "Cormac Toher" @@ -145,7 +146,7 @@ def optimize_gibbs_free_energy(self): def optimizer(self, temperature): """Evaluate G(V, T, P) at the given temperature(and pressure) and minimize it w.r.t. V. - 1. Compute the vibrational Helmholtz free energy, A_vib. + 1. Compute the vibrational Helmholtz free energy, A_vib. 2. Compute the Gibbs free energy as a function of volume, temperature and pressure, G(V,T,P). 3. Perform an equation of state fit to get the functional form of @@ -247,7 +248,7 @@ def debye_temperature(self, volume: float) -> float: @staticmethod def debye_integral(y): """ - Debye integral. Eq(5) in doi.org/10.1016/j.comphy.2003.12.001. + Debye integral. Eq(5) in doi.org/10.1016/j.comphy.2003.12.001. Args: y (float): Debye temperature / T, upper limit diff --git a/src/pymatgen/analysis/quasirrho.py b/src/pymatgen/analysis/quasirrho.py index 5846f19339b..0c66f730db3 100644 --- a/src/pymatgen/analysis/quasirrho.py +++ b/src/pymatgen/analysis/quasirrho.py @@ -15,14 +15,16 @@ import numpy as np import scipy.constants as const + from pymatgen.core.units import kb as kb_ev from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Molecule from pymatgen.io.gaussian import GaussianOutput from pymatgen.io.qchem.outputs import QCOutput - from typing_extensions import Self __author__ = "Alex Epstein" __copyright__ = "Copyright 2020, The Materials Project" diff --git a/src/pymatgen/analysis/reaction_calculator.py b/src/pymatgen/analysis/reaction_calculator.py index 638e9189318..26f5766dd36 100644 --- a/src/pymatgen/analysis/reaction_calculator.py +++ b/src/pymatgen/analysis/reaction_calculator.py @@ -10,16 +10,18 @@ import numpy as np from monty.fractions import gcd_float from monty.json import MontyDecoder, MSONable +from uncertainties import ufloat + from pymatgen.core.composition import Composition from pymatgen.entries.computed_entries import ComputedEntry -from uncertainties import ufloat if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + from pymatgen.core import Element, Species from pymatgen.util.typing import CompositionLike - from typing_extensions import Self __author__ = "Shyue Ping Ong, Anubhav Jain" __copyright__ = "Copyright 2011, The Materials Project" @@ -137,7 +139,7 @@ def normalize_to_element(self, element: Species | Element, factor: float = 1) -> Another factor can be specified. Args: - element (Element/Species): Element to normalize to. + element (SpeciesLike): Element to normalize to. factor (float): Factor to normalize to. Defaults to 1. """ all_comp = self._all_comp @@ -150,7 +152,7 @@ def get_el_amount(self, element: Element | Species) -> float: """Get the amount of the element in the reaction. Args: - element (Element/Species): Element in the reaction + element (SpeciesLike): Element in the reaction Returns: Amount of that element in the reaction. @@ -260,7 +262,7 @@ def from_dict(cls, dct: dict) -> Self: dct (dict): from as_dict(). Returns: - A BalancedReaction object. + BalancedReaction """ reactants = {Composition(comp): coeff for comp, coeff in dct["reactants"].items()} products = {Composition(comp): coeff for comp, coeff in dct["products"].items()} diff --git a/src/pymatgen/analysis/solar/slme.py b/src/pymatgen/analysis/solar/slme.py index 976ef49b701..cc7fdf2dabf 100644 --- a/src/pymatgen/analysis/solar/slme.py +++ b/src/pymatgen/analysis/solar/slme.py @@ -22,9 +22,10 @@ from scipy.integrate import simpson except ImportError: from scipy.integrate import simps as simpson +from scipy.interpolate import interp1d + from pymatgen.io.vasp.outputs import Vasprun from pymatgen.util.due import Doi, due -from scipy.interpolate import interp1d due.cite( Doi("10.1021/acs.chemmater.9b02166"), diff --git a/src/pymatgen/analysis/structure_analyzer.py b/src/pymatgen/analysis/structure_analyzer.py index a044216444a..b3d33dfaef5 100644 --- a/src/pymatgen/analysis/structure_analyzer.py +++ b/src/pymatgen/analysis/structure_analyzer.py @@ -10,10 +10,11 @@ import matplotlib.pyplot as plt import numpy as np +from scipy.spatial import Voronoi + from pymatgen.analysis.local_env import JmolNN, VoronoiNN from pymatgen.core import Composition, Element, PeriodicSite, Species from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from scipy.spatial import Voronoi if TYPE_CHECKING: from pymatgen.core import Structure @@ -394,7 +395,7 @@ def contains_peroxide(structure, relative_cutoff=1.1): atoms must be to each other to be considered a peroxide. Returns: - Boolean indicating if structure contains a peroxide anion. + bool: True if structure contains a peroxide anion. """ return oxide_type(structure, relative_cutoff) == "peroxide" diff --git a/src/pymatgen/analysis/structure_matcher.py b/src/pymatgen/analysis/structure_matcher.py index 7a2650bef71..cf6e890c9e7 100644 --- a/src/pymatgen/analysis/structure_matcher.py +++ b/src/pymatgen/analysis/structure_matcher.py @@ -8,6 +8,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.core import Composition, Lattice, Structure, get_el_sp from pymatgen.optimization.linear_assignment import LinearAssignment from pymatgen.util.coord import lattice_points_in_supercell @@ -17,9 +18,10 @@ from collections.abc import Mapping, Sequence from typing import Literal - from pymatgen.util.typing import SpeciesLike from typing_extensions import Self + from pymatgen.util.typing import SpeciesLike + __author__ = "William Davidson Richards, Stephen Dacek, Shyue Ping Ong" __copyright__ = "Copyright 2011, The Materials Project" __version__ = "1.0" @@ -51,7 +53,7 @@ def are_equal(self, sp1, sp2) -> bool: definition in Site and PeriodicSite. Returns: - Boolean indicating whether species are considered equal. + bool: True if species are considered equal. """ return False @@ -120,7 +122,7 @@ def are_equal(self, sp1, sp2) -> bool: definition in Site and PeriodicSite. Returns: - Boolean indicating whether species are equal. + bool: True if species are equal. """ return sp1 == sp2 @@ -149,7 +151,7 @@ def are_equal(self, sp1, sp2) -> bool: definition in Site and PeriodicSite. Returns: - Boolean indicating whether species are equal. + bool: True if species are equal. """ for s1 in sp1: spin1 = getattr(s1, "spin", 0) or 0 @@ -186,8 +188,7 @@ def are_equal(self, sp1, sp2) -> bool: definition in Site and PeriodicSite. Returns: - Boolean indicating whether species are the same based on element - and amounts. + bool: True if species are the same based on element and amounts. """ comp1 = Composition(sp1) comp2 = Composition(sp2) @@ -1050,7 +1051,7 @@ def fit_anonymous( If True, skip to get a primitive structure and perform Niggli reduction for struct1 and struct2 Returns: - bool: Whether a species mapping can map struct1 to struct2 + bool: True if a species mapping can map struct1 to struct2 """ struct1, struct2 = self._process_species([struct1, struct2]) struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli, skip_structure_reduction) diff --git a/src/pymatgen/analysis/structure_prediction/dopant_predictor.py b/src/pymatgen/analysis/structure_prediction/dopant_predictor.py index dcc750ec9f3..e21603e13ee 100644 --- a/src/pymatgen/analysis/structure_prediction/dopant_predictor.py +++ b/src/pymatgen/analysis/structure_prediction/dopant_predictor.py @@ -5,6 +5,7 @@ import warnings import numpy as np + from pymatgen.analysis.structure_prediction.substitution_probability import SubstitutionPredictor from pymatgen.core import Element, Species diff --git a/src/pymatgen/analysis/structure_prediction/substitution_probability.py b/src/pymatgen/analysis/structure_prediction/substitution_probability.py index 75720ce9419..01a6d32cb2e 100644 --- a/src/pymatgen/analysis/structure_prediction/substitution_probability.py +++ b/src/pymatgen/analysis/structure_prediction/substitution_probability.py @@ -13,12 +13,15 @@ from typing import TYPE_CHECKING from monty.design_patterns import cached_class + from pymatgen.core import Species, get_el_sp from pymatgen.util.due import Doi, due if TYPE_CHECKING: from typing_extensions import Self + from pymatgen.util.typing import SpeciesLike + __author__ = "Will Richards, Geoffroy Hautier" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "1.2" @@ -83,8 +86,8 @@ def __init__(self, lambda_table=None, alpha=-5): def get_lambda(self, s1, s2): """ Args: - s1 (Element/Species/str/int): Describes Ion in 1st Structure - s2 (Element/Species/str/int): Describes Ion in 2nd Structure. + s1 (SpeciesLike): Ion in 1st structure. + s2 (SpeciesLike): Ion in 2nd structure. Returns: Lambda values @@ -92,13 +95,13 @@ def get_lambda(self, s1, s2): key = frozenset([get_el_sp(s1), get_el_sp(s2)]) return self._l.get(key, self.alpha) - def get_px(self, sp): + def get_px(self, sp: SpeciesLike) -> float: """ Args: - sp (Species/Element): Species. + sp (SpeciesLike): Species. Returns: - Probability + float: Probability """ return self._px[get_el_sp(sp)] @@ -184,7 +187,7 @@ class SubstitutionPredictor: def __init__(self, lambda_table=None, alpha=-5, threshold=1e-3): """ Args: - lambda_table (): Input lambda table. + lambda_table (dict): Input lambda table. alpha (float): weight function for never observed substitutions threshold (float): Threshold to use to identify high probability structures. """ diff --git a/src/pymatgen/analysis/structure_prediction/substitutor.py b/src/pymatgen/analysis/structure_prediction/substitutor.py index c36df98f030..2921355d637 100644 --- a/src/pymatgen/analysis/structure_prediction/substitutor.py +++ b/src/pymatgen/analysis/structure_prediction/substitutor.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.alchemy.filters import RemoveDuplicatesFilter, RemoveExistingFilter from pymatgen.alchemy.materials import TransformedStructure from pymatgen.alchemy.transmuters import StandardTransmuter diff --git a/src/pymatgen/analysis/structure_prediction/volume_predictor.py b/src/pymatgen/analysis/structure_prediction/volume_predictor.py index 5ce74b652f4..dd749e8416a 100644 --- a/src/pymatgen/analysis/structure_prediction/volume_predictor.py +++ b/src/pymatgen/analysis/structure_prediction/volume_predictor.py @@ -7,6 +7,7 @@ import numpy as np from monty.serialization import loadfn + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import Structure diff --git a/src/pymatgen/analysis/surface_analysis.py b/src/pymatgen/analysis/surface_analysis.py index f395f90afd0..ff49d82088f 100644 --- a/src/pymatgen/analysis/surface_analysis.py +++ b/src/pymatgen/analysis/surface_analysis.py @@ -42,6 +42,9 @@ import matplotlib.pyplot as plt import numpy as np +from sympy import Symbol +from sympy.solvers import linsolve, solve + from pymatgen.analysis.wulff import WulffShape from pymatgen.core import Structure from pymatgen.core.composition import Composition @@ -51,13 +54,12 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.due import Doi, due from pymatgen.util.plotting import pretty_plot -from sympy import Symbol -from sympy.solvers import linsolve, solve if TYPE_CHECKING: - from pymatgen.util.typing import Tuple3Ints from typing_extensions import Self + from pymatgen.util.typing import Tuple3Ints + EV_PER_ANG2_TO_JOULES_PER_M2 = 16.0217656 __author__ = "Richard Tran" diff --git a/src/pymatgen/analysis/topological/spillage.py b/src/pymatgen/analysis/topological/spillage.py index e088b98553b..d0da47ccfe1 100644 --- a/src/pymatgen/analysis/topological/spillage.py +++ b/src/pymatgen/analysis/topological/spillage.py @@ -8,6 +8,7 @@ from __future__ import annotations import numpy as np + from pymatgen.io.vasp.outputs import Wavecar diff --git a/src/pymatgen/analysis/transition_state.py b/src/pymatgen/analysis/transition_state.py index e7c6d9b2924..5dbbe1fbb54 100644 --- a/src/pymatgen/analysis/transition_state.py +++ b/src/pymatgen/analysis/transition_state.py @@ -15,11 +15,12 @@ import matplotlib.pyplot as plt import numpy as np from monty.json import MSONable, jsanitize +from scipy.interpolate import CubicSpline + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import Structure from pymatgen.io.vasp import Outcar from pymatgen.util.plotting import pretty_plot -from scipy.interpolate import CubicSpline if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/pymatgen/analysis/wulff.py b/src/pymatgen/analysis/wulff.py index 23a6986d87f..420d77c4161 100644 --- a/src/pymatgen/analysis/wulff.py +++ b/src/pymatgen/analysis/wulff.py @@ -25,10 +25,11 @@ import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go +from scipy.spatial import ConvexHull + from pymatgen.core.structure import Structure from pymatgen.util.coord import get_angle from pymatgen.util.string import unicodeify_spacegroup -from scipy.spatial import ConvexHull if TYPE_CHECKING: from pymatgen.core.lattice import Lattice diff --git a/src/pymatgen/analysis/xas/spectrum.py b/src/pymatgen/analysis/xas/spectrum.py index 97af8b85282..007c4596642 100644 --- a/src/pymatgen/analysis/xas/spectrum.py +++ b/src/pymatgen/analysis/xas/spectrum.py @@ -7,10 +7,11 @@ from typing import TYPE_CHECKING import numpy as np +from scipy.interpolate import interp1d + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core.spectrum import Spectrum from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from scipy.interpolate import interp1d if TYPE_CHECKING: from typing import Literal diff --git a/src/pymatgen/analysis/xps.py b/src/pymatgen/analysis/xps.py index 4e9acff9540..1aca37a1eee 100644 --- a/src/pymatgen/analysis/xps.py +++ b/src/pymatgen/analysis/xps.py @@ -25,14 +25,16 @@ import numpy as np import pandas as pd + from pymatgen.core import Element from pymatgen.core.spectrum import Spectrum from pymatgen.util.due import Doi, due if TYPE_CHECKING: - from pymatgen.electronic_structure.dos import CompleteDos from typing_extensions import Self + from pymatgen.electronic_structure.dos import CompleteDos + due.cite( Doi("10.21105/joss.007733"), diff --git a/src/pymatgen/apps/battery/analyzer.py b/src/pymatgen/apps/battery/analyzer.py index 3d08e147c35..2543b44da92 100644 --- a/src/pymatgen/apps/battery/analyzer.py +++ b/src/pymatgen/apps/battery/analyzer.py @@ -6,6 +6,7 @@ from collections import defaultdict import scipy.constants as const + from pymatgen.core import Composition, Element, Species __author__ = "Anubhav Jain" diff --git a/src/pymatgen/apps/battery/battery_abc.py b/src/pymatgen/apps/battery/battery_abc.py index 8465f85445f..78f4d984498 100644 --- a/src/pymatgen/apps/battery/battery_abc.py +++ b/src/pymatgen/apps/battery/battery_abc.py @@ -12,9 +12,10 @@ from typing import TYPE_CHECKING from monty.json import MSONable -from pymatgen.core import Composition, Element from scipy.constants import N_A +from pymatgen.core import Composition, Element + if TYPE_CHECKING: from pymatgen.entries.computed_entries import ComputedEntry diff --git a/src/pymatgen/apps/battery/conversion_battery.py b/src/pymatgen/apps/battery/conversion_battery.py index ee274bcd40d..b84da63efd6 100644 --- a/src/pymatgen/apps/battery/conversion_battery.py +++ b/src/pymatgen/apps/battery/conversion_battery.py @@ -5,19 +5,21 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from scipy.constants import N_A + from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.reaction_calculator import BalancedReaction from pymatgen.apps.battery.battery_abc import AbstractElectrode, AbstractVoltagePair from pymatgen.core import Composition, Element from pymatgen.core.units import Charge, Time -from scipy.constants import N_A if TYPE_CHECKING: from collections.abc import Iterable - from pymatgen.entries.computed_entries import ComputedEntry from typing_extensions import Self + from pymatgen.entries.computed_entries import ComputedEntry + @dataclass class ConversionElectrode(AbstractElectrode): diff --git a/src/pymatgen/apps/battery/insertion_battery.py b/src/pymatgen/apps/battery/insertion_battery.py index 5f073fbc916..09e53128606 100644 --- a/src/pymatgen/apps/battery/insertion_battery.py +++ b/src/pymatgen/apps/battery/insertion_battery.py @@ -9,12 +9,13 @@ from typing import TYPE_CHECKING from monty.json import MontyDecoder +from scipy.constants import N_A + from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram from pymatgen.apps.battery.battery_abc import AbstractElectrode, AbstractVoltagePair from pymatgen.core import Composition, Element from pymatgen.core.units import Charge, Time from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry -from scipy.constants import N_A if TYPE_CHECKING: from collections.abc import Iterable @@ -319,19 +320,17 @@ def get_summary_dict(self, print_subelectrodes=True) -> dict: chg_comp = self.fully_charged_entry.composition dischg_comp = self.fully_discharged_entry.composition - dct.update( - { - "id_charge": self.fully_charged_entry.entry_id, - "formula_charge": chg_comp.reduced_formula, - "id_discharge": self.fully_discharged_entry.entry_id, - "formula_discharge": dischg_comp.reduced_formula, - "max_instability": self.get_max_instability(), - "min_instability": self.get_min_instability(), - "material_ids": [itr_ent.entry_id for itr_ent in self.get_all_entries()], - "stable_material_ids": [itr_ent.entry_id for itr_ent in self.get_stable_entries()], - "unstable_material_ids": [itr_ent.entry_id for itr_ent in self.get_unstable_entries()], - } - ) + dct |= { + "id_charge": self.fully_charged_entry.entry_id, + "formula_charge": chg_comp.reduced_formula, + "id_discharge": self.fully_discharged_entry.entry_id, + "formula_discharge": dischg_comp.reduced_formula, + "max_instability": self.get_max_instability(), + "min_instability": self.get_min_instability(), + "material_ids": [itr_ent.entry_id for itr_ent in self.get_all_entries()], + "stable_material_ids": [itr_ent.entry_id for itr_ent in self.get_stable_entries()], + "unstable_material_ids": [itr_ent.entry_id for itr_ent in self.get_unstable_entries()], + } if all("decomposition_energy" in itr_ent.data for itr_ent in self.get_all_entries()): dct.update( stability_charge=self.fully_charged_entry.data["decomposition_energy"], @@ -342,7 +341,7 @@ def get_summary_dict(self, print_subelectrodes=True) -> dict: ) if all("muO2" in itr_ent.data for itr_ent in self.get_all_entries()): - dct.update({"muO2_data": {itr_ent.entry_id: itr_ent.data["muO2"] for itr_ent in self.get_all_entries()}}) + dct |= {"muO2_data": {itr_ent.entry_id: itr_ent.data["muO2"] for itr_ent in self.get_all_entries()}} return dct diff --git a/src/pymatgen/apps/battery/plotter.py b/src/pymatgen/apps/battery/plotter.py index 150f413784b..f8725ef25e6 100644 --- a/src/pymatgen/apps/battery/plotter.py +++ b/src/pymatgen/apps/battery/plotter.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import plotly.graph_objects as go + from pymatgen.util.plotting import pretty_plot if TYPE_CHECKING: diff --git a/src/pymatgen/apps/borg/hive.py b/src/pymatgen/apps/borg/hive.py index 87aaeeb2500..a90380d0dee 100644 --- a/src/pymatgen/apps/borg/hive.py +++ b/src/pymatgen/apps/borg/hive.py @@ -12,6 +12,7 @@ from monty.io import zopen from monty.json import MSONable + from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry from pymatgen.io.gaussian import GaussianOutput from pymatgen.io.vasp.inputs import Incar, Poscar, Potcar diff --git a/src/pymatgen/cli/feff_plot_cross_section.py b/src/pymatgen/cli/feff_plot_cross_section.py index 2bf9d89fe23..a9a59f3a650 100755 --- a/src/pymatgen/cli/feff_plot_cross_section.py +++ b/src/pymatgen/cli/feff_plot_cross_section.py @@ -7,6 +7,7 @@ import argparse import matplotlib.pyplot as plt + from pymatgen.io.feff.outputs import Xmu from pymatgen.util.plotting import pretty_plot diff --git a/src/pymatgen/cli/pmg.py b/src/pymatgen/cli/pmg.py index 1ee6a93ed7e..491d4ab2c11 100755 --- a/src/pymatgen/cli/pmg.py +++ b/src/pymatgen/cli/pmg.py @@ -7,6 +7,8 @@ import argparse import itertools +from tabulate import tabulate, tabulate_formats + from pymatgen.cli.pmg_analyze import analyze from pymatgen.cli.pmg_config import configure_pmg from pymatgen.cli.pmg_plot import plot @@ -15,7 +17,6 @@ from pymatgen.core import SETTINGS from pymatgen.core.structure import Structure from pymatgen.io.vasp import Incar, Potcar -from tabulate import tabulate, tabulate_formats def parse_view(args): diff --git a/src/pymatgen/cli/pmg_analyze.py b/src/pymatgen/cli/pmg_analyze.py index 4db3fbcb2f6..c12ef4186ef 100644 --- a/src/pymatgen/cli/pmg_analyze.py +++ b/src/pymatgen/cli/pmg_analyze.py @@ -7,10 +7,11 @@ import os import re +from tabulate import tabulate + from pymatgen.apps.borg.hive import SimpleVaspToComputedEntryDrone, VaspToComputedEntryDrone from pymatgen.apps.borg.queen import BorgQueen from pymatgen.io.vasp import Outcar -from tabulate import tabulate __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" diff --git a/src/pymatgen/cli/pmg_config.py b/src/pymatgen/cli/pmg_config.py index 8a0f4a04317..d8253983e19 100755 --- a/src/pymatgen/cli/pmg_config.py +++ b/src/pymatgen/cli/pmg_config.py @@ -13,10 +13,11 @@ from monty.json import jsanitize from monty.serialization import dumpfn, loadfn +from ruamel import yaml + from pymatgen.core import OLD_SETTINGS_FILE, SETTINGS_FILE, Element from pymatgen.io.cp2k.inputs import GaussianTypeOrbitalBasisSet, GthPotential from pymatgen.io.cp2k.utils import chunk -from ruamel import yaml if TYPE_CHECKING: from argparse import Namespace diff --git a/src/pymatgen/cli/pmg_plot.py b/src/pymatgen/cli/pmg_plot.py index 56e7d629660..89d0fa269ae 100755 --- a/src/pymatgen/cli/pmg_plot.py +++ b/src/pymatgen/cli/pmg_plot.py @@ -5,6 +5,7 @@ from __future__ import annotations import matplotlib.pyplot as plt + from pymatgen.analysis.diffraction.xrd import XRDCalculator from pymatgen.core.structure import Structure from pymatgen.electronic_structure.plotter import DosPlotter diff --git a/src/pymatgen/cli/pmg_structure.py b/src/pymatgen/cli/pmg_structure.py index b8046b699ce..3f85040fede 100755 --- a/src/pymatgen/cli/pmg_structure.py +++ b/src/pymatgen/cli/pmg_structure.py @@ -4,10 +4,11 @@ from __future__ import annotations +from tabulate import tabulate + from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from tabulate import tabulate __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" diff --git a/src/pymatgen/command_line/bader_caller.py b/src/pymatgen/command_line/bader_caller.py index c85baf861ed..dbbb743acc9 100644 --- a/src/pymatgen/command_line/bader_caller.py +++ b/src/pymatgen/command_line/bader_caller.py @@ -27,6 +27,7 @@ from monty.dev import deprecated from monty.shutil import decompress_file from monty.tempfile import ScratchDir + from pymatgen.io.common import VolumetricData from pymatgen.io.vasp.inputs import Potcar from pymatgen.io.vasp.outputs import Chgcar @@ -34,9 +35,10 @@ if TYPE_CHECKING: from typing import Any - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + __author__ = "shyuepingong" __version__ = "0.1" __maintainer__ = "Shyue Ping Ong" diff --git a/src/pymatgen/command_line/chargemol_caller.py b/src/pymatgen/command_line/chargemol_caller.py index a22e4b345a0..d9b17a27404 100644 --- a/src/pymatgen/command_line/chargemol_caller.py +++ b/src/pymatgen/command_line/chargemol_caller.py @@ -51,6 +51,7 @@ import numpy as np from monty.tempfile import ScratchDir + from pymatgen.core import Element from pymatgen.io.vasp.inputs import Potcar from pymatgen.io.vasp.outputs import Chgcar diff --git a/src/pymatgen/command_line/critic2_caller.py b/src/pymatgen/command_line/critic2_caller.py index 9855345d7ff..4f6261afe3e 100644 --- a/src/pymatgen/command_line/critic2_caller.py +++ b/src/pymatgen/command_line/critic2_caller.py @@ -52,17 +52,19 @@ from monty.json import MSONable from monty.serialization import loadfn from monty.tempfile import ScratchDir +from scipy.spatial import KDTree + from pymatgen.analysis.graphs import StructureGraph from pymatgen.core import DummySpecies from pymatgen.io.vasp.inputs import Potcar from pymatgen.io.vasp.outputs import Chgcar, VolumetricData from pymatgen.util.due import Doi, due -from scipy.spatial import KDTree if TYPE_CHECKING: - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/src/pymatgen/command_line/enumlib_caller.py b/src/pymatgen/command_line/enumlib_caller.py index 619060b8386..bdae0256e0e 100644 --- a/src/pymatgen/command_line/enumlib_caller.py +++ b/src/pymatgen/command_line/enumlib_caller.py @@ -39,6 +39,7 @@ from monty.dev import requires from monty.fractions import lcm from monty.tempfile import ScratchDir + from pymatgen.core import DummySpecies, PeriodicSite, Structure from pymatgen.io.vasp.inputs import Poscar from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/src/pymatgen/command_line/gulp_caller.py b/src/pymatgen/command_line/gulp_caller.py index fa3e250b99a..b13912f2433 100644 --- a/src/pymatgen/command_line/gulp_caller.py +++ b/src/pymatgen/command_line/gulp_caller.py @@ -10,6 +10,7 @@ import subprocess from monty.tempfile import ScratchDir + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.core import Element, Lattice, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/src/pymatgen/command_line/mcsqs_caller.py b/src/pymatgen/command_line/mcsqs_caller.py index 980f04e2f13..dd1cca70b1c 100644 --- a/src/pymatgen/command_line/mcsqs_caller.py +++ b/src/pymatgen/command_line/mcsqs_caller.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, NamedTuple from monty.dev import requires + from pymatgen.core.structure import Structure if TYPE_CHECKING: diff --git a/src/pymatgen/command_line/vampire_caller.py b/src/pymatgen/command_line/vampire_caller.py index 44aa50c3895..8709088514d 100644 --- a/src/pymatgen/command_line/vampire_caller.py +++ b/src/pymatgen/command_line/vampire_caller.py @@ -21,6 +21,7 @@ import pandas as pd from monty.dev import requires from monty.json import MSONable + from pymatgen.analysis.magnetism.heisenberg import HeisenbergMapper __author__ = "ncfrey" diff --git a/src/pymatgen/core/__init__.py b/src/pymatgen/core/__init__.py index 848c251a8b4..840b16e9b69 100644 --- a/src/pymatgen/core/__init__.py +++ b/src/pymatgen/core/__init__.py @@ -7,6 +7,8 @@ from importlib.metadata import PackageNotFoundError, version from typing import Any +from ruamel.yaml import YAML + from pymatgen.core.composition import Composition from pymatgen.core.lattice import Lattice from pymatgen.core.operations import SymmOp @@ -14,7 +16,6 @@ from pymatgen.core.sites import PeriodicSite, Site from pymatgen.core.structure import IMolecule, IStructure, Molecule, PeriodicNeighbor, SiteCollection, Structure from pymatgen.core.units import ArrayWithUnit, FloatWithUnit, Unit -from ruamel.yaml import YAML __author__ = "Pymatgen Development Team" __email__ = "pymatgen@googlegroups.com" diff --git a/src/pymatgen/core/bonds.py b/src/pymatgen/core/bonds.py index 7faed475c54..6c1ced8fa47 100644 --- a/src/pymatgen/core/bonds.py +++ b/src/pymatgen/core/bonds.py @@ -103,7 +103,7 @@ def is_bonded( bond length. If None, a ValueError will be thrown. Returns: - bool: whether two sites are bonded. + bool: True if two sites are bonded. """ sp1 = next(iter(site1.species)) sp2 = next(iter(site2.species)) diff --git a/src/pymatgen/core/composition.py b/src/pymatgen/core/composition.py index 9d889d83780..168591cbff4 100644 --- a/src/pymatgen/core/composition.py +++ b/src/pymatgen/core/composition.py @@ -18,6 +18,7 @@ from monty.fractions import gcd, gcd_float from monty.json import MSONable from monty.serialization import loadfn + from pymatgen.core.periodic_table import DummySpecies, Element, ElementType, Species, get_el_sp from pymatgen.core.units import Mass from pymatgen.util.string import Stringify, formula_double_format @@ -26,9 +27,10 @@ from collections.abc import Generator, Iterator from typing import Any, ClassVar - from pymatgen.util.typing import SpeciesLike from typing_extensions import Self + from pymatgen.util.typing import SpeciesLike + module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -505,7 +507,7 @@ def get_atomic_fraction(self, el: SpeciesLike) -> float: """Calculate atomic fraction of an Element or Species. Args: - el (Element/Species): Element or Species to get fraction for. + el (SpeciesLike): Element or Species to get fraction for. Returns: Atomic fraction for element el in Composition @@ -534,7 +536,7 @@ def contains_element_type(self, category: str) -> bool: "actinoid", "radioactive", "quadrupolar", "s-block", "p-block", "d-block", "f-block". Returns: - bool: Whether any elements in Composition match category. + bool: True if any elements in Composition match category. """ allowed_categories = [element.value for element in ElementType] diff --git a/src/pymatgen/core/interface.py b/src/pymatgen/core/interface.py index 2856130e014..a7a49004409 100644 --- a/src/pymatgen/core/interface.py +++ b/src/pymatgen/core/interface.py @@ -15,6 +15,9 @@ import numpy as np from monty.fractions import lcm from numpy.testing import assert_allclose +from scipy.cluster.hierarchy import fcluster, linkage +from scipy.spatial.distance import squareform + from pymatgen.analysis.adsorption import AdsorbateSiteFinder from pymatgen.core.lattice import Lattice from pymatgen.core.sites import PeriodicSite, Site @@ -22,17 +25,16 @@ from pymatgen.core.surface import Slab from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.typing import Tuple3Ints -from scipy.cluster.hierarchy import fcluster, linkage -from scipy.spatial.distance import squareform if TYPE_CHECKING: from collections.abc import Sequence from typing import Any, Callable from numpy.typing import ArrayLike, NDArray + from typing_extensions import Self + from pymatgen.core import Element from pymatgen.util.typing import CompositionLike, Matrix3D, MillerIndex, Tuple3Floats, Vector3D - from typing_extensions import Self Tuple4Ints = tuple[int, int, int, int] logger = logging.getLogger(__name__) @@ -81,8 +83,8 @@ def __init__( """A Structure with additional information and methods pertaining to GBs. Args: - lattice (Lattice/3x3 array): The lattice, either as an instance or - any 2D array. Each row should correspond to a lattice vector. + lattice (Lattice | np.ndarray): The lattice, either as an instance or + a 3x3 array. Each row should correspond to a lattice vector. species ([Species]): Sequence of species on each site. Can take in flexible input, including: @@ -1766,7 +1768,7 @@ def enum_sigma_ort( e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. Returns: - dict: sigmas dictionary with keys as the possible integer sigma values + dict: sigmas dictionary with keys as the possible integer sigma values and values as list of the possible rotation angles to the corresponding sigma values. e.g. the format as {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} @@ -2388,7 +2390,7 @@ def symm_group_cubic(mat: NDArray) -> list: """Obtain cubic symmetric equivalents of the list of vectors. Args: - mat (n by 3 array/matrix): lattice matrix + mat (np.ndarray): n x 3 lattice matrix Returns: @@ -2449,9 +2451,8 @@ def __init__( and methods pertaining to interfaces. Args: - lattice (Lattice/3x3 array): The lattice, either as a - pymatgen.core.Lattice or - simply as any 2D array. Each row should correspond to a lattice + lattice (Lattice | np.ndarray): The lattice, either as a pymatgen.core.Lattice + or a 3x3 array. Each row should correspond to a lattice vector. e.g. [[10,0,0], [20,10,0], [0,0,30]] specifies a lattice with lattice vectors [10,0,0], [20,10,0] and [0,0,30]. species ([Species]): Sequence of species on each site. Can take in diff --git a/src/pymatgen/core/ion.py b/src/pymatgen/core/ion.py index 52408713463..03055da8095 100644 --- a/src/pymatgen/core/ion.py +++ b/src/pymatgen/core/ion.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.core.composition import Composition, reduce_formula from pymatgen.util.string import Stringify, charge_string, formula_double_format diff --git a/src/pymatgen/core/lattice.py b/src/pymatgen/core/lattice.py index 89f28628e6d..7782ac10ebb 100644 --- a/src/pymatgen/core/lattice.py +++ b/src/pymatgen/core/lattice.py @@ -16,17 +16,19 @@ import numpy as np from monty.dev import deprecated from monty.json import MSONable +from scipy.spatial import Voronoi + from pymatgen.util.coord import pbc_shortest_vectors from pymatgen.util.due import Doi, due -from scipy.spatial import Voronoi if TYPE_CHECKING: from collections.abc import Iterator from numpy.typing import ArrayLike + from typing_extensions import Self + from pymatgen.core.operations import SymmOp from pymatgen.util.typing import MillerIndex, PbcLike, Vector3D - from typing_extensions import Self __author__ = "Shyue Ping Ong, Michael Kocher" __copyright__ = "Copyright 2011, The Materials Project" @@ -1298,8 +1300,7 @@ def dot( Args: coords_a: Array-like coordinates. coords_b: Array-like coordinates. - frac_coords (bool): Boolean stating whether the vector - corresponds to fractional or Cartesian coordinates. + frac_coords (bool): True if the vectors are fractional (as opposed to Cartesian) coordinates. Returns: one-dimensional `numpy` array. diff --git a/src/pymatgen/core/operations.py b/src/pymatgen/core/operations.py index c67603c288f..54f24f8515b 100644 --- a/src/pymatgen/core/operations.py +++ b/src/pymatgen/core/operations.py @@ -10,6 +10,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.electronic_structure.core import Magmom from pymatgen.util.due import Doi, due from pymatgen.util.string import transformation_to_string diff --git a/src/pymatgen/core/periodic_table.py b/src/pymatgen/core/periodic_table.py index 7fe90d45ce5..d6992dc3c15 100644 --- a/src/pymatgen/core/periodic_table.py +++ b/src/pymatgen/core/periodic_table.py @@ -17,6 +17,7 @@ import numpy as np from monty.dev import deprecated from monty.json import MSONable + from pymatgen.core.units import SUPPORTED_UNIT_NAMES, FloatWithUnit, Ha_to_eV, Length, Mass, Unit from pymatgen.io.core import ParseError from pymatgen.util.string import Stringify, formula_double_format @@ -24,9 +25,10 @@ if TYPE_CHECKING: from typing import Any, Callable, Literal - from pymatgen.util.typing import SpeciesLike from typing_extensions import Self + from pymatgen.util.typing import SpeciesLike + # Load element data from JSON file with open(Path(__file__).absolute().parent / "periodic_table.json", encoding="utf-8") as ptable_json: _pt_data = json.load(ptable_json) @@ -1607,9 +1609,8 @@ def get_el_sp(obj: int | SpeciesLike) -> Element | Species | DummySpecies: will be attempted. Args: - obj (Element/Species/str/int): An arbitrary object. Supported objects - are actual Element/Species objects, integers (representing atomic - numbers) or strings (element symbols or species strings). + obj (SpeciesLike): An arbitrary object. Supported objects are actual Element/Species, + integers (representing atomic numbers) or strings (element symbols or species strings). Raises: ValueError: if obj cannot be converted into an Element or Species. diff --git a/src/pymatgen/core/sites.py b/src/pymatgen/core/sites.py index fd8a61118ef..1e948a2d3e1 100644 --- a/src/pymatgen/core/sites.py +++ b/src/pymatgen/core/sites.py @@ -8,6 +8,7 @@ import numpy as np from monty.json import MontyDecoder, MontyEncoder, MSONable + from pymatgen.core.composition import Composition from pymatgen.core.lattice import Lattice from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp @@ -17,9 +18,10 @@ from typing import Any from numpy.typing import ArrayLike - from pymatgen.util.typing import CompositionLike, SpeciesLike, Vector3D from typing_extensions import Self + from pymatgen.util.typing import CompositionLike, SpeciesLike, Vector3D + class Site(collections.abc.Hashable, MSONable): """A generalized *non-periodic* site. This is essentially a composition @@ -44,7 +46,7 @@ def __init__( Args: species: Species on the site. Can be: i. A Composition-type object (preferred) - ii. An element / species specified either as a string + ii. An element / species specified either as a string symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, e.g. 3, 56, or actual Element or Species objects. iii.Dict of elements/species and occupancies, e.g. @@ -303,7 +305,7 @@ def __init__( Args: species: Species on the site. Can be: i. A Composition-type object (preferred) - ii. An element / species specified either as a string + ii. An element / species specified either as a string symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, e.g. 3, 56, or actual Element or Species objects. iii.Dict of elements/species and occupancies, e.g. diff --git a/src/pymatgen/core/spectrum.py b/src/pymatgen/core/spectrum.py index 94ade3ce43e..6567f3a63a2 100644 --- a/src/pymatgen/core/spectrum.py +++ b/src/pymatgen/core/spectrum.py @@ -8,10 +8,11 @@ import numpy as np from monty.json import MSONable -from pymatgen.util.coord import get_linear_interpolated_value from scipy import stats from scipy.ndimage import convolve1d +from pymatgen.util.coord import get_linear_interpolated_value + if TYPE_CHECKING: from typing import Callable, Literal diff --git a/src/pymatgen/core/structure.py b/src/pymatgen/core/structure.py index fb864bbd0bb..14aa7c19866 100644 --- a/src/pymatgen/core/structure.py +++ b/src/pymatgen/core/structure.py @@ -31,6 +31,12 @@ from monty.json import MSONable from numpy import cross, eye from numpy.linalg import norm +from ruamel.yaml import YAML +from scipy.cluster.hierarchy import fcluster, linkage +from scipy.linalg import expm, polar +from scipy.spatial.distance import squareform +from tabulate import tabulate + from pymatgen.core.bonds import CovalentBond, get_bond_length from pymatgen.core.composition import Composition from pymatgen.core.lattice import Lattice, get_points_in_spheres @@ -41,11 +47,6 @@ from pymatgen.electronic_structure.core import Magmom from pymatgen.symmetry.maggroups import MagneticSpaceGroup from pymatgen.util.coord import all_distances, get_angle, lattice_points_in_supercell -from ruamel.yaml import YAML -from scipy.cluster.hierarchy import fcluster, linkage -from scipy.linalg import expm, polar -from scipy.spatial.distance import squareform -from tabulate import tabulate if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence @@ -58,9 +59,10 @@ from ase.optimize.optimize import Optimizer from matgl.ext.ase import TrajectoryObserver from numpy.typing import ArrayLike, NDArray - from pymatgen.util.typing import CompositionLike, MillerIndex, PathLike, PbcLike, SpeciesLike from typing_extensions import Self + from pymatgen.util.typing import CompositionLike, MillerIndex, PathLike, PbcLike, SpeciesLike + FileFormats = Literal["cif", "poscar", "cssr", "json", "yaml", "yml", "xsf", "mcsqs", "res", "pwmat", ""] StructureSources = Literal["Materials Project", "COD"] @@ -827,6 +829,7 @@ def _relax( from ase.constraints import ExpCellFilter from ase.io import read from ase.optimize.optimize import Optimizer + from pymatgen.io.ase import AseAtomsAdaptor opt_kwargs = opt_kwargs or {} @@ -1253,7 +1256,7 @@ def from_spacegroup( are generated from the spacegroup operations. Args: - sg (str/int): The spacegroup. If a string, it will be interpreted + sg (str | int): The spacegroup. If a string, it will be interpreted as one of the notations supported by pymatgen.symmetry.groups.Spacegroup. e.g. "R-3c" or "Fm-3m". If an int, it will be interpreted as an international number. @@ -1542,7 +1545,7 @@ def matches( Basically a convenience method to call structure matching. Args: - other (IStructure/Structure): Another structure. + other (IStructure | Structure): Another structure. anonymous (bool): Whether to use anonymous structure matching which allows distinct species in one structure to map to another. **kwargs: Same **kwargs as in @@ -1813,7 +1816,7 @@ def get_symmetric_neighbor_list( Args: r (float): Radius of sphere - sg (str/int): The spacegroup the symmetry operations of which will be + sg (str | int): The spacegroup the symmetry operations of which will be used to classify the neighbors. If a string, it will be interpreted as one of the notations supported by pymatgen.symmetry.groups.Spacegroup. e.g. "R-3c" or "Fm-3m". @@ -3938,9 +3941,8 @@ def __setitem__( # type: ignore[override] """Modify a site in the structure. Args: - idx (int, [int], slice, Species-like): Indices to change. You can - specify these as an int, a list of int, or a species-like - string. + idx (int, list[int], slice, Species-like): Indices to change. You can + specify these as an int, a list of int, or a species-like string. site (PeriodicSite | Species | dict[SpeciesLike, float] | Sequence): 4 options exist. You can provide a PeriodicSite directly (lattice will be checked). Or more conveniently, you can provide a species-like object (or a dict mapping SpeciesLike to occupancy floats) @@ -4770,9 +4772,8 @@ def __setitem__( # type: ignore[override] """Modify a site in the molecule. Args: - idx (int, [int], slice, Species-like): Indices to change. You can - specify these as an int, a list of int, or a species-like - string. + idx (int, list[int], slice, Species-like): Indices to change. You can + specify these as an int, a list of int, or a species-like string. site (PeriodicSite/Species/Sequence): Three options exist. You can provide a Site directly, or for convenience, you can provide simply a Species-like string/object, or finally a (Species, diff --git a/src/pymatgen/core/surface.py b/src/pymatgen/core/surface.py index f59d3a3cbc0..6bc5678f70d 100644 --- a/src/pymatgen/core/surface.py +++ b/src/pymatgen/core/surface.py @@ -27,24 +27,26 @@ import numpy as np from monty.fractions import lcm +from scipy.cluster.hierarchy import fcluster, linkage +from scipy.spatial.distance import squareform + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import Lattice, PeriodicSite, Structure, get_el_sp from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.coord import in_coord_list from pymatgen.util.due import Doi, due from pymatgen.util.typing import Tuple3Ints -from scipy.cluster.hierarchy import fcluster, linkage -from scipy.spatial.distance import squareform if TYPE_CHECKING: from collections.abc import Sequence from typing import Any from numpy.typing import ArrayLike, NDArray + from typing_extensions import Self + from pymatgen.core.composition import Element, Species from pymatgen.symmetry.groups import CrystalSystem from pymatgen.util.typing import MillerIndex - from typing_extensions import Self __author__ = "Richard Tran, Wenhao Sun, Zihan Xu, Shyue Ping Ong" @@ -288,7 +290,7 @@ def is_symmetric(self, symprec: float = 0.1) -> bool: symprec (float): Symmetry precision used for SpaceGroup analyzer. Returns: - bool: Whether surfaces are symmetric. + bool: True if surfaces are symmetric. """ spg_analyzer = SpacegroupAnalyzer(self, symprec=symprec) symm_ops = spg_analyzer.get_point_group_operations() diff --git a/src/pymatgen/core/tensors.py b/src/pymatgen/core/tensors.py index 9a8b5373713..18575b0c172 100644 --- a/src/pymatgen/core/tensors.py +++ b/src/pymatgen/core/tensors.py @@ -15,20 +15,22 @@ import numpy as np from monty.json import MSONable from monty.serialization import loadfn +from scipy.linalg import polar + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core.lattice import Lattice from pymatgen.core.operations import SymmOp from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from scipy.linalg import polar if TYPE_CHECKING: from collections.abc import Sequence from typing import Any from numpy.typing import NDArray - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + __author__ = "Joseph Montoya" __credits__ = "Maarten de Jong, Shyam Dwaraknath, Wei Chen, Mark Asta, Anubhav Jain, Terence Lew" diff --git a/src/pymatgen/core/trajectory.py b/src/pymatgen/core/trajectory.py index 1d50e2a8df5..ad595bb0e45 100644 --- a/src/pymatgen/core/trajectory.py +++ b/src/pymatgen/core/trajectory.py @@ -13,6 +13,7 @@ import numpy as np from monty.io import zopen from monty.json import MSONable + from pymatgen.core.structure import Composition, DummySpecies, Element, Lattice, Molecule, Species, Structure from pymatgen.io.ase import AseAtomsAdaptor @@ -20,9 +21,10 @@ from collections.abc import Iterator from typing import Any - from pymatgen.util.typing import Matrix3D, PathLike, SitePropsType, Vector3D from typing_extensions import Self + from pymatgen.util.typing import Matrix3D, PathLike, SitePropsType, Vector3D + __author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith" __version__ = "0.1" diff --git a/src/pymatgen/core/xcfunc.py b/src/pymatgen/core/xcfunc.py index 4b2e46d7376..a17d9c939fb 100644 --- a/src/pymatgen/core/xcfunc.py +++ b/src/pymatgen/core/xcfunc.py @@ -6,6 +6,7 @@ from monty.functools import lazy_property from monty.json import MSONable + from pymatgen.core.libxcfunc import LibxcFunc if TYPE_CHECKING: diff --git a/src/pymatgen/electronic_structure/bandstructure.py b/src/pymatgen/electronic_structure/bandstructure.py index f03160bb9ea..74a25d4edb6 100644 --- a/src/pymatgen/electronic_structure/bandstructure.py +++ b/src/pymatgen/electronic_structure/bandstructure.py @@ -1,4 +1,4 @@ -"""This module provides classes to define everything related to band structures.""" +"""This module provides classes to define things related to band structures.""" from __future__ import annotations @@ -7,10 +7,11 @@ import re import warnings from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload import numpy as np from monty.json import MSONable + from pymatgen.core import Element, Lattice, Structure, get_el_sp from pymatgen.electronic_structure.core import Orbital, Spin from pymatgen.symmetry.analyzer import SpacegroupAnalyzer @@ -19,6 +20,7 @@ if TYPE_CHECKING: from typing import Any + from numpy.typing import NDArray from typing_extensions import Self __author__ = "Geoffroy Hautier, Shyue Ping Ong, Michael Kocher" @@ -31,14 +33,13 @@ class Kpoint(MSONable): - """Store kpoint objects. A kpoint is defined with a lattice and frac - or Cartesian coordinates syntax similar than the site object in - pymatgen.core.structure. + """A kpoint defined with a lattice and frac or Cartesian coordinates, + similar to the Site object in pymatgen.core.structure. """ def __init__( self, - coords: np.ndarray, + coords: NDArray, lattice: Lattice, to_unit_cell: bool = False, coords_are_cartesian: bool = False, @@ -46,15 +47,14 @@ def __init__( ) -> None: """ Args: - coords: coordinate of the kpoint as a numpy array - lattice: A pymatgen.core.Lattice object representing - the reciprocal lattice of the kpoint - to_unit_cell: Translates fractional coordinate to the basic unit + coords (NDArray): Coordinate of the Kpoint. + lattice (Lattice): The reciprocal lattice of the kpoint. + to_unit_cell (bool): Translate fractional coordinate to the basic unit cell, i.e., all fractional coordinates satisfy 0 <= a < 1. Defaults to False. - coords_are_cartesian: Boolean indicating if the coordinates given are - in Cartesian or fractional coordinates (by default fractional) - label: the label of the kpoint if any (None by default). + coords_are_cartesian (bool): Whether the coordinates given are + in Cartesian (True) or fractional coordinates (by default fractional). + label (str): The label of the Kpoint if any (None by default). """ self._lattice = lattice self._frac_coords = lattice.get_fractional_coords(coords) if coords_are_cartesian else coords @@ -66,11 +66,24 @@ def __init__( self._cart_coords = lattice.get_cartesian_coords(self._frac_coords) + def __str__(self) -> str: + """String with fractional, Cartesian coordinates and label.""" + return f"{self.frac_coords} {self.cart_coords} {self.label}" + + def __eq__(self, other: object) -> bool: + """Whether two Kpoints are equal.""" + if not isinstance(other, type(self)): + return NotImplemented + + return ( + np.allclose(self.frac_coords, other.frac_coords) + and self.lattice == other.lattice + and self.label == other.label + ) + @property def lattice(self) -> Lattice: - """The lattice associated with the kpoint. It's a - pymatgen.core.Lattice object. - """ + """The lattice associated with the kpoint, as a Lattice object.""" return self._lattice @property @@ -84,13 +97,13 @@ def label(self, label: str | None) -> None: self._label = label @property - def frac_coords(self) -> np.ndarray: - """The fractional coordinates of the kpoint as a numpy array.""" + def frac_coords(self) -> NDArray: + """The fractional coordinates of the kpoint as a NumPy array.""" return np.copy(self._frac_coords) @property - def cart_coords(self) -> np.ndarray: - """The Cartesian coordinates of the kpoint as a numpy array.""" + def cart_coords(self) -> NDArray: + """The Cartesian coordinates of the kpoint as a NumPy array.""" return np.copy(self._cart_coords) @property @@ -108,22 +121,8 @@ def c(self) -> float: """Fractional c coordinate of the kpoint.""" return self._frac_coords[2] - def __str__(self) -> str: - """Get a string with fractional, Cartesian coordinates and label.""" - return f"{self.frac_coords} {self.cart_coords} {self.label}" - - def __eq__(self, other: object) -> bool: - """Check if two kpoints are equal.""" - if not isinstance(other, Kpoint): - return NotImplemented - return ( - np.allclose(self.frac_coords, other.frac_coords) - and self.lattice == other.lattice - and self.label == other.label - ) - def as_dict(self) -> dict[str, Any]: - """JSON-serializable dict representation of a kpoint.""" + """JSON-serializable dict representation of the kpoint.""" return { "lattice": self.lattice.as_dict(), "fcoords": self.frac_coords.tolist(), @@ -135,7 +134,7 @@ def as_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, dct: dict) -> Self: - """Create from dict. + """Create from a dict. Args: dct (dict): A dict with all data for a kpoint object. @@ -148,59 +147,57 @@ def from_dict(cls, dct: dict) -> Self: class BandStructure: - """This is the most generic band structure data possible - it's defined by a list of kpoints + energies for each of them. + """Generic band structure data, defined by a list of Kpoints + and corresponding energies for each of them. Attributes: - kpoints (list): The list of kpoints (as Kpoint objects) in the band structure. + kpoints (list[Kpoint]): Kpoints in the band structure. lattice_rec (Lattice): The reciprocal lattice of the band structure. - efermi (float): The Fermi energy. - is_spin_polarized (bool): True if the band structure is spin-polarized. - bands (dict): The energy eigenvalues as a {spin: array}. Note that the use of an - array is necessary for computational as well as memory efficiency due to the large - amount of numerical data. The indices of the array are [band_index, kpoint_index]. - nb_bands (int): Returns the number of bands in the band structure. - structure (Structure): Returns the structure. - projections (dict): The projections as a {spin: array}. Note that the use of an - array is necessary for computational as well as memory efficiency due to the large - amount of numerical data. The indices of the array are [band_index, kpoint_index, - orbital_index, ion_index]. + efermi (float): The Fermi level. + is_spin_polarized (bool): Whether the band structure is spin-polarized. + bands (dict[Spin, NDArray]): The energy eigenvalues. Note that the use of an + array is necessary for computational and memory efficiency due to the large + amount of numerical data. The indices of the array are (band_index, kpoint_index). + nb_bands (int): The number of bands in the band structure. + structure (Structure): The structure. + projections (dict[Spin, NDArray]): The projections. Note that the use of an + array is necessary for computational and memory efficiency due to the large + amount of numerical data. The indices of the array are (band_index, kpoint_index, + orbital_index, ion_index). """ def __init__( self, - kpoints: np.ndarray, - eigenvals: dict[Spin, np.ndarray], + kpoints: NDArray, + eigenvals: dict[Spin, NDArray], lattice: Lattice, efermi: float, - labels_dict=None, + labels_dict: dict[str, Kpoint] | None = None, coords_are_cartesian: bool = False, structure: Structure | None = None, - projections: dict[Spin, np.ndarray] | None = None, + projections: dict[Spin, NDArray] | None = None, ) -> None: """ Args: - kpoints: list of kpoint as numpy arrays, in frac_coords of the - given lattice by default - eigenvals: dict of energies for spin up and spin down - {Spin.up:[][],Spin.down:[][]}, the first index of the array - [][] refers to the band and the second to the index of the - kpoint. The kpoints are ordered according to the order of the - kpoints array. If the band structure is not spin polarized, we - only store one data set under Spin.up - lattice: The reciprocal lattice as a pymatgen Lattice object. - Pymatgen uses the physics convention of reciprocal lattice vectors - WITH a 2*pi coefficient - efermi (float): Fermi energy - labels_dict: (dict) of {} this links a kpoint (in frac coords or - Cartesian coordinates depending on the coords) to a label. - coords_are_cartesian: Whether coordinates are cartesian. - structure: The crystal structure (as a pymatgen Structure object) + kpoints (NDArray): Kpoint as NumPy array, in frac_coords of the + given lattice by default. + eigenvals (dict): Energies for spin up and spin down as + {Spin.up:[][], Spin.down:[][]}, the first index of the array + [][] refers to the band and the second to the index of the kpoint. + The kpoints are ordered according to the kpoints array. + If the band structure is not spin polarized, we + only store one data set under Spin.up. + lattice (Lattice): The reciprocal lattice. Pymatgen uses the physics + convention of reciprocal lattice vectors with a 2*pi coefficient. + efermi (float): The Fermi level. + labels_dict (dict[str, Kpoint]): Dict mapping label to Kpoint. + coords_are_cartesian (bool): Whether coordinates are cartesian. + structure (Structure): The crystal structure associated with the band structure. This is needed if we - provide projections to the band structure - projections: dict of orbital projections as {spin: array}. The - indices of the array are [band_index, kpoint_index, orbital_index, - ion_index].If the band structure is not spin polarized, we only + provide projections to the band structure. + projections (dict[Spin, NDArray]): Orbital projections. The + indices of the array are (band_index, kpoint_index, orbital_index, + ion_index). If the band structure is not spin polarized, we only store one data set under Spin.up. """ self.efermi = efermi @@ -214,58 +211,58 @@ def __init__( if labels_dict is None: labels_dict = {} - if len(self.projections) != 0 and self.structure is None: + if self.projections and self.structure is None: raise RuntimeError("if projections are provided a structure object is also required") - for k in kpoints: - # let see if this kpoint has been assigned a label + for kpt in kpoints: + # Check if this Kpoint has a label label = None for c in labels_dict: - if np.linalg.norm(k - np.array(labels_dict[c])) < 0.0001: + if np.linalg.norm(kpt - np.array(labels_dict[c])) < 0.0001: label = c self.labels_dict[label] = Kpoint( - k, + kpt, lattice, label=label, coords_are_cartesian=coords_are_cartesian, ) - self.kpoints.append(Kpoint(k, lattice, label=label, coords_are_cartesian=coords_are_cartesian)) + self.kpoints.append(Kpoint(kpt, lattice, label=label, coords_are_cartesian=coords_are_cartesian)) self.bands = {spin: np.array(v) for spin, v in eigenvals.items()} self.nb_bands = len(eigenvals[Spin.up]) self.is_spin_polarized = len(self.bands) == 2 - def get_projection_on_elements(self): - """Get a dictionary of projections on elements. + def get_projection_on_elements(self) -> dict[Spin, NDArray]: + """Get projections on elements. Returns: - a dictionary in the {Spin.up:[][{Element: [values]}], - Spin.down:[][{Element: [values]}]} format - if there is no projections in the band structure - returns an empty dict + dict[Spin, NDArray]: Dict in {Spin.up:[][{Element: [values]}], + Spin.down: [][{Element: [values]}]} format. + If there is no projections in the band structure, return {}. """ - result = {} - for spin, v in self.projections.items(): - result[spin] = [[defaultdict(float) for i in range(len(self.kpoints))] for j in range(self.nb_bands)] + assert self.structure is not None + result: dict[Spin, NDArray] = {} + for spin, val in self.projections.items(): + result[spin] = [[defaultdict(float) for _ in range(len(self.kpoints))] for _ in range(self.nb_bands)] for i, j, k in itertools.product( range(self.nb_bands), range(len(self.kpoints)), range(len(self.structure)), ): - result[spin][i][j][str(self.structure[k].specie)] += np.sum(v[i, j, :, k]) + result[spin][i][j][str(self.structure[k].specie)] += np.sum(val[i, j, :, k]) return result def get_projections_on_elements_and_orbitals(self, el_orb_spec: dict[str, list[str]]): - """Get a dictionary of projections on elements and specific orbitals. + """Get projections on elements and specific orbitals. Args: - el_orb_spec (dict[str, list[str]]): A dictionary of elements and orbitals which - to project onto. Format is {Element: [orbitals]}, e.g. {'Cu':['d','s']}. + el_orb_spec (dict[str, list[str]]): Elements and orbitals to project onto. + Format is {Element: [orbitals]}, e.g. {"Cu": ["d", "s"]}. Returns: - A dictionary of projections on elements in the - {Spin.up:[][{Element:{orb:values}}], - Spin.down:[][{Element:{orb:values}}]} format - if there is no projections in the band structure returns an empty dict. + dict[str, list[str]: Projections on elements in the + {Spin.up: [][{Element: {orb: values}}], + Spin.down: [][{Element: {orb: values}}]} format. + If there is no projections in the band structure, return {}. """ if self.structure is None: raise ValueError("Structure is required for this method") @@ -273,8 +270,8 @@ def get_projections_on_elements_and_orbitals(self, el_orb_spec: dict[str, list[s species_orb_spec = {get_el_sp(el): orbs for el, orbs in el_orb_spec.items()} for spin, v in self.projections.items(): result[spin] = [ - [{str(e): defaultdict(float) for e in species_orb_spec} for i in range(len(self.kpoints))] - for j in range(self.nb_bands) + [{str(e): defaultdict(float) for e in species_orb_spec} for _ in range(len(self.kpoints))] + for _ in range(self.nb_bands) ] for i, j, k in itertools.product( @@ -289,12 +286,12 @@ def get_projections_on_elements_and_orbitals(self, el_orb_spec: dict[str, list[s result[spin][i][j][str(sp)][o] += v[i][j][orb_i][k] return result - def is_metal(self, efermi_tol=1e-4) -> bool: - """Check if the band structure indicates a metal by looking if the fermi - level crosses a band. + def is_metal(self, efermi_tol: float = 1e-4) -> bool: + """Check if the band structure indicates a metal, + by looking at if the fermi level crosses a band. Returns: - bool: True if a metal. + bool: True if is metal. """ for vals in self.bands.values(): for idx in range(self.nb_bands): @@ -302,26 +299,26 @@ def is_metal(self, efermi_tol=1e-4) -> bool: return True return False - def get_vbm(self): - """Get data about the VBM. + def get_vbm(self) -> dict[str, Any]: + """Get data about the valence band maximum (VBM). Returns: - dict: With keys "band_index", "kpoint_index", "kpoint", "energy" - - "band_index": A dict with spin keys pointing to a list of the - indices of the band containing the VBM (please note that you - can have several bands sharing the VBM) {Spin.up:[], - Spin.down:[]} - - "kpoint_index": The list of indices in self.kpoints for the - kpoint VBM. Please note that there can be several - kpoint_indices relating to the same kpoint (e.g., Gamma can - occur at different spots in the band structure line plot) - - "kpoint": The kpoint (as a kpoint object) - - "energy": The energy of the VBM - - "projections": The projections along sites and orbitals of the - VBM if any projection data is available (else it is an empty - dictionary). The format is similar to the projections field in - BandStructure: {spin:{'Orbital': [proj]}} where the array - [proj] is ordered according to the sites in structure + dict with keys "band_index", "kpoint_index", "kpoint", "energy": + - "band_index" (dict): A dict with spin keys pointing to a list of the + indices of the band containing the VBM (please note that you + can have several bands sharing the VBM) {Spin.up:[], + Spin.down:[]}. + - "kpoint_index": The list of indices in self.kpoints for the + kpoint VBM. Please note that there can be several + kpoint_indices relating to the same kpoint (e.g., Gamma can + occur at different spots in the band structure line plot). + - "kpoint" (Kpoint): The kpoint. + - "energy" (float): The energy of the VBM. + - "projections": The projections along sites and orbitals of the + VBM if any projection data is available (else it is an empty + dictionary). The format is similar to the projections field in + BandStructure: {spin:{'Orbital': [proj]}} where the array + [proj] is ordered according to the sites in structure. """ if self.is_metal(): return { @@ -331,33 +328,36 @@ def get_vbm(self): "energy": None, "projections": {}, } + max_tmp = -float("inf") index = kpoint_vbm = None for value in self.bands.values(): - for i, j in zip(*np.where(value < self.efermi)): - if value[i, j] > max_tmp: - max_tmp = float(value[i, j]) + for idx, j in zip(*np.where(value < self.efermi)): + if value[idx, j] > max_tmp: + max_tmp = float(value[idx, j]) index = j kpoint_vbm = self.kpoints[j] list_ind_kpts = [] - if kpoint_vbm.label is not None: - for i, kpt in enumerate(self.kpoints): + if kpoint_vbm is not None and kpoint_vbm.label is not None: + for idx, kpt in enumerate(self.kpoints): if kpt.label == kpoint_vbm.label: - list_ind_kpts.append(i) + list_ind_kpts.append(idx) else: list_ind_kpts.append(index) - # get all other bands sharing the vbm + + # Get all other bands sharing the VBM list_ind_band = defaultdict(list) for spin in self.bands: - for i in range(self.nb_bands): - if math.fabs(self.bands[spin][i][index] - max_tmp) < 0.001: - list_ind_band[spin].append(i) + for idx in range(self.nb_bands): + if math.fabs(self.bands[spin][idx][index] - max_tmp) < 0.001: + list_ind_band[spin].append(idx) proj = {} for spin, value in self.projections.items(): if len(list_ind_band[spin]) == 0: continue proj[spin] = value[list_ind_band[spin][0]][list_ind_kpts[0]] + return { "band_index": list_ind_band, "kpoint_index": list_ind_kpts, @@ -366,25 +366,25 @@ def get_vbm(self): "projections": proj, } - def get_cbm(self): - """Get data about the CBM. + def get_cbm(self) -> dict[str, Any]: + """Get data about the conduction band minimum (CBM). Returns: - dict[str, Any]: with keys band_index, kpoint_index, kpoint, energy. - - "band_index": A dict with spin keys pointing to a list of the + dict with keys "band_index", "kpoint_index", "kpoint", "energy": + - "band_index" (dict): A dict with spin keys pointing to a list of the indices of the band containing the CBM (please note that you - can have several bands sharing the CBM) {Spin.up:[], Spin.down:[]} + can have several bands sharing the CBM) {Spin.up:[], Spin.down:[]}. - "kpoint_index": The list of indices in self.kpoints for the kpoint CBM. Please note that there can be several kpoint_indices relating to the same kpoint (e.g., Gamma can - occur at different spots in the band structure line plot) - - "kpoint": The kpoint (as a kpoint object) - - "energy": The energy of the CBM + occur at different spots in the band structure line plot). + - "kpoint" (Kpoint): The kpoint. + - "energy" (float): The energy of the CBM. - "projections": The projections along sites and orbitals of the CBM if any projection data is available (else it is an empty dictionary). The format is similar to the projections field in BandStructure: {spin:{'Orbital': [proj]}} where the array - [proj] is ordered according to the sites in structure + [proj] is ordered according to the sites in structure. """ if self.is_metal(): return { @@ -394,30 +394,30 @@ def get_cbm(self): "energy": None, "projections": {}, } - max_tmp = float("inf") + max_tmp = float("inf") index = kpoint_cbm = None for value in self.bands.values(): - for i, j in zip(*np.where(value >= self.efermi)): - if value[i, j] < max_tmp: - max_tmp = float(value[i, j]) + for idx, j in zip(*np.where(value >= self.efermi)): + if value[idx, j] < max_tmp: + max_tmp = float(value[idx, j]) index = j kpoint_cbm = self.kpoints[j] list_index_kpoints = [] - if kpoint_cbm.label is not None: - for i, kpt in enumerate(self.kpoints): + if kpoint_cbm is not None and kpoint_cbm.label is not None: + for idx, kpt in enumerate(self.kpoints): if kpt.label == kpoint_cbm.label: - list_index_kpoints.append(i) + list_index_kpoints.append(idx) else: list_index_kpoints.append(index) - # get all other bands sharing the cbm + # Get all other bands sharing the CBM list_index_band = defaultdict(list) for spin in self.bands: - for i in range(self.nb_bands): - if math.fabs(self.bands[spin][i][index] - max_tmp) < 0.001: - list_index_band[spin].append(i) + for idx in range(self.nb_bands): + if math.fabs(self.bands[spin][idx][index] - max_tmp) < 0.001: + list_index_band[spin].append(idx) proj = {} for spin, value in self.projections.items(): if len(list_index_band[spin]) == 0: @@ -432,22 +432,25 @@ def get_cbm(self): "projections": proj, } - def get_band_gap(self): - r"""Get band gap data. + def get_band_gap(self) -> dict[str, Any]: + r"""Get band gap. Returns: - A dict {"energy","direct","transition"}: - "energy": band gap energy - "direct": A boolean telling if the gap is direct or not - "transition": kpoint labels of the transition (e.g., "\\Gamma-X") + dict with keys "energy", "direct", "transition": + "energy" (float): Band gap energy. + "direct" (bool): Whether the gap is direct. + "transition" (str): Kpoint labels of the transition (e.g., "\\Gamma-X"). """ if self.is_metal(): return {"energy": 0.0, "direct": False, "transition": None} + cbm = self.get_cbm() vbm = self.get_vbm() - result = {"direct": False, "energy": 0.0, "transition": None} - - result["energy"] = cbm["energy"] - vbm["energy"] + result = { + "direct": False, + "transition": None, + "energy": cbm["energy"] - vbm["energy"], + } if (cbm["kpoint"].label is not None and cbm["kpoint"].label == vbm["kpoint"].label) or np.linalg.norm( cbm["kpoint"].cart_coords - vbm["kpoint"].cart_coords @@ -463,16 +466,16 @@ def get_band_gap(self): return result - def get_direct_band_gap_dict(self): - """Get a dictionary of information about the direct - band gap. + def get_direct_band_gap_dict(self) -> dict[Spin, dict[str, Any]]: + """Get information about the direct band gap. Returns: - a dictionary of the band gaps indexed by spin - along with their band indices and k-point index + dict[Spin, dict[str, Any]]: The band gaps indexed by spin + along with their band indices and kpoint index. """ if self.is_metal(): raise ValueError("get_direct_band_gap_dict should only be used with non-metals") + direct_gap_dict = {} for spin, v in self.bands.items(): above = v[np.all(v > self.efermi, axis=1)] @@ -492,35 +495,42 @@ def get_direct_band_gap_dict(self): } return direct_gap_dict - def get_direct_band_gap(self): + def get_direct_band_gap(self) -> float: """Get the direct band gap. Returns: - the value of the direct band gap + float: The direct band gap value. """ if self.is_metal(): return 0.0 + dg = self.get_direct_band_gap_dict() return min(v["value"] for v in dg.values()) - def get_sym_eq_kpoints(self, kpoint, cartesian=False, tol: float = 1e-2): - """Get a list of unique symmetrically equivalent k-points. + def get_sym_eq_kpoints( + self, + kpoint: NDArray, + cartesian: bool = False, + tol: float = 1e-2, + ) -> NDArray: + """Get unique symmetrically equivalent Kpoints. Args: - kpoint (1x3 array): coordinate of the k-point - cartesian (bool): kpoint is in Cartesian or fractional coordinates - tol (float): tolerance below which coordinates are considered equal + kpoint (1x3 array): Coordinate of the Kpoint. + cartesian (bool): Whether kpoint is in Cartesian or fractional coordinates. + tol (float): Tolerance below which coordinates are considered equal. Returns: - list[1x3 array] | None: if structure is not available returns None + (1x3 NDArray) | None: None if structure is not available. """ if not self.structure: return None + sg = SpacegroupAnalyzer(self.structure) symm_ops = sg.get_point_group_operations(cartesian=cartesian) points = np.dot(kpoint, [m.rotation_matrix for m in symm_ops]) rm_list = [] - # identify and remove duplicates from the list of equivalent k-points: + # Identify and remove duplicates from equivalent k-points for i in range(len(points) - 1): for j in range(i + 1, len(points)): if np.allclose(pbc_diff(points[i], points[j]), [0, 0, 0], tol): @@ -528,25 +538,28 @@ def get_sym_eq_kpoints(self, kpoint, cartesian=False, tol: float = 1e-2): break return np.delete(points, rm_list, axis=0) - def get_kpoint_degeneracy(self, kpoint, cartesian=False, tol: float = 1e-2): - """Get degeneracy of a given k-point based on structure symmetry. + def get_kpoint_degeneracy( + self, + kpoint: NDArray, + cartesian: bool = False, + tol: float = 1e-2, + ) -> NDArray | None: + """Get degeneracy of a given kpoint based on structure symmetry. Args: - kpoint (1x3 array): coordinate of the k-point - cartesian (bool): kpoint is in Cartesian or fractional coordinates - tol (float): tolerance below which coordinates are considered equal. + kpoint (1x3 NDArray): Coordinate of the k-point. + cartesian (bool): Whether kpoint is in Cartesian or fractional coordinates. + tol (float): Tolerance below which coordinates are considered equal. Returns: - int | None: degeneracy or None if structure is not available + int | None: Degeneracy, or None if structure is not available. """ all_kpts = self.get_sym_eq_kpoints(kpoint, cartesian, tol=tol) - if all_kpts is not None: - return len(all_kpts) - return None + return len(all_kpts) if all_kpts is not None else None - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of BandStructure.""" - dct = { + dct: dict[str, Any] = { "@module": type(self).__module__, "@class": type(self).__name__, "lattice_rec": self.lattice_rec.as_dict(), @@ -554,7 +567,7 @@ def as_dict(self): "kpoints": [], } # kpoints are not kpoint objects dicts but are frac coords (this makes - # the dict smaller and avoids the repetition of the lattice + # the dict smaller and avoids the repetition of the lattice). for k in self.kpoints: dct["kpoints"].append(k.as_dict()["fcoords"]) @@ -578,39 +591,40 @@ def as_dict(self): dct["labels_dict"] = {} dct["is_spin_polarized"] = self.is_spin_polarized - # MongoDB does not accept keys starting with $. Add a blank space to fix the problem + # MongoDB does not accept keys starting with "$", add a space to fix this. for c, label in self.labels_dict.items(): - mongo_key = c if not c.startswith("$") else f" {c}" + mongo_key = f" {c}" if c.startswith("$") else c dct["labels_dict"][mongo_key] = label.as_dict()["fcoords"] dct["projections"] = {} - if len(self.projections) != 0: + if len(self.projections) != 0 and self.structure is not None: dct["structure"] = self.structure.as_dict() dct["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()} return dct @classmethod - def from_dict(cls, dct: dict) -> Self: - """Create from dict. + def from_dict(cls, dct: dict[str, Any]) -> Self: + """Create from a dict. Args: - dct: A dict with all data for a band structure object. + dct: A dict with all data for a BandStructure. Returns: - A BandStructure object + BandStructure """ # Strip the label to recover initial string - # (see trick used in as_dict to handle $ chars) + # (see trick used in as_dict to handle "$"" chars) labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()} - projections = {} - structure = None + if isinstance(next(iter(dct["bands"].values())), dict): eigenvals = {Spin(int(k)): np.array(dct["bands"][k]["data"]) for k in dct["bands"]} else: eigenvals = {Spin(int(k)): dct["bands"][k] for k in dct["bands"]} + structure = None if "structure" in dct: structure = Structure.from_dict(dct["structure"]) + projections = {} try: if dct.get("projections"): if isinstance(dct["projections"]["1"][0][0], dict): @@ -637,15 +651,16 @@ def from_dict(cls, dct: dict) -> Self: return cls.from_old_dict(dct) @classmethod - def from_old_dict(cls, dct) -> Self: + def from_old_dict(cls, dct: dict[str, Any]) -> Self: """ Args: - dct (dict): A dict with all data for a band structure symmetry line object. + dct (dict): A dict with all data for a BandStructure object. Returns: - A BandStructureSymmLine object + BandStructure """ - # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) + # Strip the label to recover initial string + # (see trick used in as_dict to handle "$" chars) labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()} projections: dict = {} structure = None @@ -686,37 +701,35 @@ class BandStructureSymmLine(BandStructure, MSONable): def __init__( self, - kpoints, - eigenvals, - lattice, - efermi, - labels_dict, - coords_are_cartesian=False, - structure=None, - projections=None, + kpoints: NDArray, + eigenvals: dict[Spin, list], + lattice: Lattice, + efermi: float, + labels_dict: dict[str, Kpoint], + coords_are_cartesian: bool = False, + structure: Structure | None = None, + projections: dict[Spin, NDArray] | None = None, ) -> None: """ Args: - kpoints: list of kpoint as numpy arrays, in frac_coords of the + kpoints (NDArray): Array of kpoint, in frac_coords of the given lattice by default - eigenvals: dict of energies for spin up and spin down + eigenvals (dict[Spin, list]): Energies for spin up and spin down {Spin.up:[][],Spin.down:[][]}, the first index of the array [][] refers to the band and the second to the index of the kpoint. The kpoints are ordered according to the order of the kpoints array. If the band structure is not spin polarized, we only store one data set under Spin.up. - lattice: The reciprocal lattice. - Pymatgen uses the physics convention of reciprocal lattice vectors - WITH a 2*pi coefficient - efermi: fermi energy - labels_dict: (dict) of {} this link a kpoint (in frac coords or - Cartesian coordinates depending on the coords). - coords_are_cartesian: Whether coordinates are cartesian. - structure: The crystal structure (as a pymatgen Structure object) - associated with the band structure. This is needed if we - provide projections to the band structure. - projections: dict of orbital projections as {spin: array}. The - indices of the array are [band_index, kpoint_index, orbital_index, + lattice (Lattice): The reciprocal lattice. Pymatgen uses the physics + convention of reciprocal lattice vectors with a 2*pi coefficient. + efermi (float): The Fermi level. + labels_dict (dict[str, Kpoint]): Dict mapping label to Kpoint. + coords_are_cartesian (bool): Whether coordinates are cartesian. + structure (Structure): The crystal structure associated with the + band structure. This is needed if we provide projections to + the band structure. + projections (dict[Spin, NDArray]): Orbital projections as {spin: array}. + The indices of the array are [band_index, kpoint_index, orbital_index, ion_index].If the band structure is not spin polarized, we only store one data set under Spin.up. """ @@ -734,7 +747,7 @@ def __init__( self.branches = [] one_group: list = [] branches_tmp = [] - # get labels and distance for each kpoint + # Get labels and distance for each kpoint previous_kpoint = self.kpoints[0] previous_distance = 0.0 @@ -769,44 +782,40 @@ def __init__( if len(self.bands) == 2: self.is_spin_polarized = True - def get_equivalent_kpoints(self, index): - """Get the list of kpoint indices equivalent (meaning they are the - same frac coords) to the given one. + def get_equivalent_kpoints(self, index: int) -> list[int]: + """Get kpoint indices equivalent (having the same coords) to the given one. Args: - index: the kpoint index + index (int): The kpoint index Returns: - a list of equivalent indices + list[int]: Equivalent indices. - TODO: now it uses the label we might want to use coordinates instead - (in case there was a mislabel) + TODO: now it uses the label, we might want to use coordinates + instead in case there was a mislabel. """ - # if the kpoint has no label it can't have a repetition along the band - # structure line object - + # If the kpoint has no label it can't have a repetition + # along the BandStructureSymmLine object if self.kpoints[index].label is None: return [index] list_index_kpoints = [] - for i, kpt in enumerate(self.kpoints): + for idx, kpt in enumerate(self.kpoints): if kpt.label == self.kpoints[index].label: - list_index_kpoints.append(i) + list_index_kpoints.append(idx) return list_index_kpoints - def get_branch(self, index): - r"""Get in what branch(es) is the kpoint. There can be several - branches. + def get_branch(self, index: int) -> list[dict[str, Any]]: + """Get what branch(es) is the kpoint. It takes into account the + fact that one kpoint (e.g., Gamma) can be in several branches. Args: - index: the kpoint index + index (int): The kpoint index. Returns: - A list of dictionaries [{"name","start_index","end_index","index"}] - indicating all branches in which the k_point is. It takes into - account the fact that one kpoint (e.g., \\Gamma) can be in several - branches + A list of dicts [{"name", "start_index", "end_index", "index"}] + indicating all branches in which the k_point is. """ to_return = [] for idx in self.get_equivalent_kpoints(index): @@ -822,19 +831,19 @@ def get_branch(self, index): ) return to_return - def apply_scissor(self, new_band_gap): + def apply_scissor(self, new_band_gap: float) -> Self: """Apply a scissor operator (shift of the CBM) to fit the given band gap. If it's a metal, we look for the band crossing the Fermi level and shift this one up. This will not work all the time for metals! Args: - new_band_gap: the band gap the scissor band structure need to have. + new_band_gap (float): The band gap the scissor band structure need to have. Returns: - BandStructureSymmLine: with the applied scissor shift + BandStructureSymmLine: With the applied scissor shift. """ if self.is_metal(): - # moves then the highest index band crossing the Fermi level find this band... + # Move then the highest index band crossing the Fermi level find this band... max_index = -1000 # spin_index = None for idx in range(self.nb_bands): @@ -866,6 +875,7 @@ def apply_scissor(self, new_band_gap): for v in range(len(old_dict["bands"][spin][k])): if k >= max_index: old_dict["bands"][spin][k][v] = old_dict["bands"][spin][k][v] + shift + else: shift = new_band_gap - self.get_band_gap()["energy"] old_dict = self.as_dict() @@ -875,9 +885,10 @@ def apply_scissor(self, new_band_gap): if old_dict["bands"][spin][k][v] >= old_dict["cbm"]["energy"]: old_dict["bands"][spin][k][v] = old_dict["bands"][spin][k][v] + shift old_dict["efermi"] = old_dict["efermi"] + shift + return self.from_dict(old_dict) - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of BandStructureSymmLine.""" dct = super().as_dict() dct["branches"] = self.branches @@ -885,11 +896,11 @@ def as_dict(self): class LobsterBandStructureSymmLine(BandStructureSymmLine): - """Lobster subclass of BandStructure with customized functions.""" + """LOBSTER subclass of BandStructure with customized functions.""" - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of BandStructureSymmLine.""" - dct = { + dct: dict[str, Any] = { "@module": type(self).__module__, "@class": type(self).__name__, "lattice_rec": self.lattice_rec.as_dict(), @@ -920,27 +931,28 @@ def as_dict(self): dct["band_gap"] = self.get_band_gap() dct["labels_dict"] = {} dct["is_spin_polarized"] = self.is_spin_polarized - # MongoDB does not accept keys starting with $. Add a blank space to fix the problem + + # MongoDB does not accept keys starting with "$", add a space to fix this. for c, label in self.labels_dict.items(): - mongo_key = c if not c.startswith("$") else " " + c + mongo_key = f" {c}" if c.startswith("$") else c dct["labels_dict"][mongo_key] = label.as_dict()["fcoords"] - if len(self.projections) != 0: + if len(self.projections) != 0 and self.structure is not None: dct["structure"] = self.structure.as_dict() dct["projections"] = {str(int(spin)): np.array(v).tolist() for spin, v in self.projections.items()} return dct @classmethod - def from_dict(cls, dct: dict) -> Self: + def from_dict(cls, dct: dict[str, Any]) -> Self: """ Args: - dct (dict): A dict with all data for a band structure symmetry line - object. + dct (dict): All data for a LobsterBandStructureSymmLine object. Returns: - A BandStructureSymmLine object + A LobsterBandStructureSymmLine object. """ try: - # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) + # Strip the label to recover initial string + # (see trick used in as_dict to handle "$" chars) labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()} projections = {} structure = None @@ -969,16 +981,16 @@ def from_dict(cls, dct: dict) -> Self: return cls.from_old_dict(dct) @classmethod - def from_old_dict(cls, dct) -> Self: + def from_old_dict(cls, dct: dict[str, Any]) -> Self: """ Args: - dct (dict): A dict with all data for a band structure symmetry line - object. + dct (dict): All data for a LobsterBandStructureSymmLine object. Returns: - A BandStructureSymmLine object + A LobsterBandStructureSymmLine object """ - # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) + # Strip the label to recover initial string + # (see trick used in as_dict to handle "$" chars) labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()} projections: dict = {} structure = None @@ -1004,19 +1016,18 @@ def from_old_dict(cls, dct) -> Self: projections=projections, ) - def get_projection_on_elements(self): - """Get a dictionary of projections on elements. It sums over all available orbitals + def get_projection_on_elements(self) -> dict[Spin, list]: + """Get projections on elements. It sums over all available orbitals for each element. Returns: - a dictionary in the {Spin.up:[][{Element:values}], - Spin.down:[][{Element:values}]} format - if there is no projections in the band structure - returns an empty dict + dict[Spin, list]: dict in the {Spin.up:[][{Element:values}], + Spin.down:[][{Element:values}]} format. + If there is no projections in the band structure, return {}. """ - result = {} + result: dict[Spin, list] = {} for spin, v in self.projections.items(): - result[spin] = [[defaultdict(float) for i in range(len(self.kpoints))] for j in range(self.nb_bands)] + result[spin] = [[defaultdict(float) for _ in range(len(self.kpoints))] for _ in range(self.nb_bands)] for i, j in itertools.product(range(self.nb_bands), range(len(self.kpoints))): for key, item in v[i][j].items(): for item2 in item.values(): @@ -1024,13 +1035,17 @@ def get_projection_on_elements(self): result[spin][i][j][specie] += item2 return result - def get_projections_on_elements_and_orbitals(self, el_orb_spec): - """Return a dictionary of projections on elements and specific orbitals. + def get_projections_on_elements_and_orbitals( + self, + el_orb_spec: dict[Element, list], # type: ignore[override] + ) -> dict[Spin, list]: + """Get projections on elements and specific orbitals. Args: - el_orb_spec: A dictionary of Elements and Orbitals for which we want - to have projections on. It is given as: {Element:[orbitals]}, - e.g. {'Si':['3s','3p']} or {'Si':['3s','3p_x', '3p_y', '3p_z']} depending on input files + el_orb_spec (dict): Elements and Orbitals for which we want + to project on. It is given as {Element: [orbitals]}, + e.g. {"Si": ["3s", "3p"]} or {"Si": ["3s", "3p_x", "3p_y", "3p_z']} + depending on input files. Returns: A dictionary of projections on elements in the @@ -1039,12 +1054,12 @@ def get_projections_on_elements_and_orbitals(self, el_orb_spec): if there is no projections in the band structure returns an empty dict. """ - result = {} + result: dict[Spin, list] = {} el_orb_spec = {get_el_sp(el): orbs for el, orbs in el_orb_spec.items()} for spin, v in self.projections.items(): result[spin] = [ - [{str(e): defaultdict(float) for e in el_orb_spec} for i in range(len(self.kpoints))] - for j in range(self.nb_bands) + [{str(e): defaultdict(float) for e in el_orb_spec} for _ in range(len(self.kpoints))] + for _ in range(self.nb_bands) ] for i, j in itertools.product(range(self.nb_bands), range(len(self.kpoints))): @@ -1056,39 +1071,52 @@ def get_projections_on_elements_and_orbitals(self, el_orb_spec): return result -def get_reconstructed_band_structure(list_bs, efermi=None): - """Take a list of band structures and reconstructs - one band structure object from all of them. +@overload +def get_reconstructed_band_structure( # type: ignore[overload-overlap] + list_bs: list[BandStructure], + efermi: float | None = None, +) -> BandStructure: + pass - This is typically very useful when you split non self consistent - band structure runs in several independent jobs and want to merge back - the results + +@overload +def get_reconstructed_band_structure( + list_bs: list[BandStructureSymmLine], + efermi: float | None = None, +) -> BandStructureSymmLine: + pass + + +def get_reconstructed_band_structure( + list_bs: list[BandStructure] | list[BandStructureSymmLine], + efermi: float | None = None, +) -> BandStructure | BandStructureSymmLine: + """Merge multiple BandStructure(SymmLine) objects to a single one. + + This is typically useful when you split non self-consistent band + structure runs to several independent jobs and want to merge the results. Args: - list_bs: A list of BandStructure or BandStructureSymmLine objects. - efermi: The Fermi energy of the reconstructed band structure. If - None is assigned an average of all the Fermi energy in each + list_bs (list): BandStructure or BandStructureSymmLine objects. + efermi (float): The Fermi level of the reconstructed band structure. + If None, an average of all the Fermi levels in each object in the list_bs is used. Returns: A BandStructure or BandStructureSymmLine object (depending on - the type of the list_bs objects) + the type of the objects in list_bs). """ if efermi is None: efermi = sum(b.efermi for b in list_bs) / len(list_bs) - kpoints = [] - labels_dict = {} rec_lattice = list_bs[0].lattice_rec nb_bands = min(list_bs[i].nb_bands for i in range(len(list_bs))) - kpoints = np.concatenate([[k.frac_coords for k in bs.kpoints] for bs in list_bs]) + kpoints = np.concatenate([[kpt.frac_coords for kpt in bs.kpoints] for bs in list_bs]) dicts = [bs.labels_dict for bs in list_bs] - labels_dict = {k: v.frac_coords for d in dicts for k, v in d.items()} - - eigenvals = {} - eigenvals[Spin.up] = np.concatenate([bs.bands[Spin.up][:nb_bands] for bs in list_bs], axis=1) + labels_dict = {key: val.frac_coords for dct in dicts for key, val in dct.items()} + eigenvals = {Spin.up: np.concatenate([bs.bands[Spin.up][:nb_bands] for bs in list_bs], axis=1)} if list_bs[0].is_spin_polarized: eigenvals[Spin.down] = np.concatenate([bs.bands[Spin.down][:nb_bands] for bs in list_bs], axis=1) diff --git a/src/pymatgen/electronic_structure/boltztrap.py b/src/pymatgen/electronic_structure/boltztrap.py index 7f9e60f7db7..c1db968f2a4 100644 --- a/src/pymatgen/electronic_structure/boltztrap.py +++ b/src/pymatgen/electronic_structure/boltztrap.py @@ -28,6 +28,10 @@ from monty.dev import requires from monty.json import MSONable, jsanitize from monty.os import cd +from scipy import constants +from scipy.optimize import fsolve +from scipy.spatial import distance + from pymatgen.core.lattice import Lattice from pymatgen.core.units import Energy, Length from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine, Kpoint @@ -35,17 +39,15 @@ from pymatgen.electronic_structure.dos import CompleteDos, Dos, Spin from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.symmetry.bandstructure import HighSymmKpath -from scipy import constants -from scipy.optimize import fsolve -from scipy.spatial import distance if TYPE_CHECKING: from typing import Literal from numpy.typing import ArrayLike + from typing_extensions import Self + from pymatgen.core.sites import PeriodicSite from pymatgen.core.structure import Structure - from typing_extensions import Self __author__ = "Geoffroy Hautier, Zachary Gibbs, Francesco Ricci, Anubhav Jain" __copyright__ = "Copyright 2013, The Materials Project" @@ -1936,7 +1938,7 @@ def from_files(cls, path_dir: str, dos_spin: Literal[-1, 1] = 1) -> Self: dos_spin: in DOS mode, set to 1 for spin up and -1 for spin down Returns: - a BoltztrapAnalyzer object + BoltztrapAnalyzer """ run_type, warning, efermi, gap, doping_levels = cls.parse_outputtrans(path_dir) diff --git a/src/pymatgen/electronic_structure/boltztrap2.py b/src/pymatgen/electronic_structure/boltztrap2.py index d84df06e2a9..49a6a34b3e4 100644 --- a/src/pymatgen/electronic_structure/boltztrap2.py +++ b/src/pymatgen/electronic_structure/boltztrap2.py @@ -1,29 +1,27 @@ -"""BoltzTraP2 is a python software interpolating band structures and -computing materials properties from dft band structure using Boltzmann -semi-classical transport theory. -This module provides a pymatgen interface to BoltzTraP2. +"""This module provides an interface to BoltzTraP2. Some of the code is written following the examples provided in BoltzTraP2. -BoltzTraP2 has been developed by Georg Madsen, Jesús Carrete, Matthieu J. Verstraete. +BoltzTraP2 is a Python software interpolating band structures and +computing materials properties from DFT band structure using Boltzmann +semi-classical transport theory, developed by Georg Madsen, Jesús Carrete, +Matthieu J. Verstraete. https://gitlab.com/sousaw/BoltzTraP2 https://www.sciencedirect.com/science/article/pii/S0010465518301632 -References are: - +References: Georg K.H.Madsen, Jesús Carrete, Matthieu J.Verstraete BoltzTraP2, a program for interpolating band structures and calculating semi-classical transport coefficients - Computer Physics Communications 231, 140-145, 2018 + Computer Physics Communications 231, 140-145, 2018. Madsen, G. K. H., and Singh, D. J. (2006). BoltzTraP. A code for calculating band-structure dependent quantities. - Computer Physics Communications, 175, 67-71 + Computer Physics Communications, 175, 67-71. Todo: -- DONE: spin polarized bands -- read first derivative of the eigenvalues from vasprun.xml (mommat) -- handle magnetic moments (magmom) +- Read first derivative of the eigenvalues from vasprun.xml (mommat) +- Handle magnetic moments (MAGMOM) """ from __future__ import annotations @@ -34,6 +32,8 @@ import matplotlib.pyplot as plt import numpy as np from monty.serialization import dumpfn, loadfn +from tqdm import tqdm + from pymatgen.electronic_structure.bandstructure import BandStructure, BandStructureSymmLine, Spin from pymatgen.electronic_structure.boltztrap import BoltztrapError from pymatgen.electronic_structure.dos import CompleteDos, Dos, Orbital @@ -41,10 +41,10 @@ from pymatgen.io.ase import AseAtomsAdaptor from pymatgen.io.vasp import Vasprun from pymatgen.symmetry.bandstructure import HighSymmKpath -from tqdm import tqdm if TYPE_CHECKING: from pathlib import Path + from typing import Literal from typing_extensions import Self @@ -962,20 +962,20 @@ def __init__(self, bzt_transP=None, bzt_interp=None) -> None: def plot_props( self, - prop_y, - prop_x, - prop_z="temp", - output="avg_eigs", - dop_type="n", - doping=None, - temps=None, - xlim=(-2, 2), - ax: plt.Axes = None, - ): + prop_y: str, + prop_x: Literal["mu", "doping", "temp"], + prop_z: Literal["doping", "temp"] = "temp", + output: Literal["avg_eigs", "eigs"] = "avg_eigs", + dop_type: Literal["n", "p"] = "n", + doping: list[float] | None = None, + temps: list[float] | None = None, + xlim: tuple[float, float] = (-2, 2), + ax: plt.Axes | None = None, + ) -> plt.Axes | plt.Figure: """Plot the transport properties. Args: - prop_y: property to plot among ("Conductivity","Seebeck","Kappa","Carrier_conc", + prop_y: property to plot among ("Conductivity", "Seebeck", "Kappa", "Carrier_conc", "Hall_carrier_conc_trace"). Abbreviations are possible, like "S" for "Seebeck" prop_x: independent variable in the x-axis among ('mu','doping','temp') prop_z: third variable to plot multiple curves ('doping','temp') @@ -992,7 +992,9 @@ def plot_props( ax: figure.axes where to plot. If None, a new figure is produced. Returns: - plt.Axes: matplotlib Axes object + plt.Axes: matplotlib Axes object if ax provided + OR + plt.Figure: matplotlib Figure object if ax is None Example: bztPlotter.plot_props('S','mu','temp',temps=[600,900,1200]).show() @@ -1027,15 +1029,15 @@ def plot_props( r"$(cm^{-3})$", ) - props_short = [p[: len(prop_y)] for p in props] + props_short = tuple(p[: len(prop_y)] for p in props) if prop_y not in props_short: raise BoltztrapError("prop_y not valid") - if prop_x not in ("mu", "doping", "temp"): + if prop_x not in {"mu", "doping", "temp"}: raise BoltztrapError("prop_x not valid") - if prop_z not in ("doping", "temp"): + if prop_z not in {"doping", "temp"}: raise BoltztrapError("prop_z not valid") idx_prop = props_short.index(prop_y) @@ -1049,8 +1051,7 @@ def plot_props( else: p_array = getattr(self.bzt_transP, f"{props[idx_prop]}_{prop_x}") - if ax is None: - plt.figure(figsize=(10, 8)) + fig = plt.figure(figsize=(10, 8)) if ax is None else None temps_all = self.bzt_transP.temp_r.tolist() if temps is None: @@ -1113,6 +1114,9 @@ def plot_props( leg_title = f"{dop_type}-type" elif prop_z == "doping" and prop_x == "temp": + if doping is None: + raise ValueError("doping cannot be None when prop_z is doping") + for dop in doping: dop_idx = doping_all.index(dop) prop_out = np.linalg.eigh(p_array[dop_type][:, dop_idx])[0] @@ -1138,10 +1142,11 @@ def plot_props( plt.ylabel(f"{props_lbl[idx_prop]} {props_unit[idx_prop]}", fontsize=30) plt.xticks(fontsize=25) plt.yticks(fontsize=25) - plt.legend(title=leg_title if leg_title != "" else "", fontsize=15) + plt.legend(title=leg_title or "", fontsize=15) plt.tight_layout() plt.grid() - return ax + + return fig if ax is None else ax def plot_bands(self): """Plot a band structure on symmetry line using BSPlotter().""" diff --git a/src/pymatgen/electronic_structure/cohp.py b/src/pymatgen/electronic_structure/cohp.py index 17bcdce0be5..be757cbce95 100644 --- a/src/pymatgen/electronic_structure/cohp.py +++ b/src/pymatgen/electronic_structure/cohp.py @@ -1,6 +1,8 @@ -"""This module defines classes to represent crystal orbital Hamilton -populations (COHP) and integrated COHP (ICOHP), but can also be used -for crystal orbital overlap populations (COOP) or crystal orbital bond indices (COBIs). +"""This module defines classes to represent: + - Crystal orbital Hamilton population (COHP) and integrated COHP (ICOHP). + - Crystal orbital overlap population (COOP). + - Crystal orbital bond index (COBI). + If you use this module, please cite: J. George, G. Petretto, A. Naik, M. Esters, A. J. Jackson, R. Nelson, R. Dronskowski, G.-M. Rignanese, G. Hautier, "Automated Bonding Analysis with Crystal Orbital Hamilton Populations", @@ -17,6 +19,8 @@ import numpy as np from monty.json import MSONable +from scipy.interpolate import InterpolatedUnivariateSpline + from pymatgen.core.sites import PeriodicSite from pymatgen.core.structure import Structure from pymatgen.electronic_structure.core import Orbital, Spin @@ -25,13 +29,16 @@ from pymatgen.util.coord import get_linear_interpolated_value from pymatgen.util.due import Doi, due from pymatgen.util.num import round_to_sigfigs -from scipy.interpolate import InterpolatedUnivariateSpline if TYPE_CHECKING: - from typing import Any + from collections.abc import Sequence + from typing import Any, Literal + from numpy.typing import NDArray from typing_extensions import Self + from pymatgen.util.typing import PathLike, SpinLike, Vector3D + __author__ = "Marco Esters, Janine George" __copyright__ = "Copyright 2017, The Materials Project" __version__ = "0.2" @@ -49,17 +56,24 @@ class Cohp(MSONable): """Basic COHP object.""" def __init__( - self, efermi, energies, cohp, are_coops=False, are_cobis=False, are_multi_center_cobis=False, icohp=None + self, + efermi: float, + energies: Sequence[float], + cohp: dict[Spin, NDArray], + are_coops: bool = False, + are_cobis: bool = False, + are_multi_center_cobis: bool = False, + icohp: dict[Spin, NDArray] | None = None, ) -> None: """ Args: - are_coops: Indicates whether this object describes COOPs. - are_cobis: Indicates whether this object describes COBIs. - are_multi_center_cobis: Indicates whether this object describes multi-center COBIs - efermi: Fermi energy. - energies: A sequence of energies. - cohp ({Spin: np.array}): representing the COHP for each spin. - icohp ({Spin: np.array}): representing the ICOHP for each spin. + efermi (float): The Fermi level. + energies (Sequence[float]): Energies. + cohp ({Spin: NDArrary}): The COHP for each spin. + are_coops (bool): Whether this object describes COOPs. + are_cobis (bool): Whether this object describes COBIs. + are_multi_center_cobis (bool): Whether this object describes multi-center COBIs. + icohp ({Spin: NDArrary}): The ICOHP for each spin. """ self.are_coops = are_coops self.are_cobis = are_cobis @@ -70,7 +84,7 @@ def __init__( self.icohp = icohp def __repr__(self) -> str: - """Get a string that can be easily plotted (e.g. using gnuplot).""" + """A string that can be easily plotted (e.g. using gnuplot).""" if self.are_coops: cohp_str = "COOP" elif self.are_cobis or self.are_multi_center_cobis: @@ -96,7 +110,7 @@ def __repr__(self) -> str: str_arr.append(format_data.format(*(d[idx] for d in data))) return "\n".join(str_arr) - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of COHP.""" dct = { "@module": type(self).__module__, @@ -112,20 +126,22 @@ def as_dict(self): dct["ICOHP"] = {str(spin): pops.tolist() for spin, pops in self.icohp.items()} return dct - def get_cohp(self, spin=None, integrated=False): + def get_cohp( + self, + spin: SpinLike | None = None, + integrated: bool = False, + ) -> dict[Spin, NDArray] | None: """Get the COHP or ICOHP for a particular spin. Args: - spin: Spin. Can be parsed as spin object, integer (-1/1) - or str ("up"/"down") - integrated: Return COHP (False) or ICOHP (True) + spin (SpinLike): Selected spin. If is None and both + spins are present, both will be returned. + integrated: Return ICOHP (True) or COHP (False). Returns: - Returns the CHOP or ICOHP for the input spin. If Spin is - None and both spins are present, both spins will be returned - as a dictionary. + dict: The COHP or ICOHP for the selected spin. """ - populations = self.cohp if not integrated else self.icohp + populations = self.icohp if integrated else self.cohp if populations is None: return None @@ -137,118 +153,127 @@ def get_cohp(self, spin=None, integrated=False): spin = Spin({"up": 1, "down": -1}[spin.lower()]) return {spin: populations[spin]} - def get_icohp(self, spin=None): - """Convenient alternative to get the ICOHP for a particular spin.""" + def get_icohp( + self, + spin: SpinLike | None = None, + ) -> dict[Spin, NDArray] | None: + """Convenient wrapper to get the ICOHP for a particular spin.""" return self.get_cohp(spin=spin, integrated=True) - def get_interpolated_value(self, energy, integrated=False): - """Get the COHP for a particular energy. + def get_interpolated_value( + self, + energy: float, + integrated: bool = False, + ) -> dict[Spin, float]: + """Get the interpolated COHP for a particular energy. Args: - energy: Energy to return the COHP value for. - integrated: Return COHP (False) or ICOHP (True) + energy (float): Energy to get the COHP value for. + integrated (bool): Return ICOHP (True) or COHP (False). """ - inter = {} + inters = {} for spin in self.cohp: if not integrated: - inter[spin] = get_linear_interpolated_value(self.energies, self.cohp[spin], energy) + inters[spin] = get_linear_interpolated_value(self.energies, self.cohp[spin], energy) elif self.icohp is not None: - inter[spin] = get_linear_interpolated_value(self.energies, self.icohp[spin], energy) + inters[spin] = get_linear_interpolated_value(self.energies, self.icohp[spin], energy) else: raise ValueError("ICOHP is empty.") - return inter + return inters + + def has_antibnd_states_below_efermi( + self, + spin: SpinLike | None = None, + limit: float = 0.01, + ) -> dict[Spin, bool] | None: + """Get dict of antibonding states below the Fermi level for the spin. - def has_antibnd_states_below_efermi(self, spin=None, limit=0.01): - """Get dict indicating if there are antibonding states below the Fermi level depending on the spin - spin: Spin - limit: -COHP smaller -limit will be considered. + Args: + spin (SpinLike): Selected spin. + limit (float): Only COHP higher than this value will be considered. """ populations = self.cohp n_energies_below_efermi = len([energy for energy in self.energies if energy <= self.efermi]) if populations is None: return None + + dict_to_return = {} if spin is None: - dict_to_return = {} for sp, cohp_vals in populations.items(): - if (max(cohp_vals[:n_energies_below_efermi])) > limit: - dict_to_return[sp] = True - else: - dict_to_return[sp] = False + # NOTE: Casting to bool is necessary, otherwise ended up + # getting "bool_" instead of "bool" from NumPy + dict_to_return[sp] = bool((max(cohp_vals[:n_energies_below_efermi])) > limit) + else: - dict_to_return = {} if isinstance(spin, int): spin = Spin(spin) elif isinstance(spin, str): spin = Spin({"up": 1, "down": -1}[spin.lower()]) - if (max(populations[spin][:n_energies_below_efermi])) > limit: - dict_to_return[spin] = True - else: - dict_to_return[spin] = False + dict_to_return[spin] = bool((max(populations[spin][:n_energies_below_efermi])) > limit) return dict_to_return @classmethod def from_dict(cls, dct: dict[str, Any]) -> Self: - """Get a COHP object from a dict representation of the COHP.""" + """Generate Cohp from a dict representation.""" icohp = {Spin(int(key)): np.array(val) for key, val in dct["ICOHP"].items()} if "ICOHP" in dct else None - are_cobis = dct.get("are_cobis", False) - are_multi_center_cobis = dct.get("are_multi_center_cobis", False) + return cls( dct["efermi"], dct["energies"], {Spin(int(key)): np.array(val) for key, val in dct["COHP"].items()}, icohp=icohp, are_coops=dct["are_coops"], - are_cobis=are_cobis, - are_multi_center_cobis=are_multi_center_cobis, + are_cobis=dct.get("are_cobis", False), + are_multi_center_cobis=dct.get("are_multi_center_cobis", False), ) class CompleteCohp(Cohp): - """A wrapper class that defines an average COHP, and individual COHPs. + """A wrapper that defines an average COHP, and individual COHPs. Attributes: - are_coops (bool): Indicates whether the object is consisting of COOPs. - are_cobis (bool): Indicates whether the object is consisting of COBIs. - efermi (float): Fermi energy. + are_coops (bool): Whether the object consists of COOPs. + are_cobis (bool): Whether the object consists of COBIs. + efermi (float): The Fermi level. energies (Sequence[float]): Sequence of energies. - structure (pymatgen.Structure): Structure associated with the COHPs. + structure (Structure): Structure associated with the COHPs. cohp (Sequence[float]): The average COHP. icohp (Sequence[float]): The average ICOHP. - all_cohps (dict[str, Sequence[float]]): A dict of COHPs for individual bonds of the form {label: COHP}. + all_cohps (dict[str, Sequence[float]]): COHPs for individual bonds of the form {label: COHP}. orb_res_cohp (dict[str, Dict[str, Sequence[float]]]): Orbital-resolved COHPs. """ def __init__( self, - structure, - avg_cohp, - cohp_dict, - bonds=None, - are_coops=False, - are_cobis=False, - are_multi_center_cobis=False, - orb_res_cohp=None, + structure: Structure, + avg_cohp: Cohp, + cohp_dict: dict[str, Cohp], + bonds: dict[str, Any] | None = None, + are_coops: bool = False, + are_cobis: bool = False, + are_multi_center_cobis: bool = False, + orb_res_cohp: dict[str, dict] | None = None, ) -> None: """ Args: - structure: Structure associated with this COHP. - avg_cohp: The average cohp as a COHP object. - cohp_dict: A dict of COHP objects for individual bonds of the form - {label: COHP} - bonds: A dict containing information on the bonds of the form - {label: {key: val}}. The key-val pair can be any information - the user wants to put in, but typically contains the sites, - the bond length, and the number of bonds. If nothing is + structure (Structure): Structure associated with this COHP. + avg_cohp (Cohp): The average COHP. + cohp_dict (dict[str, Cohp]): COHP for individual bonds of the form + {label: COHP}. + bonds (dict[str, Any]): Information on the bonds of the form + {label: {key: val}}. The value can be any information, + but typically contains the sites, the bond length, + and the number of bonds. If nothing is supplied, it will default to an empty dict. - are_coops: indicates whether the Cohp objects are COOPs. + are_coops (bool): Whether the Cohp objects are COOPs. Defaults to False for COHPs. - are_cobis: indicates whether the Cohp objects are COBIs. + are_cobis (bool): Whether the Cohp objects are COBIs. Defaults to False for COHPs. - are_multi_center_cobis: indicates whether the Cohp objects are multi-center COBIs. + are_multi_center_cobis (bool): Whether the Cohp objects are multi-center COBIs. Defaults to False for COHPs. - orb_res_cohp: Orbital-resolved COHPs. + orb_res_cohp (dict): Orbital-resolved COHPs. """ if ( (are_coops and are_cobis) @@ -256,6 +281,7 @@ def __init__( or (are_cobis and are_multi_center_cobis) ): raise ValueError("You cannot have info about COOPs, COBIs and/or multi-center COBIS in the same file.") + super().__init__( avg_cohp.efermi, avg_cohp.energies, @@ -275,12 +301,15 @@ def __init__( def __str__(self) -> str: if self.are_coops: - return f"Complete COOPs for {self.structure}" - if self.are_cobis: - return f"Complete COBIs for {self.structure}" - return f"Complete COHPs for {self.structure}" + header = "COOPs" + elif self.are_cobis: + header = "COBIs" + else: + header = "COHPs" + + return f"Complete {header} for {self.structure}" - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of CompleteCohp.""" dct = { "@module": type(self).__module__, @@ -298,16 +327,14 @@ def as_dict(self): dct["ICOHP"] = {"average": {str(spin): pops.tolist() for spin, pops in self.icohp.items()}} for label in self.all_cohps: - dct["COHP"].update({label: {str(spin): pops.tolist() for spin, pops in self.all_cohps[label].cohp.items()}}) - if self.all_cohps[label].icohp is not None: + dct["COHP"] |= {label: {str(spin): pops.tolist() for spin, pops in self.all_cohps[label].cohp.items()}} + icohp = self.all_cohps[label].icohp + if icohp is not None: if "ICOHP" not in dct: - dct["ICOHP"] = { - label: {str(spin): pops.tolist() for spin, pops in self.all_cohps[label].icohp.items()} - } + dct["ICOHP"] = {label: {str(spin): pops.tolist() for spin, pops in icohp.items()}} else: - dct["ICOHP"].update( - {label: {str(spin): pops.tolist() for spin, pops in self.all_cohps[label].icohp.items()}} - ) + dct["ICOHP"] |= {label: {str(spin): pops.tolist() for spin, pops in icohp.items()}} + if False in [bond_dict == {} for bond_dict in self.bonds.values()]: dct["bonds"] = { bond: { @@ -316,44 +343,55 @@ def as_dict(self): } for bond in self.bonds } + if self.orb_res_cohp: - orb_dict = {} + orb_dict: dict[str, Any] = {} for label in self.orb_res_cohp: orb_dict[label] = {} for orbs in self.orb_res_cohp[label]: - cohp = {str(spin): pops.tolist() for spin, pops in self.orb_res_cohp[label][orbs]["COHP"].items()} - orb_dict[label][orbs] = {"COHP": cohp} - icohp = {str(spin): pops.tolist() for spin, pops in self.orb_res_cohp[label][orbs]["ICOHP"].items()} - orb_dict[label][orbs]["ICOHP"] = icohp - orbitals = [[orb[0], orb[1].name] for orb in self.orb_res_cohp[label][orbs]["orbitals"]] - orb_dict[label][orbs]["orbitals"] = orbitals + orb_dict[label][orbs] = { + "COHP": { + str(spin): pops.tolist() for spin, pops in self.orb_res_cohp[label][orbs]["COHP"].items() + }, + "ICOHP": { + str(spin): pops.tolist() for spin, pops in self.orb_res_cohp[label][orbs]["ICOHP"].items() + }, + "orbitals": [[orb[0], orb[1].name] for orb in self.orb_res_cohp[label][orbs]["orbitals"]], + } + dct["orb_res_cohp"] = orb_dict return dct - def get_cohp_by_label(self, label, summed_spin_channels=False): - """Get specific COHP object. + def get_cohp_by_label( + self, + label: str, + summed_spin_channels: bool = False, + ) -> Cohp: + """Get specific Cohp by the label, to simplify plotting. Args: - label: string (for newer Lobster versions: a number) - summed_spin_channels: bool, will sum the spin channels and return the sum in Spin.up if true + label (str): Label for the interaction. + summed_spin_channels (bool): Sum the spin channels and return the sum as Spin.up. Returns: - Returns the COHP object to simplify plotting + The Cohp. """ if label.lower() == "average": - divided_cohp = self.cohp - divided_icohp = self.icohp - + divided_cohp: dict[Spin, Any] | None = self.cohp + divided_icohp: dict[Spin, Any] | None = self.icohp else: divided_cohp = self.all_cohps[label].get_cohp(spin=None, integrated=False) divided_icohp = self.all_cohps[label].get_icohp(spin=None) + assert divided_cohp is not None + if summed_spin_channels and Spin.down in self.cohp: - final_cohp = {} - final_icohp = {} - final_cohp[Spin.up] = np.sum([divided_cohp[Spin.up], divided_cohp[Spin.down]], axis=0) - final_icohp[Spin.up] = np.sum([divided_icohp[Spin.up], divided_icohp[Spin.down]], axis=0) + assert divided_icohp is not None + final_cohp: dict[Spin, Any] = {Spin.up: np.sum([divided_cohp[Spin.up], divided_cohp[Spin.down]], axis=0)} + final_icohp: dict[Spin, Any] | None = { + Spin.up: np.sum([divided_icohp[Spin.up], divided_icohp[Spin.down]], axis=0) + } else: final_cohp = divided_cohp final_icohp = divided_icohp @@ -367,46 +405,50 @@ def get_cohp_by_label(self, label, summed_spin_channels=False): icohp=final_icohp, ) - def get_summed_cohp_by_label_list(self, label_list, divisor=1, summed_spin_channels=False): - """Get a COHP object that includes a summed COHP divided by divisor. + def get_summed_cohp_by_label_list( + self, + label_list: list[str], + divisor: float = 1, + summed_spin_channels: bool = False, + ) -> Cohp: + """Get a Cohp object that includes a summed COHP divided by divisor. Args: - label_list: list of labels for the COHP that should be included in the summed cohp - divisor: float/int, the summed cohp will be divided by this divisor - summed_spin_channels: bool, will sum the spin channels and return the sum in Spin.up if true + label_list (list[str]): Labels for the COHP to include. + divisor (float): The summed COHP will be divided by this divisor. + summed_spin_channels (bool): Sum the spin channels and return the sum in Spin.up. Returns: - Returns a COHP object including a summed COHP + A Cohp object for the summed COHP. """ - # check if cohps are spinpolarized or not + # Check if COHPs are spin polarized first_cohpobject = self.get_cohp_by_label(label_list[0]) summed_cohp = first_cohpobject.cohp.copy() + assert first_cohpobject.icohp is not None summed_icohp = first_cohpobject.icohp.copy() for label in label_list[1:]: - cohp_here = self.get_cohp_by_label(label) - summed_cohp[Spin.up] = np.sum([summed_cohp[Spin.up], cohp_here.cohp[Spin.up]], axis=0) + cohp = self.get_cohp_by_label(label) + icohp = cohp.icohp + assert icohp is not None + summed_cohp[Spin.up] = np.sum([summed_cohp[Spin.up], cohp.cohp[Spin.up]], axis=0) if Spin.down in summed_cohp: - summed_cohp[Spin.down] = np.sum([summed_cohp[Spin.down], cohp_here.cohp[Spin.down]], axis=0) + summed_cohp[Spin.down] = np.sum([summed_cohp[Spin.down], cohp.cohp[Spin.down]], axis=0) - summed_icohp[Spin.up] = np.sum([summed_icohp[Spin.up], cohp_here.icohp[Spin.up]], axis=0) + summed_icohp[Spin.up] = np.sum([summed_icohp[Spin.up], icohp[Spin.up]], axis=0) if Spin.down in summed_icohp: - summed_icohp[Spin.down] = np.sum([summed_icohp[Spin.down], cohp_here.icohp[Spin.down]], axis=0) + summed_icohp[Spin.down] = np.sum([summed_icohp[Spin.down], icohp[Spin.down]], axis=0) - divided_cohp = {} - divided_icohp = {} - divided_cohp[Spin.up] = np.divide(summed_cohp[Spin.up], divisor) - divided_icohp[Spin.up] = np.divide(summed_icohp[Spin.up], divisor) + divided_cohp = {Spin.up: np.divide(summed_cohp[Spin.up], divisor)} + divided_icohp = {Spin.up: np.divide(summed_icohp[Spin.up], divisor)} if Spin.down in summed_cohp: divided_cohp[Spin.down] = np.divide(summed_cohp[Spin.down], divisor) divided_icohp[Spin.down] = np.divide(summed_icohp[Spin.down], divisor) if summed_spin_channels and Spin.down in summed_cohp: - final_cohp = {} - final_icohp = {} - final_cohp[Spin.up] = np.sum([divided_cohp[Spin.up], divided_cohp[Spin.down]], axis=0) - final_icohp[Spin.up] = np.sum([divided_icohp[Spin.up], divided_icohp[Spin.down]], axis=0) + final_cohp = {Spin.up: np.sum([divided_cohp[Spin.up], divided_cohp[Spin.down]], axis=0)} + final_icohp = {Spin.up: np.sum([divided_icohp[Spin.up], divided_icohp[Spin.down]], axis=0)} else: final_cohp = divided_cohp final_icohp = divided_icohp @@ -421,50 +463,56 @@ def get_summed_cohp_by_label_list(self, label_list, divisor=1, summed_spin_chann ) def get_summed_cohp_by_label_and_orbital_list( - self, label_list, orbital_list, divisor=1, summed_spin_channels=False - ): - """Get a COHP object that includes a summed COHP divided by divisor. + self, + label_list: list[str], + orbital_list: list[str], + divisor: float = 1, + summed_spin_channels: bool = False, + ) -> Cohp: + """Get a Cohp object that includes a summed COHP divided by divisor. Args: - label_list: list of labels for the COHP that should be included in the summed cohp - orbital_list: list of orbitals for the COHPs that should be included in the summed cohp (same order as - label_list) - divisor: float/int, the summed cohp will be divided by this divisor - summed_spin_channels: bool, will sum the spin channels and return the sum in Spin.up if true + label_list (list[str]): Labels for the COHP that should be included. + orbital_list (list[str]): Orbitals for the COHPs that should be included + (same order as label_list). + divisor (float): The summed COHP will be divided by this divisor. + summed_spin_channels (bool): Sum the spin channels and return the sum in Spin.up. Returns: - Returns a COHP object including a summed COHP + A Cohp object including the summed COHP. """ - # check length of label_list and orbital_list: + # Check length of label_list and orbital_list if not len(label_list) == len(orbital_list): raise ValueError("label_list and orbital_list don't have the same length!") - # check if cohps are spinpolarized or not + + # Check if COHPs are spin polarized first_cohpobject = self.get_orbital_resolved_cohp(label_list[0], orbital_list[0]) + assert first_cohpobject is not None + assert first_cohpobject.icohp is not None summed_cohp = first_cohpobject.cohp.copy() summed_icohp = first_cohpobject.icohp.copy() - for ilabel, label in enumerate(label_list[1:], start=1): - cohp_here = self.get_orbital_resolved_cohp(label, orbital_list[ilabel]) - summed_cohp[Spin.up] = np.sum([summed_cohp[Spin.up], cohp_here.cohp.copy()[Spin.up]], axis=0) + + for idx, label in enumerate(label_list[1:], start=1): + cohp = self.get_orbital_resolved_cohp(label, orbital_list[idx]) + assert cohp is not None + assert cohp.icohp is not None + summed_cohp[Spin.up] = np.sum([summed_cohp[Spin.up], cohp.cohp.copy()[Spin.up]], axis=0) if Spin.down in summed_cohp: - summed_cohp[Spin.down] = np.sum([summed_cohp[Spin.down], cohp_here.cohp.copy()[Spin.down]], axis=0) - summed_icohp[Spin.up] = np.sum([summed_icohp[Spin.up], cohp_here.icohp.copy()[Spin.up]], axis=0) + summed_cohp[Spin.down] = np.sum([summed_cohp[Spin.down], cohp.cohp.copy()[Spin.down]], axis=0) + + summed_icohp[Spin.up] = np.sum([summed_icohp[Spin.up], cohp.icohp.copy()[Spin.up]], axis=0) if Spin.down in summed_icohp: - summed_icohp[Spin.down] = np.sum([summed_icohp[Spin.down], cohp_here.icohp.copy()[Spin.down]], axis=0) + summed_icohp[Spin.down] = np.sum([summed_icohp[Spin.down], cohp.icohp.copy()[Spin.down]], axis=0) - divided_cohp = {} - divided_icohp = {} - divided_cohp[Spin.up] = np.divide(summed_cohp[Spin.up], divisor) - divided_icohp[Spin.up] = np.divide(summed_icohp[Spin.up], divisor) + divided_cohp = {Spin.up: np.divide(summed_cohp[Spin.up], divisor)} + divided_icohp = {Spin.up: np.divide(summed_icohp[Spin.up], divisor)} if Spin.down in summed_cohp: divided_cohp[Spin.down] = np.divide(summed_cohp[Spin.down], divisor) divided_icohp[Spin.down] = np.divide(summed_icohp[Spin.down], divisor) if summed_spin_channels and Spin.down in divided_cohp: - final_cohp = {} - final_icohp = {} - - final_cohp[Spin.up] = np.sum([divided_cohp[Spin.up], divided_cohp[Spin.down]], axis=0) - final_icohp[Spin.up] = np.sum([divided_icohp[Spin.up], divided_icohp[Spin.down]], axis=0) + final_cohp = {Spin.up: np.sum([divided_cohp[Spin.up], divided_cohp[Spin.down]], axis=0)} + final_icohp = {Spin.up: np.sum([divided_icohp[Spin.up], divided_icohp[Spin.down]], axis=0)} else: final_cohp = divided_cohp final_icohp = divided_icohp @@ -478,30 +526,34 @@ def get_summed_cohp_by_label_and_orbital_list( icohp=final_icohp, ) - def get_orbital_resolved_cohp(self, label, orbitals, summed_spin_channels=False): + def get_orbital_resolved_cohp( + self, + label: str, + orbitals: str | list[tuple[str, Orbital]] | tuple[tuple[str, Orbital], ...], + summed_spin_channels: bool = False, + ) -> Cohp | None: """Get orbital-resolved COHP. Args: - label: bond label (Lobster: labels as in ICOHPLIST/ICOOPLIST.lobster). - - orbitals: The orbitals as a label, or list or tuple of the form - [(n1, orbital1), (n2, orbital2)]. Orbitals can either be str, - int, or Orbital. - - summed_spin_channels: bool, will sum the spin channels and return the sum in Spin.up if true + label (str): Bond labels as in ICOHPLIST/ICOOPLIST.lobster. + orbitals: The orbitals as a label, or list/tuple of + [(n1, orbital1), (n2, orbital2), ...]. + Where each orbital can either be str, int, or Orbital. + summed_spin_channels (bool): Sum the spin channels and return the sum as Spin.up. Returns: - A Cohp object if CompleteCohp contains orbital-resolved cohp, - or None if it doesn't. + A Cohp object if CompleteCohp contains orbital-resolved COHP, + or None if it doesn't. Note: It currently assumes that orbitals are str if they aren't the - other valid types. This is not ideal, but the easiest way to - avoid unicode issues between python 2 and python 3. + other valid types. This is not ideal, but is the easiest way to + avoid unicode issues between Python 2 and Python 3. """ if self.orb_res_cohp is None: return None + if isinstance(orbitals, (list, tuple)): - cohp_orbs = [d["orbitals"] for d in self.orb_res_cohp[label].values()] + cohp_orbs = [val["orbitals"] for val in self.orb_res_cohp[label].values()] orbs = [] for orbital in orbitals: if isinstance(orbital[1], int): @@ -514,10 +566,12 @@ def get_orbital_resolved_cohp(self, label, orbitals, summed_spin_channels=False) raise TypeError("Orbital must be str, int, or Orbital.") orb_index = cohp_orbs.index(orbs) orb_label = list(self.orb_res_cohp[label])[orb_index] + elif isinstance(orbitals, str): orb_label = orbitals else: raise TypeError("Orbitals must be str, list, or tuple.") + try: icohp = self.orb_res_cohp[label][orb_label]["ICOHP"] except KeyError: @@ -546,9 +600,11 @@ def get_orbital_resolved_cohp(self, label, orbitals, summed_spin_channels=False) ) @classmethod - def from_dict(cls, dct: dict) -> Self: - """Get CompleteCohp object from dict representation.""" - # TODO: clean that mess up? + def from_dict(cls, dct: dict[str, Any]) -> Self: + """Get CompleteCohp object from a dict representation. + + TODO: Clean this up. + """ cohp_dict = {} efermi = dct["efermi"] energies = dct["energies"] @@ -569,6 +625,7 @@ def from_dict(cls, dct: dict) -> Self: } else: bonds = None + for label in dct["COHP"]: cohp = {Spin(int(spin)): np.array(dct["COHP"][label][spin]) for spin in dct["COHP"][label]} try: @@ -589,7 +646,7 @@ def from_dict(cls, dct: dict) -> Self: cohp_dict[label] = Cohp(efermi, energies, cohp, icohp=icohp) if "orb_res_cohp" in dct: - orb_cohp: dict[str, dict] = {} + orb_cohp: dict[str, dict[Orbital, dict[str, Any]]] = {} for label in dct["orb_res_cohp"]: orb_cohp[label] = {} for orb in dct["orb_res_cohp"][label]: @@ -611,9 +668,9 @@ def from_dict(cls, dct: dict) -> Self: "orbitals": orbitals, } # If no total COHPs are present, calculate the total - # COHPs from the single-orbital populations. Total COHPs - # may not be present when the COHP generator keyword is used - # in LOBSTER versions 2.2.0 and earlier. + # COHPs from the single-orbital populations. + # Total COHPs may not be present when the COHP generator keyword + # is used in LOBSTER versions 2.2.0 and earlier. if label not in dct["COHP"] or dct["COHP"][label] is None: cohp = { Spin.up: np.sum( @@ -647,52 +704,52 @@ def from_dict(cls, dct: dict) -> Self: else: orb_cohp = {} - are_cobis = dct.get("are_cobis", False) - + assert avg_cohp is not None return cls( structure, avg_cohp, cohp_dict, bonds=bonds, are_coops=dct["are_coops"], - are_cobis=are_cobis, + are_cobis=dct.get("are_cobis", False), are_multi_center_cobis=are_multi_center_cobis, orb_res_cohp=orb_cohp, ) @classmethod def from_file( - cls, fmt, filename=None, structure_file=None, are_coops=False, are_cobis=False, are_multi_center_cobis=False + cls, + fmt: Literal["LMTO", "LOBSTER"], + filename: PathLike | None = None, + structure_file: PathLike | None = None, + are_coops: bool = False, + are_cobis: bool = False, + are_multi_center_cobis: bool = False, ) -> Self: - """ - Creates a CompleteCohp object from an output file of a COHP - calculation. Valid formats are either LMTO (for the Stuttgart - LMTO-ASA code) or LOBSTER (for the LOBSTER code). + """Create CompleteCohp from an output file of a COHP calculation. Args: - fmt: A string for the code that was used to calculate - the COHPs so that the output file can be handled - correctly. Can take the values "LMTO" or "LOBSTER". - filename: Name of the COHP output file. Defaults to COPL - for LMTO and COHPCAR.lobster/COOPCAR.lobster for LOBSTER. - structure_file: Name of the file containing the structure. - If no file name is given, use CTRL for LMTO and POSCAR - for LOBSTER. - are_coops: Indicates whether the populations are COOPs or - COHPs. Defaults to False for COHPs. - are_cobis: Indicates whether the populations are COBIs or - COHPs. Defaults to False for COHPs. - are_multi_center_cobis: Indicates whether this file - includes information on multi-center COBIs + fmt (Literal["LMTO", "LOBSTER"]): The code used to calculate COHPs. + filename (PathLike): The COHP output file. Defaults to "COPL" + for LMTO and "COHPCAR.lobster/COOPCAR.lobster" for LOBSTER. + structure_file (PathLike): The file containing the structure. + If None, use "CTRL" for LMTO and "POSCAR" for LOBSTER. + are_coops (bool): Whether the populations are COOPs or COHPs. + Defaults to False for COHPs. + are_cobis (bool): Whether the populations are COBIs or COHPs. + Defaults to False for COHPs. + are_multi_center_cobis (bool): Whether this file + includes information on multi-center COBIs. Returns: A CompleteCohp object. """ if are_coops and are_cobis: raise ValueError("You cannot have info about COOPs and COBIs in the same file.") - fmt = fmt.upper() + + fmt = fmt.upper() # type: ignore[assignment] if fmt == "LMTO": - # LMTO COOPs and orbital-resolved COHP cannot be handled yet. + # TODO: LMTO COOPs and orbital-resolved COHP cannot be handled yet are_coops = False are_cobis = False orb_res_cohp = None @@ -700,7 +757,9 @@ def from_file( structure_file = "CTRL" if filename is None: filename = "COPL" + cohp_file: LMTOCopl | Cohpcar = LMTOCopl(filename=filename, to_eV=True) + elif fmt == "LOBSTER": if ( (are_coops and are_cobis) @@ -724,6 +783,7 @@ def from_file( are_multi_center_cobis=are_multi_center_cobis, ) orb_res_cohp = cohp_file.orb_res_cohp + else: raise ValueError(f"Unknown format {fmt}. Valid formats are LMTO and LOBSTER.") @@ -732,9 +792,8 @@ def from_file( cohp_data = cohp_file.cohp_data energies = cohp_file.energies - # Lobster shifts the energies so that the Fermi energy is at zero. + # LOBSTER shifts the energies so that the Fermi level is at zero. # Shifting should be done by the plotter object though. - spins = [Spin.up, Spin.down] if cohp_file.is_spin_polarized else [Spin.up] if fmt == "LOBSTER": energies += efermi @@ -744,6 +803,7 @@ def from_file( # COHPs from the single-orbital populations. Total COHPs # may not be present when the cohpgenerator keyword is used # in LOBSTER versions 2.2.0 and earlier. + # TODO: Test this more extensively for label in orb_res_cohp: @@ -765,16 +825,16 @@ def from_file( } if fmt == "LMTO": - # Calculate the average COHP for the LMTO file to be - # consistent with LOBSTER output. - avg_data: dict[str, dict] = {"COHP": {}, "ICOHP": {}} - for i in avg_data: + # Calculate the average COHP for the LMTO file to be consistent with LOBSTER + avg_data: dict[Literal["COHP", "ICOHP"], dict] = {"COHP": {}, "ICOHP": {}} + for dtype in avg_data: for spin in spins: - rows = np.array([v[i][spin] for v in cohp_data.values()]) + rows = np.array([v[dtype][spin] for v in cohp_data.values()]) avg = np.mean(rows, axis=0) - # LMTO COHPs have 5 significant figures - avg_data[i].update({spin: np.array([round_to_sigfigs(a, 5) for a in avg], dtype=float)}) + # LMTO COHPs have 5 significant digits + avg_data[dtype] |= {spin: np.array([round_to_sigfigs(a, 5) for a in avg], dtype=float)} avg_cohp = Cohp(efermi, energies, avg_data["COHP"], icohp=avg_data["ICOHP"]) + elif not are_multi_center_cobis: avg_cohp = Cohp( efermi, @@ -786,9 +846,9 @@ def from_file( are_multi_center_cobis=are_multi_center_cobis, ) del cohp_data["average"] + else: - # only include two-center cobis in average - # do this for both spin channels + # Only include two-center COBIs in average for both spin channels cohp = {} cohp[Spin.up] = np.array( [np.array(c["COHP"][Spin.up]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] @@ -799,6 +859,7 @@ def from_file( ).mean(axis=0) except KeyError: pass + try: icohp = {} icohp[Spin.up] = np.array( @@ -812,6 +873,7 @@ def from_file( pass except KeyError: icohp = None + avg_cohp = Cohp( efermi, energies, @@ -855,38 +917,53 @@ def from_file( class IcohpValue(MSONable): - """Store information on an ICOHP or ICOOP value. + """Information for an ICOHP or ICOOP value. Attributes: - energies (ndarray): Energy values for the COHP/ICOHP/COOP/ICOOP. - densities (ndarray): Density of states values for the COHP/ICOHP/COOP/ICOOP. - energies_are_cartesian (bool): Whether the energies are cartesian or not. - are_coops (bool): Whether the object is a COOP/ICOOP or not. - are_cobis (bool): Whether the object is a COBIS/ICOBIS or not. - icohp (dict): A dictionary of the ICOHP/COHP values. The keys are Spin.up and Spin.down. + energies (NDArray): Energy values for the COHP/ICOHP/COOP/ICOOP. + densities (NDArray): Density of states for the COHP/ICOHP/COOP/ICOOP. + energies_are_cartesian (bool): Whether the energies are cartesian. + are_coops (bool): Whether the object is COOP/ICOOP. + are_cobis (bool): Whether the object is COBIS/ICOBIS. + icohp (dict): The ICOHP/COHP values, whose keys are Spin.up and Spin.down. summed_icohp (float): The summed ICOHP/COHP values. - num_bonds (int): The number of bonds used for the average COHP (relevant for Lobster versions <3.0). + num_bonds (int): The number of bonds used for the average COHP (for LOBSTER versions <3.0). """ def __init__( - self, label, atom1, atom2, length, translation, num, icohp, are_coops=False, are_cobis=False, orbitals=None + self, + label: str, + atom1: str, + atom2: str, + length: float, + translation: Vector3D, + num: int, + icohp: dict[Spin, float], + are_coops: bool = False, + are_cobis: bool = False, + orbitals: dict[str, dict[Literal["icohp", "orbitals"], Any]] | None = None, ) -> None: """ Args: - label: label for the icohp - atom1: str of atom that is contributing to the bond - atom2: str of second atom that is contributing to the bond - length: float of bond lengths - translation: translation list, e.g. [0,0,0] - num: integer describing how often the bond exists - icohp: dict={Spin.up: icohpvalue for spin.up, Spin.down: icohpvalue for spin.down} - are_coops: if True, this are COOPs - are_cobis: if True, this are COBIs - orbitals: {[str(Orbital1)-str(Orbital2)]: {"icohp":{Spin.up: icohpvalue for spin.up, Spin.down: - icohpvalue for spin.down}, "orbitals":[Orbital1, Orbital2]}}. + label (str): Label for the ICOHP. + atom1 (str): The first atom that contributes to the bond. + atom2 (str): The second atom that contributes to the bond. + length (float): Bond length. + translation (Vector3D): cell translation vector, e.g. (0, 0, 0). + num (int): The number of equivalent bonds. + icohp (dict[Spin, float]): {Spin.up: ICOHP_up, Spin.down: ICOHP_down} + are_coops (bool): Whether these are COOPs. + are_cobis (bool): Whether these are COBIs. + orbitals (dict): {[str(Orbital1)-str(Orbital2)]: { + "icohp": { + Spin.up: IcohpValue for spin.up, + Spin.down: IcohpValue for spin.down + }, + "orbitals": [Orbital1, Orbital2, ...]}. """ if are_coops and are_cobis: raise ValueError("You cannot have info about COOPs and COBIs in the same file.") + self._are_coops = are_coops self._are_cobis = are_cobis self._label = label @@ -897,136 +974,127 @@ def __init__( self._num = num self._icohp = icohp self._orbitals = orbitals - if Spin.down in self._icohp: - self._is_spin_polarized = True - else: - self._is_spin_polarized = False + self._is_spin_polarized = Spin.down in self._icohp def __str__(self) -> str: """String representation of the ICOHP/ICOOP.""" - if not self._are_coops and not self._are_cobis: - if self._is_spin_polarized: - return ( - f"ICOHP {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " - f"{self._icohp[Spin.up]} eV (Spin up) and {self._icohp[Spin.down]} eV (Spin down)" - ) - return ( - f"ICOHP {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " - f"{self._icohp[Spin.up]} eV (Spin up)" - ) - if self._are_coops and not self._are_cobis: - if self._is_spin_polarized: - return ( - f"ICOOP {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " - f"{self._icohp[Spin.up]} eV (Spin up) and {self._icohp[Spin.down]} eV (Spin down)" - ) - return ( - f"ICOOP {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " - f"{self._icohp[Spin.up]} eV (Spin up)" - ) - if self._are_cobis and not self._are_coops: - if self._is_spin_polarized: - return ( - f"ICOBI {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " - f"{self._icohp[Spin.up]} eV (Spin up) and {self._icohp[Spin.down]} eV (Spin down)" - ) + # (are_coops and are_cobis) is never True + if self._are_coops: + header = "ICOOP" + elif self._are_cobis: + header = "ICOBI" + else: + header = "ICOHP" + + if self._is_spin_polarized: return ( - f"ICOBI {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " - f"{self._icohp[Spin.up]} eV (Spin up)" + f"{header} {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " + f"{self._icohp[Spin.up]} eV (Spin up) and {self._icohp[Spin.down]} eV (Spin down)" ) - return "" + return ( + f"{header} {self._label} between {self._atom1} and {self._atom2} ({self._translation}): " + f"{self._icohp[Spin.up]} eV (Spin up)" + ) @property - def num_bonds(self): - """Tells the number of bonds for which the ICOHP value is an average. + def num_bonds(self) -> int: + """The number of bonds for which the ICOHP value is an average. Returns: - Int. + int """ return self._num @property def are_coops(self) -> bool: - """Tells if ICOOPs or not. + """Whether these are ICOOPs. Returns: - Boolean. + bool """ return self._are_coops @property def are_cobis(self) -> bool: - """Tells if ICOBIs or not. + """Whether these are ICOBIs. Returns: - Boolean. + bool """ return self._are_cobis @property def is_spin_polarized(self) -> bool: - """Tells if spin polarized calculation or not. + """Whether this is a spin polarized calculation. Returns: - Boolean. + bool """ return self._is_spin_polarized - def icohpvalue(self, spin=Spin.up): + def icohpvalue(self, spin: Spin = Spin.up) -> float: """ Args: spin: Spin.up or Spin.down. Returns: - float: corresponding to chosen spin. + float: ICOHP value corresponding to chosen spin. """ if not self.is_spin_polarized and spin == Spin.down: raise ValueError("The calculation was not performed with spin polarization") return self._icohp[spin] - def icohpvalue_orbital(self, orbitals, spin=Spin.up) -> float: + def icohpvalue_orbital( + self, + orbitals: tuple[Orbital, Orbital] | str, + spin: Spin = Spin.up, + ) -> float: """ Args: - orbitals: List of Orbitals or "str(Orbital1)-str(Orbital2)" - spin: Spin.up or Spin.down. + orbitals: tuple[Orbital, Orbital] or "str(Orbital0)-str(Orbital1)". + spin (Spin): Spin.up or Spin.down. Returns: - float: corresponding to chosen spin. + float: ICOHP value corresponding to chosen spin. """ if not self.is_spin_polarized and spin == Spin.down: raise ValueError("The calculation was not performed with spin polarization") - if isinstance(orbitals, list): + + if isinstance(orbitals, (tuple, list)): orbitals = f"{orbitals[0]}-{orbitals[1]}" + + assert self._orbitals is not None return self._orbitals[orbitals]["icohp"][spin] @property - def icohp(self): + def icohp(self) -> dict[Spin, float]: """Dict with ICOHPs for spin up and spin down. Returns: - dict={Spin.up: icohpvalue for spin.up, Spin.down: icohpvalue for spin.down}. + dict[Spin, float]: {Spin.up: ICOHP_up, Spin.down: ICOHP_down}. """ return self._icohp @property - def summed_icohp(self): - """Sums ICOHPs of both spin channels for spin polarized compounds. + def summed_icohp(self) -> float: + """Summed ICOHPs of both spin channels if spin polarized. Returns: - float: icohp value in eV. + float: ICOHP value in eV. """ return self._icohp[Spin.down] + self._icohp[Spin.up] if self._is_spin_polarized else self._icohp[Spin.up] @property - def summed_orbital_icohp(self): - """Sums orbital-resolved ICOHPs of both spin channels for spin-polarized compounds. + def summed_orbital_icohp(self) -> dict[str, float]: + """Summed orbital-resolved ICOHPs of both spin channels if spin-polarized. Returns: - dict[str, float]: "str(Orbital1)-str(Ortibal2)" mapped to ICOHP value in eV. + dict[str, float]: "str(Orbital1)-str(Ortibal2)": ICOHP value in eV. """ orbital_icohp = {} + assert self._orbitals is not None for orb, item in self._orbitals.items(): orbital_icohp[orb] = ( item["icohp"][Spin.up] + item["icohp"][Spin.down] if self._is_spin_polarized else item["icohp"][Spin.up] @@ -1035,49 +1103,49 @@ def summed_orbital_icohp(self): class IcohpCollection(MSONable): - """Store IcohpValues. + """Collection of IcohpValues. Attributes: - are_coops (bool): Boolean to indicate if these are ICOOPs. - are_cobis (bool): Boolean to indicate if these are ICOOPs. - is_spin_polarized (bool): Boolean to indicate if the Lobster calculation was done spin polarized or not. + are_coops (bool): Whether these are ICOOPs. + are_cobis (bool): Whether these are ICOOPs. + is_spin_polarized (bool): Whether the calculation is spin polarized. """ def __init__( self, - list_labels, - list_atom1, - list_atom2, - list_length, - list_translation, - list_num, - list_icohp, - is_spin_polarized, - list_orb_icohp=None, - are_coops=False, - are_cobis=False, + list_labels: list[str], + list_atom1: list[str], + list_atom2: list[str], + list_length: list[float], + list_translation: list[Vector3D], + list_num: list[int], + list_icohp: list[dict[Spin, float]], + is_spin_polarized: bool, + list_orb_icohp: list[dict[str, dict[Literal["icohp", "orbitals"], Any]]] | None = None, + are_coops: bool = False, + are_cobis: bool = False, ) -> None: """ Args: - list_labels: list of labels for ICOHP/ICOOP values - list_atom1: list of str of atomnames e.g. "O1" - list_atom2: list of str of atomnames e.g. "O1" - list_length: list of lengths of corresponding bonds in Angstrom - list_translation: list of translation list, e.g. [0,0,0] - list_num: list of equivalent bonds, usually 1 starting from Lobster 3.0.0 - list_icohp: list of dict={Spin.up: icohpvalue for spin.up, Spin.down: icohpvalue for spin.down} - is_spin_polarized: Boolean to indicate if the Lobster calculation was done spin polarized or not Boolean to - indicate if the Lobster calculation was done spin polarized or not - list_orb_icohp: list of dict={[str(Orbital1)-str(Orbital2)]: {"icohp":{Spin.up: icohpvalue for spin.up, - Spin.down: icohpvalue for spin.down}, "orbitals":[Orbital1, Orbital2]}} - are_coops: Boolean to indicate whether ICOOPs are stored - are_cobis: Boolean to indicate whether ICOBIs are stored. + list_labels (list[str]): Labels for ICOHP/ICOOP values. + list_atom1 (list[str]): Atom names, e.g. "O1". + list_atom2 (list[str]): Atom names, e.g. "O1". + list_length (list[float]): Bond lengths in Angstrom. + list_translation (list[Vector3D]): Cell translation vectors. + list_num (list[int]): Numbers of equivalent bonds, usually 1 starting from LOBSTER 3.0.0. + list_icohp (list[dict]): Dicts as {Spin.up: ICOHP_up, Spin.down: ICOHP_down}. + is_spin_polarized (bool): Whether the calculation is spin polarized. + list_orb_icohp (list[dict]): Dicts as {[str(Orbital1)-str(Orbital2)]: { + "icohp": {Spin.up: IcohpValue for spin.up, Spin.down: IcohpValue for spin.down}, + "orbitals": [Orbital1, Orbital2]}. + are_coops (bool): Whether ICOOPs are stored. + are_cobis (bool): Whether ICOBIs are stored. """ if are_coops and are_cobis: raise ValueError("You cannot have info about COOPs and COBIs in the same file.") + self._are_coops = are_coops self._are_cobis = are_cobis - self._icohplist = {} self._is_spin_polarized = is_spin_polarized self._list_labels = list_labels self._list_atom1 = list_atom1 @@ -1088,132 +1156,153 @@ def __init__( self._list_icohp = list_icohp self._list_orb_icohp = list_orb_icohp - for ilist, listel in enumerate(list_labels): - self._icohplist[listel] = IcohpValue( - label=listel, - atom1=list_atom1[ilist], - atom2=list_atom2[ilist], - length=list_length[ilist], - translation=list_translation[ilist], - num=list_num[ilist], - icohp=list_icohp[ilist], + # TODO: DanielYang: self._icohplist name is misleading + # (not list), and confuses with self._list_icohp + self._icohplist: dict[str, IcohpValue] = {} + for idx, label in enumerate(list_labels): + self._icohplist[label] = IcohpValue( + label=label, + atom1=list_atom1[idx], + atom2=list_atom2[idx], + length=list_length[idx], + translation=list_translation[idx], + num=list_num[idx], + icohp=list_icohp[idx], are_coops=are_coops, are_cobis=are_cobis, - orbitals=None if list_orb_icohp is None else list_orb_icohp[ilist], + orbitals=None if list_orb_icohp is None else list_orb_icohp[idx], ) def __str__(self) -> str: - lst = [] - for value in self._icohplist.values(): - lst.append(str(value)) - return "\n".join(lst) + return "\n".join([str(value) for value in self._icohplist.values()]) - def get_icohp_by_label(self, label, summed_spin_channels=True, spin=Spin.up, orbitals=None) -> float: - """Get an icohp value for a certain bond as indicated by the label (bond labels starting by "1" as in - ICOHPLIST/ICOOPLIST). + def get_icohp_by_label( + self, + label: str, + summed_spin_channels: bool = True, + spin: Spin = Spin.up, + orbitals: str | tuple[Orbital, Orbital] | None = None, + ) -> float: + """Get an ICOHP value for a certain bond indicated by the label. Args: - label: label in str format (usually the bond number in Icohplist.lobster/Icooplist.lobster - summed_spin_channels: Boolean to indicate whether the ICOHPs/ICOOPs of both spin channels should be summed - spin: if summed_spin_channels is equal to False, this spin indicates which spin channel should be returned - orbitals: List of Orbital or "str(Orbital1)-str(Orbital2)" + label (str): The bond number in Icohplist.lobster/Icooplist.lobster, + starting from "1". + summed_spin_channels (bool): Whether the ICOHPs/ICOOPs of both + spin channels should be summed. + spin (Spin): If not summed_spin_channels, indicate + which spin channel should be returned. + orbitals: List of Orbital or "str(Orbital1)-str(Orbital2)". Returns: - float: ICOHP/ICOOP value + float: ICOHP/ICOOP value. """ - icohp_here: IcohpValue = self._icohplist[label] + icohp: IcohpValue = self._icohplist[label] + if orbitals is None: - if summed_spin_channels: - return icohp_here.summed_icohp - return icohp_here.icohpvalue(spin) + return icohp.summed_icohp if summed_spin_channels else icohp.icohpvalue(spin) - if isinstance(orbitals, list): + if isinstance(orbitals, (tuple, list)): orbitals = f"{orbitals[0]}-{orbitals[1]}" + if summed_spin_channels: - return icohp_here.summed_orbital_icohp[orbitals] + return icohp.summed_orbital_icohp[orbitals] - return icohp_here.icohpvalue_orbital(spin=spin, orbitals=orbitals) + return icohp.icohpvalue_orbital(spin=spin, orbitals=orbitals) - def get_summed_icohp_by_label_list(self, label_list, divisor=1.0, summed_spin_channels=True, spin=Spin.up) -> float: - """Get the sum of several ICOHP values that are indicated by a list of labels - (labels of the bonds are the same as in ICOHPLIST/ICOOPLIST). + def get_summed_icohp_by_label_list( + self, + label_list: list[str], + divisor: float = 1.0, + summed_spin_channels: bool = True, + spin: Spin = Spin.up, + ) -> float: + """Get the sum of ICOHP values. Args: - label_list: list of labels of the ICOHPs/ICOOPs that should be summed - divisor: is used to divide the sum - summed_spin_channels: Boolean to indicate whether the ICOHPs/ICOOPs of both spin channels should be summed - spin: if summed_spin_channels is equal to False, this spin indicates which spin channel should be returned + label_list (list[str]): Labels of the ICOHPs/ICOOPs that should be summed, + the same as in ICOHPLIST/ICOOPLIST. + divisor (float): Divisor used to divide the sum. + summed_spin_channels (bool): Whether the ICOHPs/ICOOPs of both + spin channels should be summed. + spin (Spin): If not summed_spin_channels, indicate + which spin channel should be returned. Returns: - float: sum of all ICOHPs/ICOOPs as indicated with label_list + float: Sum of ICOHPs selected with label_list. """ - sum_icohp = 0 + sum_icohp: float = 0 for label in label_list: - icohp_here = self._icohplist[label] - if icohp_here.num_bonds != 1: + icohp = self._icohplist[label] + if icohp.num_bonds != 1: warnings.warn("One of the ICOHP values is an average over bonds. This is currently not considered.") - if icohp_here._is_spin_polarized: - if summed_spin_channels: - sum_icohp = sum_icohp + icohp_here.summed_icohp - else: - sum_icohp = sum_icohp + icohp_here.icohpvalue(spin) + + if icohp._is_spin_polarized and summed_spin_channels: + sum_icohp = sum_icohp + icohp.summed_icohp else: - sum_icohp = sum_icohp + icohp_here.icohpvalue(spin) + sum_icohp = sum_icohp + icohp.icohpvalue(spin) + return sum_icohp / divisor - def get_icohp_dict_by_bondlengths(self, minbondlength=0.0, maxbondlength=8.0): - """Get a dict of IcohpValues corresponding to certain bond lengths. + def get_icohp_dict_by_bondlengths( + self, + minbondlength: float = 0.0, + maxbondlength: float = 8.0, + ) -> dict[str, IcohpValue]: + """Get IcohpValues within certain bond length range. Args: - minbondlength: defines the minimum of the bond lengths of the bonds - maxbondlength: defines the maximum of the bond lengths of the bonds. + minbondlength (float): The minimum bond length. + maxbondlength (float): The maximum bond length. Returns: - dict of IcohpValues, the keys correspond to the values from the initial list_labels. + dict[str, IcohpValue]: Keys are the labels from the initial list_labels. """ new_icohp_dict = {} for value in self._icohplist.values(): - if value._length >= minbondlength and value._length <= maxbondlength: + if minbondlength <= value._length <= maxbondlength: new_icohp_dict[value._label] = value return new_icohp_dict def get_icohp_dict_of_site( self, - site, - minsummedicohp=None, - maxsummedicohp=None, - minbondlength=0.0, - maxbondlength=8.0, - only_bonds_to=None, - ): - """Get a dict of IcohpValue for a certain site (indicated by integer). + site: int, + minsummedicohp: float | None = None, + maxsummedicohp: float | None = None, + minbondlength: float = 0.0, + maxbondlength: float = 8.0, + only_bonds_to: list[str] | None = None, + ) -> dict[str, IcohpValue]: + """Get IcohpValues for a certain site. Args: - site: integer describing the site of interest, order as in Icohplist.lobster/Icooplist.lobster, starts at 0 - minsummedicohp: float, minimal icohp/icoop of the bonds that are considered. It is the summed ICOHP value - from both spin channels for spin polarized cases - maxsummedicohp: float, maximal icohp/icoop of the bonds that are considered. It is the summed ICOHP value - from both spin channels for spin polarized cases - minbondlength: float, defines the minimum of the bond lengths of the bonds - maxbondlength: float, defines the maximum of the bond lengths of the bonds - only_bonds_to: list of strings describing the bonding partners that are allowed, e.g. ['O'] + site (int): The site of interest, ordered as in Icohplist.lobster/Icooplist.lobster, + starts from 0. + minsummedicohp (float): Minimal ICOHP/ICOOP of the bonds that are considered. + It is the summed ICOHP value from both spin channels for spin polarized cases + maxsummedicohp (float): Maximal ICOHP/ICOOP of the bonds that are considered. + It is the summed ICOHP value from both spin channels for spin polarized cases + minbondlength (float): The minimum bond length. + maxbondlength (float): The maximum bond length. + only_bonds_to (list[str]): The bonding partners that are allowed, e.g. ["O"]. Returns: - dict of IcohpValues, the keys correspond to the values from the initial list_labels + Dict of IcohpValues, the keys correspond to the values from the initial list_labels. """ new_icohp_dict = {} for key, value in self._icohplist.items(): atomnumber1 = int(re.split(r"(\d+)", value._atom1)[1]) - 1 atomnumber2 = int(re.split(r"(\d+)", value._atom2)[1]) - 1 if site in (atomnumber1, atomnumber2): - # manipulate order of atoms so that searched one is always atom1 + # Swap order of atoms so that searched one is always atom1 if site == atomnumber2: save = value._atom1 value._atom1 = value._atom2 value._atom2 = save second_test = True if only_bonds_to is None else re.split("(\\d+)", value._atom2)[0] in only_bonds_to - if value._length >= minbondlength and value._length <= maxbondlength and second_test: + if minbondlength <= value._length <= maxbondlength and second_test: + # TODO: DanielYang: merge the following condition blocks if minsummedicohp is not None: if value.summed_icohp >= minsummedicohp: if maxsummedicohp is not None: @@ -1229,16 +1318,21 @@ def get_icohp_dict_of_site( return new_icohp_dict - def extremum_icohpvalue(self, summed_spin_channels=True, spin=Spin.up): - """Get ICOHP/ICOOP of strongest bond. + def extremum_icohpvalue( + self, + summed_spin_channels: bool = True, + spin: Spin = Spin.up, + ) -> float: + """Get ICOHP/ICOOP of the strongest bond. Args: - summed_spin_channels: Boolean to indicate whether the ICOHPs/ICOOPs of both spin channels should be summed. - - spin: if summed_spin_channels is equal to False, this spin indicates which spin channel should be returned + summed_spin_channels (bool): Whether the ICOHPs/ICOOPs of both + spin channels should be summed. + spin (Spin): If not summed_spin_channels, this indicates which + spin channel should be returned. Returns: - lowest ICOHP/largest ICOOP value (i.e. ICOHP/ICOOP value of strongest bond) + Lowest ICOHP/largest ICOOP value (i.e. ICOHP/ICOOP value of strongest bond). """ extremum = -sys.float_info.max if self._are_coops or self._are_cobis else sys.float_info.max @@ -1254,60 +1348,71 @@ def extremum_icohpvalue(self, summed_spin_channels=True, spin=Spin.up): extremum = value.icohpvalue(spin) elif value.icohpvalue(spin) > extremum: extremum = value.icohpvalue(spin) + elif not self._are_coops and not self._are_cobis: if value.summed_icohp < extremum: extremum = value.summed_icohp + elif value.summed_icohp > extremum: extremum = value.summed_icohp + return extremum @property def is_spin_polarized(self) -> bool: - """Whether it is spin polarized.""" + """Whether this is spin polarized.""" return self._is_spin_polarized @property def are_coops(self) -> bool: - """Whether this is a coop.""" + """Whether this is COOP.""" return self._are_coops @property def are_cobis(self) -> bool: - """Whether this a cobi.""" + """Whether this is COBI.""" return self._are_cobis def get_integrated_cohp_in_energy_range( - cohp, label, orbital=None, energy_range=None, relative_E_Fermi=True, summed_spin_channels=True -): - """Integrate CompleteCohp objects which include data on integrated COHPs + cohp: CompleteCohp, + label: str, + orbital: str | None = None, + energy_range: float | tuple[float, float] | None = None, + relative_E_Fermi: bool = True, + summed_spin_channels: bool = True, +) -> float | dict[Spin, float]: + """Integrate CompleteCohps which include data of integrated COHPs (ICOHPs). + Args: - cohp: CompleteCohp object - label: label of the COHP data - orbital: If not None, a orbital resolved integrated COHP will be returned - energy_range: If None, returns icohp value at Fermi level. - If float, integrates from this float up to the Fermi level. - If [float,float], will integrate in between. - relative_E_Fermi: if True, energy scale with E_Fermi at 0 eV is chosen - summed_spin_channels: if True, Spin channels will be summed. + cohp (CompleteCohp): CompleteCohp object. + label (str): Label of the COHP data. + orbital (str): If not None, a orbital resolved integrated COHP will be returned. + energy_range: If None, return the ICOHP value at Fermi level. + If float, integrate from this value up to Fermi level. + If (float, float), integrate in between. + relative_E_Fermi (bool): Whether energy scale with Fermi level at 0 eV is chosen. + summed_spin_channels (bool): Whether Spin channels will be summed. Returns: - float indicating the integrated COHP if summed_spin_channels==True, otherwise dict of the following form { - Spin.up:float, Spin.down:float} + If summed_spin_channels: + float: the ICOHP. + else: + dict: {Spin.up: float, Spin.down: float} """ - summedicohp = {} if orbital is None: icohps = cohp.all_cohps[label].get_icohp(spin=None) - if summed_spin_channels and Spin.down in icohps: - summedicohp[Spin.up] = icohps[Spin.up] + icohps[Spin.down] - else: - summedicohp = icohps else: - icohps = cohp.get_orbital_resolved_cohp(label=label, orbitals=orbital).icohp - if summed_spin_channels and Spin.down in icohps: - summedicohp[Spin.up] = icohps[Spin.up] + icohps[Spin.down] - else: - summedicohp = icohps + _icohps = cohp.get_orbital_resolved_cohp(label=label, orbitals=orbital) + assert _icohps is not None + icohps = _icohps.icohp + + assert icohps is not None + summedicohp = {} + if summed_spin_channels and Spin.down in icohps: + summedicohp[Spin.up] = icohps[Spin.up] + icohps[Spin.down] + else: + summedicohp = icohps if energy_range is None: energies_corrected = cohp.energies - cohp.efermi @@ -1316,12 +1421,10 @@ def get_integrated_cohp_in_energy_range( if not summed_spin_channels and Spin.down in icohps: spl_spindown = InterpolatedUnivariateSpline(energies_corrected, summedicohp[Spin.down], ext=0) return {Spin.up: spl_spinup(0.0), Spin.down: spl_spindown(0.0)} - if summed_spin_channels: - return spl_spinup(0.0) - return {Spin.up: spl_spinup(0.0)} + return spl_spinup(0.0) if summed_spin_channels else {Spin.up: spl_spinup(0.0)} - # returns icohp value at the Fermi level! + # Return ICOHP value at the Fermi level if isinstance(energy_range, float): if relative_E_Fermi: energies_corrected = cohp.energies - cohp.efermi diff --git a/src/pymatgen/electronic_structure/core.py b/src/pymatgen/electronic_structure/core.py index 1297c8495aa..459f8de72a0 100644 --- a/src/pymatgen/electronic_structure/core.py +++ b/src/pymatgen/electronic_structure/core.py @@ -1,11 +1,11 @@ -"""This module provides core classes needed by all define electronic structure, -such as the Spin, Orbital, etc. +"""This module provides core classes to define electronic structure, +including Spin, Orbital and Magmom. """ from __future__ import annotations from enum import Enum, unique -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, cast import numpy as np from monty.json import MSONable @@ -13,9 +13,12 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pymatgen.core import Lattice + from numpy.typing import NDArray from typing_extensions import Self + from pymatgen.core import Lattice + from pymatgen.util.typing import MagMomentLike, Vector3D + @unique class Spin(Enum): @@ -23,14 +26,14 @@ class Spin(Enum): up, down = 1, -1 - def __int__(self) -> int: - return self.value + def __int__(self) -> Literal[-1, 1]: + return cast(Literal[-1, 1], self.value) def __float__(self) -> float: return float(self.value) - def __str__(self) -> str: - return str(self.value) + def __str__(self) -> Literal["-1", "1"]: + return cast(Literal["-1", "1"], str(self.value)) @unique @@ -42,8 +45,8 @@ class OrbitalType(Enum): d = 2 f = 3 - def __str__(self) -> str: - return str(self.name) + def __str__(self) -> Literal["s", "p", "d", "f"]: + return cast(Literal["s", "p", "d", "f"], str(self.name)) @unique @@ -76,185 +79,249 @@ def __str__(self) -> str: return str(self.name) @property - def orbital_type(self): + def orbital_type(self) -> OrbitalType: """OrbitalType of an orbital.""" return OrbitalType[self.name[0]] class Magmom(MSONable): - """New class in active development. Use with caution, feedback is - appreciated. + """In active development. Use with caution, feedback is appreciated. - Class to handle magnetic moments. Defines the magnetic moment of a + Class to handle magnetic moments. Define the magnetic moment of a site or species relative to a spin quantization axis. Designed for use in electronic structure calculations. - * For the general case, Magmom can be specified by a vector, - e.g. m = Magmom([1.0, 1.0, 2.0]), and subscripts will work as - expected, e.g. m[0] gives 1.0 + * For the general case, Magmom can be specified by a 3D vector, + e.g. m = Magmom([1.0, 1.0, 2.0]), and indexing will work as + expected, e.g. m[0] gives 1.0. - * For collinear calculations, Magmom can assumed to be scalar-like, - e.g. m = Magmom(5.0) will work as expected, e.g. float(m) gives 5.0 + * For collinear calculations, Magmom can assumed to be float-like, + e.g. m = Magmom(5.0) will work as expected, e.g. float(m) gives 5.0. - Both of these cases should be safe and shouldn't give any surprises, - but more advanced functionality is available if required. + Both cases should be safe and shouldn't give any surprise, + and more advanced functionality is available if required. - There also exist useful static methods for lists of magmoms: + There are also static methods for sequences of magmoms: - * Magmom.are_collinear(magmoms) - if true, a collinear electronic - structure calculation can be safely initialized, with float(Magmom) - giving the expected scalar magnetic moment value + * Magmom.are_collinear(magmoms) - If True, a collinear electronic + structure calculation can be safely initialized, with float(Magmom) + giving the expected scalar magnetic moment value. - * Magmom.get_consistent_set_and_saxis(magmoms) - for non-collinear - electronic structure calculations, a global, consistent spin axis - has to be used. This method returns a list of Magmoms which all - share a common spin axis, along with the global spin axis. + * Magmom.get_consistent_set_and_saxis(magmoms) - For non-collinear + electronic structure calculations, a global and consistent spin axis + has to be used. This method returns a list of Magmoms which all + share a common spin axis, along with the global spin axis. - All methods that take lists of magmoms will accept magmoms either as - Magmom objects or as scalars/lists and will automatically convert to - a Magmom representation internally. + All methods that take sequence of magmoms will accept either Magmom + objects, or as scalars/lists and will automatically convert to Magmom + representations internally. - The following methods are also particularly useful in the context of - VASP calculations: + The following methods are also useful for VASP calculations: + - Magmom.get_xyz_magmom_with_001_saxis() + - Magmom.get_00t_magmom_with_xyz_saxis() - * Magmom.get_xyz_magmom_with_001_saxis() - * Magmom.get_00t_magmom_with_xyz_saxis() - - See VASP documentation for more information: - - https://cms.mpi.univie.ac.at/wiki/index.php/SAXIS + See VASP documentation for more information: + https://cms.mpi.univie.ac.at/wiki/index.php/SAXIS """ def __init__( - self, moment: float | Sequence[float] | np.ndarray | Magmom, saxis: Sequence[float] = (0, 0, 1) + self, + moment: MagMomentLike, + saxis: Vector3D = (0, 0, 1), ) -> None: """ Args: - moment: magnetic moment, supplied as float or list/np.ndarray - saxis: spin axis, supplied as list/np.ndarray, parameter will - be converted to unit vector (default is [0, 0, 1]). - - Returns: - Magmom object + moment (float | Sequence[float] | NDArray, Magmom): Magnetic moment. + saxis (Vector3D): Spin axis, and will be converted to unit + vector (default is (0, 0, 1)). """ - # to init from another Magmom instance - if isinstance(moment, Magmom): + # Init from another Magmom instance + if isinstance(moment, type(self)): saxis = moment.saxis # type: ignore[has-type] moment = moment.moment # type: ignore[has-type] - moment = np.array(moment, dtype="d") - if moment.ndim == 0: - moment = moment * [0, 0, 1] + magmom: NDArray = np.array(moment, dtype="d") + if magmom.ndim == 0: + magmom = magmom * (0, 0, 1) - self.moment = moment + self.moment = magmom saxis = np.array(saxis, dtype="d") self.saxis = saxis / np.linalg.norm(saxis) - @classmethod - def from_global_moment_and_saxis(cls, global_moment, saxis) -> Self: - """Convenience method to initialize Magmom from a given global - magnetic moment, i.e. magnetic moment with saxis=(0,0,1), and - provided saxis. + def __getitem__(self, key): + return self.moment[key] - Method is useful if you do not know the components of your - magnetic moment in frame of your desired saxis. + def __iter__(self): + return iter(self.moment) - Args: - global_moment: global magnetic moment - saxis: desired saxis - """ - magmom = Magmom(global_moment) - return cls(magmom.get_moment(saxis=saxis), saxis=saxis) + def __abs__(self) -> float: + return np.linalg.norm(self.moment) - @classmethod - def _get_transformation_matrix(cls, saxis): - saxis = saxis / np.linalg.norm(saxis) + def __eq__(self, other: object) -> bool: + """Whether global magnetic moments are the same, saxis can differ.""" + try: + other_magmom = type(self)(other) + except (TypeError, ValueError): + return NotImplemented - alpha = np.arctan2(saxis[1], saxis[0]) - beta = np.arctan2(np.sqrt(saxis[0] ** 2 + saxis[1] ** 2), saxis[2]) + return np.allclose(self.global_moment, other_magmom.global_moment) - cos_a = np.cos(alpha) - cos_b = np.cos(beta) - sin_a = np.sin(alpha) - sin_b = np.sin(beta) + def __lt__(self, other: Self) -> bool: + return abs(self) < abs(other) - return [ - [cos_b * cos_a, -sin_a, sin_b * cos_a], - [cos_b * sin_a, cos_a, sin_b * sin_a], - [-sin_b, 0, cos_b], - ] + def __neg__(self) -> Self: + return type(self)(-self.moment, saxis=self.saxis) - @classmethod - def _get_transformation_matrix_inv(cls, saxis): - saxis = saxis / np.linalg.norm(saxis) + def __hash__(self) -> int: + return hash(tuple(self.moment) + tuple(self.saxis)) - alpha = np.arctan2(saxis[1], saxis[0]) - beta = np.arctan2(np.sqrt(saxis[0] ** 2 + saxis[1] ** 2), saxis[2]) + def __float__(self) -> float: + """Get magnitude of magnetic moment with a sign with respect to + an arbitrary direction. + + Should give unsurprising output if Magmom is treated like a + float or if a set of Magmoms describes a collinear structure. + + Implemented this way rather than simpler abs(self) so that + moments will have a consistent sign in case of e.g. + antiferromagnetic collinear structures without additional + user intervention. + + However, should be used with caution for non-collinear + structures and might give nonsensical results except in the case + of only slightly non-collinear structures (e.g. small canting). + + This method is also used to obtain "diff" VolumetricDensity + in pymatgen.io.vasp.outputs.VolumetricDensity when processing + CHGCARs from SOC calculations. + """ + return float(self.get_00t_magmom_with_xyz_saxis()[2]) + + def __str__(self) -> str: + return str(float(self)) + + def __repr__(self) -> str: + if np.allclose(self.saxis, (0, 0, 1)): + return f"Magnetic moment {self.moment}" + return f"Magnetic moment {self.moment} (spin axis = {self.saxis})" + + @classmethod + def from_global_moment_and_saxis( + cls, + global_moment: MagMomentLike, + saxis: Vector3D, + ) -> Self: + """Initialize Magmom from a given global magnetic moment, + i.e. magnetic moment with saxis=(0, 0, 1), and provided saxis. - cos_a = np.cos(alpha) - cos_b = np.cos(beta) - sin_a = np.sin(alpha) - sin_b = np.sin(beta) + Method is useful if you do not know the components of your + magnetic moment in frame of your desired spin axis. - return [ - [cos_b * cos_a, cos_b * sin_a, -sin_b], - [-sin_a, cos_a, 0], - [sin_b * cos_a, sin_b * sin_a, cos_b], - ] + Args: + global_moment (MagMomentLike): Global magnetic moment. + saxis (Vector3D): Spin axis. + """ + magmom = cls(global_moment) + return cls(magmom.get_moment(saxis=saxis), saxis=saxis) - def get_moment(self, saxis=(0, 0, 1)): + def get_moment(self, saxis: Vector3D = (0, 0, 1)) -> NDArray: """Get magnetic moment relative to a given spin quantization axis. If no axis is provided, moment will be given relative to the Magmom's internal spin quantization axis, i.e. equivalent to Magmom.moment. Args: - saxis: (list/numpy array) spin quantization axis + saxis (Vector3D): Spin quantization axis. Returns: - np.ndarray of length 3 + NDArray of length 3. """ - # transform back to moment with spin axis [0, 0, 1] - trafo_mat_inv = self._get_transformation_matrix_inv(self.saxis) + + def get_transformation_matrix( + saxis: Vector3D, + ) -> tuple[Vector3D, Vector3D, Vector3D]: + """Get the matrix to transform spin axis to z-axis.""" + saxis = saxis / np.linalg.norm(saxis) + + alpha = np.arctan2(saxis[1], saxis[0]) + beta = np.arctan2(np.sqrt(saxis[0] ** 2 + saxis[1] ** 2), saxis[2]) + + cos_a = np.cos(alpha) + cos_b = np.cos(beta) + sin_a = np.sin(alpha) + sin_b = np.sin(beta) + + return ( + (cos_b * cos_a, -sin_a, sin_b * cos_a), + (cos_b * sin_a, cos_a, sin_b * sin_a), + (-sin_b, 0, cos_b), + ) + + def get_transformation_matrix_inv( + saxis: Vector3D, + ) -> tuple[Vector3D, Vector3D, Vector3D]: + """Get the inverse of matrix to transform spin axis to z-axis.""" + saxis = saxis / np.linalg.norm(saxis) + + alpha = np.arctan2(saxis[1], saxis[0]) + beta = np.arctan2(np.sqrt(saxis[0] ** 2 + saxis[1] ** 2), saxis[2]) + + cos_a = np.cos(alpha) + cos_b = np.cos(beta) + sin_a = np.sin(alpha) + sin_b = np.sin(beta) + + return ( + (cos_b * cos_a, cos_b * sin_a, -sin_b), + (-sin_a, cos_a, 0), + (sin_b * cos_a, sin_b * sin_a, cos_b), + ) + + # Transform to moment with spin axis (0, 0, 1) + trafo_mat_inv = get_transformation_matrix_inv(self.saxis) moment = np.matmul(self.moment, trafo_mat_inv) - # transform to new saxis - trafo_mat = self._get_transformation_matrix(saxis) + # Transform to new saxis + trafo_mat = get_transformation_matrix(saxis) moment = np.matmul(moment, trafo_mat) - # round small values to zero + # Round small values to zero moment[np.abs(moment) < 1e-8] = 0 return moment @property - def global_moment(self) -> np.ndarray: - """The magnetic moment defined in an arbitrary global reference frame as an np.array of length 3.""" + def global_moment(self) -> NDArray: + """The magnetic moment defined in an arbitrary global reference frame, + as a np.array of length 3. + """ return self.get_moment() @property - def projection(self): - """Projects moment along spin quantization axis. Useful for obtaining - collinear approximation for slightly non-collinear magmoms. + def projection(self) -> float: + """Project moment along spin quantization axis. + + Useful for obtaining collinear approximation for slightly non-collinear magmoms. Returns: - float + float: The projected moment. """ return np.dot(self.moment, self.saxis) - def get_xyz_magmom_with_001_saxis(self): - """Get a Magmom in the default setting of saxis = [0, 0, 1] and + def get_xyz_magmom_with_001_saxis(self) -> Self: + """Get a Magmom in the default setting of saxis = (0, 0, 1) and the magnetic moment rotated as required. Returns: Magmom """ - return Magmom(self.get_moment()) + return type(self)(self.get_moment()) - def get_00t_magmom_with_xyz_saxis(self): - """For internal implementation reasons, in non-collinear calculations VASP prefers the following. + def get_00t_magmom_with_xyz_saxis(self) -> Self: + """For internal implementation reasons, the non-collinear calculations + in VASP prefer the following. MAGMOM = 0 0 total_magnetic_moment SAXIS = x y z @@ -264,100 +331,101 @@ def get_00t_magmom_with_xyz_saxis(self): MAGMOM = x y z SAXIS = 0 0 1 - This method returns a Magmom object with magnetic moment [0, 0, t], - where t is the total magnetic moment, and saxis rotated as required. - - A consistent direction of saxis is applied such that t might be positive - or negative depending on the direction of the initial moment. This is useful - in the case of collinear structures, rather than constraining assuming - t is always positive. - Returns: - Magmom + Magmom: With magnetic moment (0, 0, t), where t is the total magnetic + moment, and saxis rotated as required. + + A consistent direction of saxis is applied such that t might be + positive or negative depending on the direction of the initial moment. + This is useful in the case of collinear structures, rather than + assuming t is always positive. """ - # reference direction gives sign of moment - # entirely arbitrary, there will always be a pathological case - # where a consistent sign is not possible if the magnetic moments - # are aligned along the reference direction, but in practice this - # is unlikely to happen + # Reference direction gives sign of moment arbitrarily, + # there can be a pathological case where a consistent sign + # is not possible if the magnetic moments are aligned along the + # reference direction, but in practice this is unlikely to happen. ref_direction = np.array([1.01, 1.02, 1.03]) - t = abs(self) - if t != 0: + total_magmom = abs(self) + if total_magmom != 0: new_saxis = self.moment / np.linalg.norm(self.moment) if np.dot(ref_direction, new_saxis) < 0: - t = -t + total_magmom = -total_magmom new_saxis = -new_saxis - return Magmom([0, 0, t], saxis=new_saxis) - return Magmom(self) + return type(self)([0, 0, total_magmom], saxis=new_saxis) + return type(self)(self) @staticmethod - def have_consistent_saxis(magmoms) -> bool: - """Check that all Magmom objects in a list have a consistent spin quantization axis. - To write MAGMOM tags to a VASP INCAR, a global SAXIS value for all magmoms has to be used. - If saxis are inconsistent, can create consistent set with: - Magmom.get_consistent_set(magmoms). + def have_consistent_saxis(magmoms: Sequence[MagMomentLike]) -> bool: + """Check whether all Magmoms have a consistent spin quantization axis. + + To write MAGMOM tags to a VASP INCAR, a consistent global SAXIS value for + all magmoms has to be used. + + If spin axes are inconsistent, can create a consistent set with: + Magmom.get_consistent_set(magmoms). Args: - magmoms: list of magmoms (Magmoms, scalars or vectors) + magmoms (Sequence[MagMomentLike]): Magmoms. Returns: bool """ - magmoms = [Magmom(magmom) for magmom in magmoms] - ref_saxis = magmoms[0].saxis - match_ref = [magmom.saxis == ref_saxis for magmom in magmoms] - return np.all(match_ref) + _magmoms: list[Magmom] = [Magmom(magmom) for magmom in magmoms] + ref_saxis = _magmoms[0].saxis + match_ref = [magmom.saxis == ref_saxis for magmom in _magmoms] + return bool(np.all(match_ref)) @staticmethod - def get_consistent_set_and_saxis(magmoms, saxis=None): - """Ensure a list of magmoms use the same spin axis. - Returns a tuple of a list of Magmoms and their global spin axis. + def get_consistent_set_and_saxis( + magmoms: Sequence[MagMomentLike], + saxis: Vector3D | None = None, + ) -> tuple[list[Magmom], NDArray]: + """Ensure magmoms use the same spin axis. Args: - magmoms: list of magmoms (Magmoms, scalars or vectors) - saxis: can provide a specific global spin axis + magmoms (Sequence[MagMomentLike]): Magmoms, floats or vectors. + saxis (Vector3D): An optional global spin axis. Returns: - tuple[list[Magmom], np.ndarray]: (list of Magmoms, global spin axis) + tuple[list[Magmom], NDArray]: Magmoms and their global spin axes. """ - magmoms = [Magmom(magmom) for magmom in magmoms] - saxis = Magmom.get_suggested_saxis(magmoms) if saxis is None else saxis / np.linalg.norm(saxis) - magmoms = [magmom.get_moment(saxis=saxis) for magmom in magmoms] - return magmoms, saxis + _magmoms: list[Magmom] = [Magmom(magmom) for magmom in magmoms] + _saxis: NDArray = Magmom.get_suggested_saxis(_magmoms) if saxis is None else saxis / np.linalg.norm(saxis) + moments: list[NDArray] = [magmom.get_moment(saxis=_saxis) for magmom in _magmoms] + return moments, _saxis @staticmethod - def get_suggested_saxis(magmoms): - """Get a suggested spin axis for a set of magmoms, - taking the largest magnetic moment as the reference. For calculations - with collinear spins, this would give a sensible saxis for a ncl - calculation. + def get_suggested_saxis(magmoms: Sequence[MagMomentLike]) -> NDArray: + """Get a suggested spin axis for magmoms, taking the largest magnetic + moment as the reference. For calculations with collinear spins, + this would give a sensible saxis for a NCL calculation. Args: - magmoms: list of magmoms (Magmoms, scalars or vectors) + magmoms (Sequence[MagMomentLike]): Magmoms, floats or vectors. Returns: - np.ndarray of length 3 + NDArray of length 3 """ - # heuristic, will pick largest magmom as reference - # useful for creating collinear approximations of - # e.g. slightly canted magnetic structures - # for fully collinear structures, will return expected - # result - - magmoms = [Magmom(magmom) for magmom in magmoms] - # filter only non-zero magmoms - magmoms = [magmom for magmom in magmoms if abs(magmom)] - magmoms.sort(reverse=True) - if len(magmoms) > 0: - return magmoms[0].get_00t_magmom_with_xyz_saxis().saxis + # Heuristic, will pick largest magmom as the reference. + # Useful for creating collinear approximations of + # e.g. slightly canted magnetic structures. + # For fully collinear structures, will return expected result. + + _magmoms: list[Magmom] = [Magmom(magmom) for magmom in magmoms] + # Keep non-zero magmoms only + _magmoms = [magmom for magmom in _magmoms if abs(magmom)] # type: ignore[arg-type] + _magmoms.sort(reverse=True) + + if _magmoms: + return _magmoms[0].get_00t_magmom_with_xyz_saxis().saxis return np.array([0, 0, 1], dtype="d") @staticmethod - def are_collinear(magmoms) -> bool: - """Check if a set of magnetic moments are collinear with each other. + def are_collinear(magmoms: Sequence[MagMomentLike]) -> bool: + """Check if a list of magnetic moments are collinear with each other. Args: - magmoms: list of magmoms (Magmoms, scalars or vectors). + magmoms (Sequence[MagMomentLike]): Magmoms, floats or vectors. Returns: bool. @@ -366,110 +434,57 @@ def are_collinear(magmoms) -> bool: if not Magmom.have_consistent_saxis(magmoms): magmoms = Magmom.get_consistent_set_and_saxis(magmoms)[0] - # convert to numpy array for convenience - magmoms = np.array([list(magmom) for magmom in magmoms]) + # Convert to numpy array for convenience + magmoms = np.array([list(cast(Magmom, magmom)) for magmom in magmoms]) magmoms = magmoms[np.any(magmoms, axis=1)] # remove zero magmoms if len(magmoms) == 0: return True - # use first moment as reference to compare against + # Use first moment as reference to compare against ref_magmom = magmoms[0] - # magnitude of cross products != 0 if non-collinear with reference + # Magnitude of cross products != 0 if non-collinear with reference num_ncl = np.count_nonzero(np.linalg.norm(np.cross(ref_magmom, magmoms), axis=1)) return num_ncl == 0 @classmethod - def from_moment_relative_to_crystal_axes(cls, moment: list[float], lattice: Lattice) -> Self: - """Obtaining a Magmom object from a magnetic moment provided + def from_moment_relative_to_crystal_axes( + cls, + moment: Vector3D, + lattice: Lattice, + ) -> Self: + """Obtain a Magmom object from a magnetic moment provided relative to crystal axes. Used for obtaining moments from magCIF file. Args: - moment: list of floats specifying vector magmom - lattice: Lattice + moment (Vector3D): Magnetic moment. + lattice (Lattice): Lattice. Returns: Magmom """ - # get matrix representing unit lattice vectors + # Get matrix representing unit lattice vectors unit_m = lattice.matrix / np.linalg.norm(lattice.matrix, axis=1)[:, None] - moment = np.matmul(list(moment), unit_m) - # round small values to zero - moment[np.abs(moment) < 1e-8] = 0 - return cls(moment) + _moment: NDArray = np.matmul(list(moment), unit_m) + # Round small values to zero + _moment[np.abs(_moment) < 1e-8] = 0 + return cls(_moment) - def get_moment_relative_to_crystal_axes(self, lattice): + def get_moment_relative_to_crystal_axes(self, lattice: Lattice) -> Vector3D: """If scalar magmoms, moments will be given arbitrarily along z. + Used for writing moments to magCIF file. Args: - lattice: Lattice + lattice (Lattice): The lattice. Returns: - vector as list of floats + Vector3D """ - # get matrix representing unit lattice vectors + # Get matrix representing unit lattice vectors unit_m = lattice.matrix / np.linalg.norm(lattice.matrix, axis=1)[:, None] - # note np.matmul() requires numpy version >= 1.10 moment = np.matmul(self.global_moment, np.linalg.inv(unit_m)) - # round small values to zero + # Round small values to zero moment[np.abs(moment) < 1e-8] = 0 return moment - - def __getitem__(self, key): - return self.moment[key] - - def __iter__(self): - return iter(self.moment) - - def __abs__(self): - return np.linalg.norm(self.moment) - - def __eq__(self, other: object) -> bool: - """Equal if 'global' magnetic moments are the same, saxis can differ.""" - try: - other_magmom = Magmom(other) - except (TypeError, ValueError): - return NotImplemented - - return np.allclose(self.global_moment, other_magmom.global_moment) - - def __lt__(self, other): - return abs(self) < abs(other) - - def __neg__(self): - return Magmom(-self.moment, saxis=self.saxis) - - def __hash__(self) -> int: - return hash(tuple(self.moment) + tuple(self.saxis)) - - def __float__(self) -> float: - """Get magnitude of magnetic moment with a sign with respect to - an arbitrary direction. - - Should give unsurprising output if Magmom is treated like a - scalar or if a set of Magmoms describes a collinear structure. - - Implemented this way rather than simpler abs(self) so that - moments will have a consistent sign in case of e.g. - antiferromagnetic collinear structures without additional - user intervention. - - However, should be used with caution for non-collinear - structures and might give nonsensical results except in the case - of only slightly non-collinear structures (e.g. small canting). - - This approach is also used to obtain "diff" VolumetricDensity - in pymatgen.io.vasp.outputs.VolumetricDensity when processing - Chgcars from SOC calculations. - """ - return float(self.get_00t_magmom_with_xyz_saxis()[2]) - - def __str__(self) -> str: - return str(float(self)) - - def __repr__(self) -> str: - if np.allclose(self.saxis, (0, 0, 1)): - return f"Magnetic moment {self.moment}" - return f"Magnetic moment {self.moment} (spin axis = {self.saxis})" diff --git a/src/pymatgen/electronic_structure/dos.py b/src/pymatgen/electronic_structure/dos.py index 26662562e7c..a5abd47d472 100644 --- a/src/pymatgen/electronic_structure/dos.py +++ b/src/pymatgen/electronic_structure/dos.py @@ -1,68 +1,92 @@ -"""This module defines classes to represent the density of states, etc.""" +"""This module defines classes to represent the density of states (DOS), etc.""" from __future__ import annotations import functools import warnings -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, cast import numpy as np from monty.json import MSONable +from scipy.constants import value as _constant +from scipy.ndimage import gaussian_filter1d +from scipy.signal import hilbert + from pymatgen.core import Structure, get_el_sp from pymatgen.core.spectrum import Spectrum from pymatgen.electronic_structure.core import Orbital, OrbitalType, Spin from pymatgen.util.coord import get_linear_interpolated_value -from scipy.constants import value as _cd -from scipy.ndimage import gaussian_filter1d -from scipy.signal import hilbert if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Sequence + from typing import Any, Literal + + from numpy.typing import NDArray + from typing_extensions import Self - from numpy.typing import ArrayLike, NDArray from pymatgen.core.sites import PeriodicSite from pymatgen.util.typing import SpeciesLike, Tuple3Floats - from typing_extensions import Self class DOS(Spectrum): - """Replacement basic DOS object. All other DOS objects are extended versions - of this object. Work in progress. + """(Work in progress) Replacement of basic DOS object. + All other DOS objects are extended versions of this. Attributes: - energies (Sequence[float]): The sequence of energies. - densities (dict[Spin, Sequence[float]]): A dict of spin densities, e.g. {Spin.up: [...], Spin.down: [...]}. - efermi (float): Fermi level. + energies (Sequence[float]): Energies. + densities (dict[Spin, NDArray]): Spin densities, + e.g. {Spin.up: DOS_up, Spin.down: DOS_down}. + efermi (float): The Fermi level. """ XLABEL = "Energy" YLABEL = "Density" - def __init__(self, energies: ArrayLike, densities: ArrayLike, efermi: float) -> None: + def __init__(self, energies: Sequence[float], densities: NDArray, efermi: float) -> None: """ Args: - energies: A sequence of energies - densities (ndarray): Either a Nx1 or a Nx2 array. If former, it is + energies (Sequence[float]): The Energies. + densities (NDArray): A Nx1 or Nx2 array. If former, it is interpreted as a Spin.up only density. Otherwise, the first column - is interpreted as Spin.up and the other is Spin.down. - efermi: Fermi level energy. + is interpreted as Spin.up and the other Spin.down. + efermi (float): The Fermi level. """ super().__init__(energies, densities, efermi) self.efermi = efermi - def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None): - """Expects a DOS object and finds the gap. + def __str__(self) -> str: + """A string which can be easily plotted.""" + if Spin.down in self.densities: + str_arr = [f"#{'Energy':30s} {'DensityUp':30s} {'DensityDown':30s}"] + for idx, energy in enumerate(self.energies): + str_arr.append(f"{energy:.5f} {self.densities[Spin.up][idx]:.5f} {self.densities[Spin.down][idx]:.5f}") + + else: + str_arr = [f"#{'Energy':30s} {'DensityUp':30s}"] + for idx, energy in enumerate(self.energies): + str_arr.append(f"{energy:.5f} {self.densities[Spin.up][idx]:.5f}") + + return "\n".join(str_arr) + + def get_interpolated_gap( + self, + tol: float = 0.001, + abs_tol: bool = False, + spin: Spin | None = None, + ) -> tuple[float, float, float]: + """Find the interpolated band gap. Args: - tol: tolerance in occupations for determining the gap - abs_tol: Set to True for an absolute tolerance and False for a - relative one. - spin: Possible values are None - finds the gap in the summed - densities, Up - finds the gap in the up spin channel, - Down - finds the gap in the down spin channel. + tol (float): Tolerance in occupations for determining the gap. + abs_tol (bool): Use absolute (True) or relative (False) tolerance. + spin (Spin | None): Find the gap: + - None: In the summed DOS. + - Up: In the spin up channel. + - Down: In the spin down channel. Returns: - tuple[float, float, float]: Energies in eV corresponding to the band gap, cbm and vbm. + tuple[float, float, float]: Energies in eV corresponding to the + band gap, CBM and VBM. """ if spin is None: tdos = self.y if len(self.ydim) == 1 else np.sum(self.y, axis=1) @@ -73,6 +97,7 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: if not abs_tol: tol = tol * tdos.sum() / tdos.shape[0] + energies = self.x below_fermi = [i for i in range(len(energies)) if energies[i] < self.efermi and tdos[i] > tol] above_fermi = [i for i in range(len(energies)) if energies[i] > self.efermi and tdos[i] > tol] @@ -80,6 +105,7 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: cbm_start = min(above_fermi) if vbm_start == cbm_start: return 0.0, self.efermi, self.efermi + # Interpolate between adjacent values terminal_dens = tdos[vbm_start : vbm_start + 2][::-1] terminal_energies = energies[vbm_start : vbm_start + 2][::-1] @@ -89,20 +115,26 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: end = get_linear_interpolated_value(terminal_dens, terminal_energies, tol) return end - start, end, start - def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin=None) -> tuple[float, float]: - """Expects a DOS object and finds the cbm and vbm. + def get_cbm_vbm( + self, + tol: float = 0.001, + abs_tol: bool = False, + spin: Spin | None = None, + ) -> tuple[float, float]: + """Find the CBM and VBM. Args: - tol: tolerance in occupations for determining the gap - abs_tol: An absolute tolerance (True) and a relative one (False) - spin: Possible values are None - finds the gap in the summed - densities, Up - finds the gap in the up spin channel, - Down - finds the gap in the down spin channel. + tol (float): Tolerance in occupations for determining the gap. + abs_tol (bool): Use absolute (True) or relative (False) tolerance. + spin (Spin | None): Find the gap: + - None: In the summed DOS. + - Up: In the spin up channel. + - Down: In the spin down channel. Returns: - tuple[float, float]: Energies in eV corresponding to the cbm and vbm. + tuple[float, float]: Energies in eV corresponding to the CBM and VBM. """ - # determine tolerance + # Determine tolerance if spin is None: tdos = self.y if len(self.ydim) == 1 else np.sum(self.y, axis=1) elif spin == Spin.up: @@ -113,73 +145,73 @@ def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin=None) -> t if not abs_tol: tol = tol * tdos.sum() / tdos.shape[0] - # find index of fermi energy + # Find index of Fermi level i_fermi = 0 while self.x[i_fermi] <= self.efermi: i_fermi += 1 - # work backwards until tolerance is reached + # Work backwards until tolerance is reached i_gap_start = i_fermi - while i_gap_start - 1 >= 0 and tdos[i_gap_start - 1] <= tol: + while i_gap_start >= 1 and tdos[i_gap_start - 1] <= tol: i_gap_start -= 1 - # work forwards until tolerance is reached + # Work forwards until tolerance is reached i_gap_end = i_gap_start while i_gap_end < tdos.shape[0] and tdos[i_gap_end] <= tol: i_gap_end += 1 i_gap_end -= 1 + return self.x[i_gap_end], self.x[i_gap_start] - def get_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None): - """Expects a DOS object and finds the gap. + def get_gap( + self, + tol: float = 0.001, + abs_tol: bool = False, + spin: Spin | None = None, + ) -> float: + """Find the band gap. Args: - tol: tolerance in occupations for determining the gap - abs_tol: An absolute tolerance (True) and a relative one (False) - spin: Possible values are None - finds the gap in the summed - densities, Up - finds the gap in the up spin channel, - Down - finds the gap in the down spin channel. + tol (float): Tolerance in occupations for determining the gap. + abs_tol (bool): Use absolute (True) or relative (False) tolerance. + spin (Spin | None): Find the gap: + - None: In the summed DOS. + - Up: In the spin up channel. + - Down: In the spin down channel. Returns: - gap in eV + float: Gap in eV. """ cbm, vbm = self.get_cbm_vbm(tol, abs_tol, spin) return max(cbm - vbm, 0.0) - def __str__(self) -> str: - """Get a string which can be easily plotted (using gnuplot).""" - if Spin.down in self.densities: - str_arr = [f"#{'Energy':30s} {'DensityUp':30s} {'DensityDown':30s}"] - for i, energy in enumerate(self.energies): - str_arr.append(f"{energy:.5f} {self.densities[Spin.up][i]:.5f} {self.densities[Spin.down][i]:.5f}") - else: - str_arr = [f"#{'Energy':30s} {'DensityUp':30s}"] - for i, energy in enumerate(self.energies): - str_arr.append(f"{energy:.5f} {self.densities[Spin.up][i]:.5f}") - return "\n".join(str_arr) - class Dos(MSONable): - """Basic DOS object. All other DOS objects are extended versions of this - object. + """Basic DOS object. All other DOS objects are extended versions of this. Attributes: - energies (Sequence[float]): The sequence of energies. - densities (dict[Spin, Sequence[float]]): A dict of spin densities, e.g. {Spin.up: [...], Spin.down: [...]}. - efermi (float): Fermi level. + energies (Sequence[float]): Energies. + densities (dict[Spin, NDArray): Spin densities, + e.g. {Spin.up: DOS_up, Spin.down: DOS_down}. + efermi (float): The Fermi level. """ def __init__( - self, efermi: float, energies: ArrayLike, densities: Mapping[Spin, ArrayLike], norm_vol: float | None = None + self, + efermi: float, + energies: Sequence[float], + densities: dict[Spin, NDArray], + norm_vol: float | None = None, ) -> None: """ Args: - efermi: Fermi level energy - energies: A sequences of energies - densities (dict[Spin: np.array]): representing the density of states for each Spin. - norm_vol: The volume used to normalize the densities. Defaults to 1 if None which will not perform any - normalization. If not None, the resulting density will have units of states/eV/Angstrom^3, otherwise - the density will be in states/eV. + efermi (float): The Fermi level. + energies (Sequence[float]): Energies. + densities (dict[Spin, NDArray]): The density of states for each Spin. + norm_vol (float | None): The volume used to normalize the DOS. + Defaults to 1 if None which will not perform any normalization. + If None, the result will be in unit of states/eV, + otherwise will be in states/eV/Angstrom^3. """ self.efermi = efermi self.energies = np.array(energies) @@ -187,58 +219,71 @@ def __init__( vol = norm_vol or 1 self.densities = {k: np.array(d) / vol for k, d in densities.items()} - def get_densities(self, spin: Spin | None = None): - """Get the density of states for a particular spin. + def __add__(self, other): + """Add two Dos. Args: - spin: Spin + other (Dos): Another Dos object. + + Raises: + ValueError: If energy scales are different. Returns: - Returns the density of states for a particular spin. If Spin is - None, the sum of all spins is returned. + Sum of the two Dos. """ - if self.densities is None: - result = None - elif spin is None: - if Spin.down in self.densities: - result = self.densities[Spin.up] + self.densities[Spin.down] - else: - result = self.densities[Spin.up] + if not all(np.equal(self.energies, other.energies)): + raise ValueError("Energies of both DOS are not compatible!") + + densities = {spin: self.densities[spin] + other.densities[spin] for spin in self.densities} + return type(self)(self.efermi, self.energies, densities) + + def __str__(self) -> str: + """A string which can be easily plotted.""" + if Spin.down in self.densities: + str_arr = [f"#{'Energy':30s} {'DensityUp':30s} {'DensityDown':30s}"] + for idx, energy in enumerate(self.energies): + str_arr.append(f"{energy:.5f} {self.densities[Spin.up][idx]:.5f} {self.densities[Spin.down][idx]:.5f}") + else: - result = self.densities[spin] - return result + str_arr = [f"#{'Energy':30s} {'DensityUp':30s}"] + for idx, energy in enumerate(self.energies): + str_arr.append(f"{energy:.5f} {self.densities[Spin.up][idx]:.5f}") + + return "\n".join(str_arr) - def get_smeared_densities(self, sigma: float): - """Get the Dict representation of the densities, {Spin: densities}, - but with a Gaussian smearing of std dev sigma. + def get_densities(self, spin: Spin | None = None) -> None | NDArray: + """Get the DOS for a particular spin. Args: - sigma: Std dev of Gaussian smearing function. + spin (Spin): Spin. Returns: - Dict of Gaussian-smeared densities. + NDArray: The DOS for the particular spin. Or the sum of both spins + if Spin is None. """ - smeared_dens = {} - diff = [self.energies[i + 1] - self.energies[i] for i in range(len(self.energies) - 1)] - avg_diff = sum(diff) / len(diff) - for spin, dens in self.densities.items(): - smeared_dens[spin] = gaussian_filter1d(dens, sigma / avg_diff) - return smeared_dens + if self.densities is None: + return None - def __add__(self, other): - """Add two DOS together. Checks that energy scales are the same. - Otherwise, a ValueError is thrown. + if spin is not None: + return self.densities[spin] + + if Spin.down in self.densities: + return self.densities[Spin.up] + self.densities[Spin.down] + + return self.densities[Spin.up] + + def get_smeared_densities(self, sigma: float) -> dict[Spin, NDArray]: + """Get the the DOS with a Gaussian smearing. Args: - other: Another DOS object. + sigma (float): Standard deviation of Gaussian smearing. Returns: - Sum of the two DOSs. + {Spin: NDArray}: Gaussian-smeared DOS by spin. """ - if not all(np.equal(self.energies, other.energies)): - raise ValueError("Energies of both DOS are not compatible!") - densities = {spin: self.densities[spin] + other.densities[spin] for spin in self.densities} - return Dos(self.efermi, self.energies, densities) + diff = [self.energies[idx + 1] - self.energies[idx] for idx in range(len(self.energies) - 1)] + avg_diff = sum(diff) / len(diff) + return {spin: gaussian_filter1d(dens, sigma / avg_diff) for spin, dens in self.densities.items()} def get_interpolated_value(self, energy: float) -> dict[Spin, float]: """Get interpolated density for a particular energy. @@ -254,23 +299,31 @@ def get_interpolated_value(self, energy: float) -> dict[Spin, float]: energies[spin] = get_linear_interpolated_value(self.energies, self.densities[spin], energy) return energies - def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None) -> Tuple3Floats: - """Expects a DOS object and finds the gap. + def get_interpolated_gap( + self, + tol: float = 0.001, + abs_tol: bool = False, + spin: Spin | None = None, + ) -> Tuple3Floats: + """Find the interpolated band gap. Args: - tol: tolerance in occupations for determining the gap - abs_tol: Set to True for an absolute tolerance and False for a - relative one. - spin: Possible values are None - finds the gap in the summed - densities, Up - finds the gap in the up spin channel, - Down - finds the gap in the down spin channel. + tol (float): Tolerance in occupations for determining the band gap. + abs_tol (bool): Use absolute (True) or relative (False) tolerance. + spin (Spin | None): Find the gap: + None - In the summed DOS. + Up - In the spin up channel. + Down - In the spin down channel. Returns: - tuple[float, float, float]: Energies in eV corresponding to the band gap, cbm and vbm. + tuple[float, float, float]: Energies in eV corresponding to the + band gap, CBM and VBM. """ tdos = self.get_densities(spin) + assert tdos is not None if not abs_tol: tol = tol * tdos.sum() / tdos.shape[0] + energies = self.energies below_fermi = [i for i in range(len(energies)) if energies[i] < self.efermi and tdos[i] > tol] above_fermi = [i for i in range(len(energies)) if energies[i] > self.efermi and tdos[i] > tol] @@ -283,84 +336,88 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: terminal_dens = tdos[vbm_start : vbm_start + 2][::-1] terminal_energies = energies[vbm_start : vbm_start + 2][::-1] start = get_linear_interpolated_value(terminal_dens, terminal_energies, tol) + terminal_dens = tdos[cbm_start - 1 : cbm_start + 1] terminal_energies = energies[cbm_start - 1 : cbm_start + 1] end = get_linear_interpolated_value(terminal_dens, terminal_energies, tol) + return end - start, end, start - def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None) -> tuple[float, float]: - """Expects a DOS object and finds the cbm and vbm. + def get_cbm_vbm( + self, + tol: float = 0.001, + abs_tol: bool = False, + spin: Spin | None = None, + ) -> tuple[float, float]: + """Find the conduction band minimum (CBM) and valence band maximum (VBM). Args: - tol: tolerance in occupations for determining the gap - abs_tol: An absolute tolerance (True) and a relative one (False) - spin: Possible values are None - finds the gap in the summed - densities, Up - finds the gap in the up spin channel, - Down - finds the gap in the down spin channel. + tol (float): Tolerance in occupations for determining the gap. + abs_tol (bool): Use absolute (True) or relative (False) tolerance. + spin (Spin | None): Find the gap: + None - In the summed DOS. + Up - In the spin up channel. + Down - In the spin down channel. Returns: - tuple[float, float]: Energies in eV corresponding to the cbm and vbm. + tuple[float, float]: Energies in eV corresponding to the CBM and VBM. """ - # determine tolerance + # Determine tolerance tdos = self.get_densities(spin) + assert tdos is not None if not abs_tol: tol = tol * tdos.sum() / tdos.shape[0] - # find index of fermi energy + # Find index of Fermi level i_fermi = 0 while self.energies[i_fermi] <= self.efermi: i_fermi += 1 - # work backwards until tolerance is reached + # Work backwards until tolerance is reached i_gap_start = i_fermi - while i_gap_start - 1 >= 0 and tdos[i_gap_start - 1] <= tol: + while i_gap_start >= 1 and tdos[i_gap_start - 1] <= tol: i_gap_start -= 1 - # work forwards until tolerance is reached + # Work forwards until tolerance is reached i_gap_end = i_gap_start while i_gap_end < tdos.shape[0] and tdos[i_gap_end] <= tol: i_gap_end += 1 i_gap_end -= 1 + return self.energies[i_gap_end], self.energies[i_gap_start] - def get_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None): - """Expects a DOS object and finds the gap. + def get_gap( + self, + tol: float = 0.001, + abs_tol: bool = False, + spin: Spin | None = None, + ) -> float: + """Find the band gap. Args: - tol: tolerance in occupations for determining the gap - abs_tol: An absolute tolerance (True) and a relative one (False) - spin: Possible values are None - finds the gap in the summed - densities, Up - finds the gap in the up spin channel, - Down - finds the gap in the down spin channel. + tol (float): Tolerance in occupations for determining the band gap. + abs_tol (bool): Use absolute (True) or relative (False) tolerance. + spin (Spin | None): Find the band gap: + None - In the summed DOS. + Up - In the spin up channel. + Down - In the spin down channel. Returns: - gap in eV + float: Band gap in eV. """ cbm, vbm = self.get_cbm_vbm(tol, abs_tol, spin) return max(cbm - vbm, 0.0) - def __str__(self) -> str: - """Get a string which can be easily plotted (using gnuplot).""" - if Spin.down in self.densities: - str_arr = [f"#{'Energy':30s} {'DensityUp':30s} {'DensityDown':30s}"] - for i, energy in enumerate(self.energies): - str_arr.append(f"{energy:.5f} {self.densities[Spin.up][i]:.5f} {self.densities[Spin.down][i]:.5f}") - else: - str_arr = [f"#{'Energy':30s} {'DensityUp':30s}"] - for i, energy in enumerate(self.energies): - str_arr.append(f"{energy:.5f} {self.densities[Spin.up][i]:.5f}") - return "\n".join(str_arr) - @classmethod def from_dict(cls, dct: dict) -> Self: - """Get Dos object from dict representation of Dos.""" + """Get Dos from a dict representation.""" return cls( dct["efermi"], dct["energies"], {Spin(int(k)): v for k, v in dct["densities"].items()}, ) - def as_dict(self) -> dict: + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of Dos.""" return { "@module": type(self).__module__, @@ -372,11 +429,12 @@ def as_dict(self) -> dict: class FermiDos(Dos, MSONable): - """This wrapper class helps relate the density of states, doping levels - (i.e. carrier concentrations) and corresponding fermi levels. A negative - doping concentration indicates the majority carriers are electrons - (n-type doping); a positive doping concentration indicates holes are the - majority carriers (p-type doping). + """Relate the density of states, doping levels + (i.e. carrier concentrations) and corresponding Fermi levels. + + A negative doping concentration indicates the majority carriers are + electrons (N-type); a positive doping concentration indicates holes + are the majority carriers (P-type). """ def __init__( @@ -388,15 +446,15 @@ def __init__( ) -> None: """ Args: - dos: Pymatgen Dos object. - structure: A structure. If not provided, the structure - of the dos object will be used. If the dos does not have an - associated structure object, an error will be thrown. - nelecs: The number of electrons included in the energy range of - dos. It is used for normalizing the densities. Default is the total - number of electrons in the structure. - bandgap: If set, the energy values are scissored so that the electronic - band gap matches this value. + dos (Dos): Pymatgen Dos object. + structure (Structure): A structure. If None, the Structure + of the Dos will be used. If the Dos does not have an + associated Structure, an ValueError will be raised. + nelecs (float): The number of electrons included in the energy range of + Dos. It is used for normalizing the DOS. Default None to + the total number of electrons in the structure. + bandgap (float): If not None, the energy values are scissored so that + the electronic band gap matches this value. """ super().__init__( dos.efermi, @@ -417,7 +475,7 @@ def __init__( self.energies = np.array(dos.energies) self.de = np.hstack((self.energies[1:], self.energies[-1])) - self.energies - # normalize total density of states based on integral at 0K + # Normalize total density of states based on integral at 0 K tdos = np.array(self.get_densities()) self.tdos = tdos * self.nelecs / (tdos * self.de)[self.energies <= self.efermi].sum() @@ -432,7 +490,7 @@ def __init__( idx_fermi = int(np.argmin(abs(self.energies - eref))) if idx_fermi == self.idx_vbm: - # Fermi level and vbm should be different indices + # Fermi level and VBM should have different indices idx_fermi += 1 self.energies[:idx_fermi] -= (bandgap - (ecbm - evbm)) / 2.0 @@ -440,19 +498,19 @@ def __init__( def get_doping(self, fermi_level: float, temperature: float) -> float: """Calculate the doping (majority carrier concentration) at a given - Fermi level and temperature. A simple Left Riemann sum is used for + Fermi level and temperature. A simple Left Riemann sum is used for integrating the density of states over energy & equilibrium Fermi-Dirac distribution. Args: - fermi_level: The fermi_level level in eV. - temperature: The temperature in Kelvin. + fermi_level (float): The Fermi level in eV. + temperature (float): The temperature in Kelvin. Returns: - The doping concentration in units of 1/cm^3. Negative values - indicate that the majority carriers are electrons (n-type doping) - whereas positive values indicates the majority carriers are holes - (p-type doping). + float: The doping concentration in units of 1/cm^3. Negative values + indicate that the majority carriers are electrons (N-type), + whereas positive values indicates the majority carriers are holes + (P-type). """ cb_integral = np.sum( self.tdos[self.idx_cbm :] @@ -468,27 +526,74 @@ def get_doping(self, fermi_level: float, temperature: float) -> float: ) return (vb_integral - cb_integral) / (self.volume * self.A_to_cm**3) + def get_fermi( + self, + concentration: float, + temperature: float, + rtol: float = 0.01, + nstep: int = 50, + step: float = 0.1, + precision: int = 8, + ) -> float: + """Find the Fermi level at which the doping concentration at the given + temperature (T) is equal to concentration. An algorithm is used + where the relative error is minimized by calculating the doping at a + grid which continually becomes finer. + + Args: + concentration (float): The doping concentration in 1/cm^3. Negative + values represent N-type doping and positive values represent P-type. + temperature (float): The temperature in Kelvin. + rtol (float): The maximum acceptable relative error. + nstep (int): The number of steps checked around a given Fermi level. + step (float): The initial Energy step length when searching. + precision (int): The decimal places of calculated Fermi level. + + Raises: + ValueError: If the Fermi level cannot be found. + + Returns: + float: The Fermi level in eV. Note that this is different from + the default Dos.efermi. + """ + fermi = self.efermi # initialize target Fermi + relative_error = [float("inf")] + for _ in range(precision): + fermi_range = np.arange(-nstep, nstep + 1) * step + fermi + calc_doping = np.array([self.get_doping(fermi_lvl, temperature) for fermi_lvl in fermi_range]) + relative_error = np.abs(calc_doping / concentration - 1.0) + fermi = fermi_range[np.argmin(relative_error)] + step /= 10.0 + + if min(relative_error) > rtol: + raise ValueError(f"Could not find fermi within {rtol:.1%} of {concentration=}") + return fermi + def get_fermi_interextrapolated( - self, concentration: float, temperature: float, warn: bool = True, c_ref: float = 1e10, **kwargs + self, + concentration: float, + temperature: float, + warn: bool = True, + c_ref: float = 1e10, + **kwargs, ) -> float: - """Similar to get_fermi except that when get_fermi fails to converge, - an interpolated or extrapolated fermi is returned with the assumption - that the Fermi level changes linearly with log(abs(concentration)). + """Similar to get_fermi method except that when it fails to converge, an + interpolated or extrapolated Fermi level is returned, with the assumption + that the Fermi level changes linearly with log(abs(concentration)), + and therefore must be used with caution. Args: - concentration: The doping concentration in 1/cm^3. Negative values - represent n-type doping and positive values represent p-type - doping. - temperature: The temperature in Kelvin. - warn: Whether to give a warning the first time the fermi cannot be - found. - c_ref: A doping concentration where get_fermi returns a + concentration (float): The doping concentration in 1/cm^3. Negative + value represents N-type doping and positive value represents P-type. + temperature (float): The temperature in Kelvin. + warn (bool): Whether to give a warning the first time the Fermi level + cannot be found. + c_ref (float): A doping concentration where get_fermi returns a value without error for both c_ref and -c_ref. **kwargs: Keyword arguments passed to the get_fermi function. Returns: - The Fermi level. Note, the value is possibly interpolated or - extrapolated and must be used with caution. + float: The possibly interpolated or extrapolated Fermi level. """ try: return self.get_fermi(concentration, temperature, **kwargs) @@ -500,8 +605,7 @@ def get_fermi_interextrapolated( if abs(concentration) < 1e-10: concentration = 1e-10 - # max(10, ) is to avoid log(0 float: - """Find the Fermi level at which the doping concentration at the given - temperature (T) is equal to concentration. A greedy algorithm is used - where the relative error is minimized by calculating the doping at a - grid which continually becomes finer. - - Args: - concentration: The doping concentration in 1/cm^3. Negative values - represent n-type doping and positive values represent p-type - doping. - temperature: The temperature in Kelvin. - rtol: The maximum acceptable relative error. - nstep: The number of steps checked around a given Fermi level. - step: Initial step in energy when searching for the Fermi level. - precision: Essentially the decimal places of calculated Fermi level. - - Raises: - ValueError: If the Fermi level cannot be found. - - Returns: - The Fermi level in eV. Note that this is different from the default - dos.efermi. - """ - fermi = self.efermi # initialize target fermi - relative_error = [float("inf")] - for _ in range(precision): - fermi_range = np.arange(-nstep, nstep + 1) * step + fermi - calc_doping = np.array([self.get_doping(fermi_lvl, temperature) for fermi_lvl in fermi_range]) - relative_error = np.abs(calc_doping / concentration - 1.0) - fermi = fermi_range[np.argmin(relative_error)] - step /= 10.0 - - if min(relative_error) > rtol: - raise ValueError(f"Could not find fermi within {rtol:.1%} of {concentration=}") - return fermi - @classmethod def from_dict(cls, dct: dict) -> Self: - """Get Dos object from dict representation of Dos.""" + """Get FermiDos object from a dict representation.""" dos = Dos( dct["efermi"], dct["energies"], @@ -574,8 +634,8 @@ def from_dict(cls, dct: dict) -> Self: ) return cls(dos, structure=Structure.from_dict(dct["structure"]), nelecs=dct["nelecs"]) - def as_dict(self) -> dict: - """JSON-serializable dict representation of Dos.""" + def as_dict(self) -> dict[str, Any]: + """JSON-serializable dict representation of FermiDos.""" return { "@module": type(self).__module__, "@class": type(self).__name__, @@ -587,32 +647,42 @@ def as_dict(self) -> dict: } +class FingerPrint(NamedTuple): + """The DOS fingerprint.""" + + energies: NDArray + densities: NDArray + type: str + n_bins: int + bin_width: float + + class CompleteDos(Dos): - """This wrapper class defines a total dos, and also provides a list of PDos. - Mainly used by pymatgen.io.vasp.Vasprun to create a complete Dos from - a vasprun.xml file. You are unlikely to try to generate this object - manually. + """Define total DOS, and projected DOS (PDOS). + + Mainly used by pymatgen.io.vasp.Vasprun to create a complete DOS from + a vasprun.xml file. You are unlikely to generate this object manually. Attributes: structure (Structure): Structure associated with the CompleteDos. - pdos (dict): Dict of partial densities of the form {Site:{Orbital:{Spin:Densities}}}. + pdos (dict[PeriodicSite, dict[Orbital, dict[Spin, NDArray]]]): PDOS. """ def __init__( self, structure: Structure, total_dos: Dos, - pdoss: Mapping[PeriodicSite, Mapping[Orbital, Mapping[Spin, ArrayLike]]], + pdoss: dict[PeriodicSite, dict[Orbital, dict[Spin, NDArray]]], normalize: bool = False, ) -> None: """ Args: - structure: Structure associated with this particular DOS. - total_dos: total Dos for structure - pdoss: The pdoss are supplied as an {Site: {Orbital: {Spin:Densities}}} - normalize: Whether to normalize the densities by the volume of the structure. - If True, the units of the densities are states/eV/Angstrom^3. Otherwise, - the units are states/eV. + structure (Structure): Structure associated with this DOS. + total_dos (Dos): Total DOS for the structure. + pdoss (dict): The PDOSs supplied as {Site: {Orbital: {Spin: Densities}}}. + normalize (bool): Whether to normalize the DOS by the volume of + the structure. If True, the units of the DOS are states/eV/Angstrom^3. + Otherwise, the units are states/eV. """ vol = structure.volume if normalize else None super().__init__( @@ -624,11 +694,15 @@ def __init__( self.pdos = pdoss self.structure = structure - def get_normalized(self) -> CompleteDos: - """Get a normalized version of the CompleteDos.""" + def __str__(self) -> str: + return f"Complete DOS for {self.structure}" + + def get_normalized(self) -> Self: + """Get normalized CompleteDos.""" if self.norm_vol is not None: return self - return CompleteDos( + + return type(self)( structure=self.structure, total_dos=self, pdoss=self.pdos, @@ -643,30 +717,30 @@ def get_site_orbital_dos(self, site: PeriodicSite, orbital: Orbital) -> Dos: orbital: Orbital in the site. Returns: - Dos containing densities for orbital of site. + Dos: for a particular orbital of a particular site. """ return Dos(self.efermi, self.energies, self.pdos[site][orbital]) def get_site_dos(self, site: PeriodicSite) -> Dos: - """Get the total Dos for a site (all orbitals). + """Get the total DOS for a site with all orbitals. Args: - site: Site in Structure associated with CompleteDos. + site (PeriodicSite): Site in Structure associated with CompleteDos. Returns: - Dos containing summed orbital densities for site. + Dos: Total DOS for a site with all orbitals. """ site_dos = functools.reduce(add_densities, self.pdos[site].values()) return Dos(self.efermi, self.energies, site_dos) def get_site_spd_dos(self, site: PeriodicSite) -> dict[OrbitalType, Dos]: - """Get orbital projected Dos of a particular site. + """Get orbital projected DOS of a particular site. Args: - site: Site in Structure associated with CompleteDos. + site (PeriodicSite): Site in Structure associated with CompleteDos. Returns: - dict of {OrbitalType: Dos}, e.g. {OrbitalType.s: Dos object, ...} + dict[OrbitalType, Dos] """ spd_dos: dict[OrbitalType, dict[Spin, np.ndarray]] = {} for orb, pdos in self.pdos[site].items(): @@ -677,14 +751,17 @@ def get_site_spd_dos(self, site: PeriodicSite) -> dict[OrbitalType, Dos]: spd_dos[orbital_type] = pdos # type: ignore[assignment] return {orb: Dos(self.efermi, self.energies, densities) for orb, densities in spd_dos.items()} - def get_site_t2g_eg_resolved_dos(self, site: PeriodicSite) -> dict[str, Dos]: - """Get the t2g, eg projected DOS for a particular site. + def get_site_t2g_eg_resolved_dos( + self, + site: PeriodicSite, + ) -> dict[Literal["e_g", "t2g"], Dos]: + """Get the t2g/e_g projected DOS for a particular site. Args: - site: Site in Structure associated with CompleteDos. + site (PeriodicSite): Site in Structure associated with CompleteDos. Returns: - dict[str, Dos]: A dict {"e_g": Dos, "t2g": Dos} containing summed e_g and t2g DOS for the site. + dict[Literal["e_g", "t2g"], Dos]: Summed e_g and t2g DOS for the site. """ t2g_dos = [] eg_dos = [] @@ -701,10 +778,10 @@ def get_site_t2g_eg_resolved_dos(self, site: PeriodicSite) -> dict[str, Dos]: } def get_spd_dos(self) -> dict[OrbitalType, Dos]: - """Get orbital projected Dos. + """Get orbital projected DOS. Returns: - dict[OrbitalType, Dos]: e.g. {OrbitalType.s: Dos object, ...} + dict[OrbitalType, Dos] """ spd_dos = {} for atom_dos in self.pdos.values(): @@ -717,29 +794,27 @@ def get_spd_dos(self) -> dict[OrbitalType, Dos]: return {orb: Dos(self.efermi, self.energies, densities) for orb, densities in spd_dos.items()} def get_element_dos(self) -> dict[SpeciesLike, Dos]: - """Get element projected Dos. + """Get element projected DOS. Returns: dict[Element, Dos] """ - el_dos = {} + el_dos: dict[SpeciesLike, dict[Spin, NDArray]] = {} for site, atom_dos in self.pdos.items(): el = site.specie for pdos in atom_dos.values(): - if el not in el_dos: - el_dos[el] = pdos - else: - el_dos[el] = add_densities(el_dos[el], pdos) + el_dos[el] = add_densities(el_dos[el], pdos) if el in el_dos else pdos + return {el: Dos(self.efermi, self.energies, densities) for el, densities in el_dos.items()} def get_element_spd_dos(self, el: SpeciesLike) -> dict[OrbitalType, Dos]: - """Get element and spd projected Dos. + """Get element and orbital (spd) projected DOS. Args: - el: Element in Structure.composition associated with CompleteDos + el (SpeciesLike): Element associated with CompleteDos. Returns: - dict[OrbitalType, Dos]: e.g. {OrbitalType.s: Dos object, ...} + dict[OrbitalType, Dos] """ el = get_el_sp(el) el_dos = {} @@ -756,24 +831,25 @@ def get_element_spd_dos(self, el: SpeciesLike) -> dict[OrbitalType, Dos]: @property def spin_polarization(self) -> float | None: - """Calculate spin polarization at Fermi level. If the - calculation is not spin-polarized, None will be returned. + """Spin polarization at Fermi level. - See Sanvito et al., doi: 10.1126/sciadv.1602241 for an example usage. + Examples: + See Sanvito et al., DOI: 10.1126/sciadv.1602241 for an example usage. Returns: - float: spin polarization in range [0, 1], will also return NaN if spin - polarization ill-defined (e.g. for insulator). + float | None: Spin polarization in range [0, 1], will return NaN + if spin polarization is ill-defined (e.g. for insulator). + None if the calculation is not spin-polarized. """ n_F = self.get_interpolated_value(self.efermi) n_F_up = n_F[Spin.up] if Spin.down not in n_F: return None - n_F_down = n_F[Spin.down] + n_F_down = n_F[Spin.down] + # Only well defined for metals or half-metals if (n_F_up + n_F_down) == 0: - # only well defined for metals or half-metals return float("NaN") spin_polarization = (n_F_up - n_F_down) / (n_F_up + n_F_down) @@ -790,37 +866,42 @@ def get_band_filling( """Compute the orbital-projected band filling, defined as the zeroth moment up to the Fermi level. + "elements" and "sites" cannot be used together. + Args: - band: Orbital type to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. + band (OrbitalType): Orbital to get the band center of (default is d-band). + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. If None, both spin channels will be combined. Returns: - float: band filling in eV, often denoted f_d for the d-band + float: Band filling in eV, often denoted f_d for the d-band. """ # Get the projected DOS if elements and sites: raise ValueError("Both element and site cannot be specified.") - densities: dict[Spin, ArrayLike] = {} + densities: dict[Spin, NDArray] = {} if elements: for idx, el in enumerate(elements): spd_dos = self.get_element_spd_dos(el)[band] densities = spd_dos.densities if idx == 0 else add_densities(densities, spd_dos.densities) dos = Dos(self.efermi, self.energies, densities) + elif sites: for idx, site in enumerate(sites): spd_dos = self.get_site_spd_dos(site)[band] densities = spd_dos.densities if idx == 0 else add_densities(densities, spd_dos.densities) dos = Dos(self.efermi, self.energies, densities) + else: dos = self.get_spd_dos()[band] energies = dos.energies - dos.efermi dos_densities = dos.get_densities(spin=spin) + assert dos_densities is not None - # Only consider up to Fermi level in numerator + # Only integrate up to Fermi level energies = dos.energies - dos.efermi return np.trapz(dos_densities[energies < 0], x=energies[energies < 0]) / np.trapz(dos_densities, x=energies) @@ -830,26 +911,29 @@ def get_band_center( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, spin: Spin | None = None, - erange: list[float] | None = None, + erange: tuple[float, float] | None = None, ) -> float: """Compute the orbital-projected band center, defined as the first moment - relative to the Fermi level - int_{-inf}^{+inf} rho(E)*E dE/int_{-inf}^{+inf} rho(E) dE - based on the work of Hammer and Norskov, Surf. Sci., 343 (1995) where the - limits of the integration can be modified by erange and E is the set - of energies taken with respect to the Fermi level. Note that the band center - is often highly sensitive to the selected erange. + relative to the Fermi level as: + int_{-inf}^{+inf} rho(E)*E dE/int_{-inf}^{+inf} rho(E) dE. + + Note that the band center is often highly sensitive to the selected energy range. + + "elements" and "sites" cannot be used together. + + References: + Hammer and Norskov, Surf. Sci., 343 (1995). Args: - band: Orbital type to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. - erange: [min, max] energy range to consider, with respect to the Fermi level. - Default is None, which means all energies are considered. + band (OrbitalType): Orbital to get the band center of (default is d-band) + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. If None, both spin channels will be combined. + erange (tuple(min, max)): The energy range to consider, with respect to the + Fermi level. Default to None for all energies. Returns: - float: band center in eV, often denoted epsilon_d for the d-band center + float: The band center in eV, often denoted epsilon_d for the d-band center. """ return self.get_n_moment(1, elements=elements, sites=sites, band=band, spin=spin, erange=erange, center=False) @@ -859,24 +943,27 @@ def get_band_width( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, spin: Spin | None = None, - erange: list[float] | None = None, + erange: tuple[float, float] | None = None, ) -> float: - """Get the orbital-projected band width, defined as the square root of the second moment + """Get the orbital-projected band width, defined as the square root + of the second moment: sqrt(int_{-inf}^{+inf} rho(E)*(E-E_center)^2 dE/int_{-inf}^{+inf} rho(E) dE) - where E_center is the orbital-projected band center, the limits of the integration can be - modified by erange, and E is the set of energies taken with respect to the Fermi level. - Note that the band width is often highly sensitive to the selected erange. + where E_center is the orbital-projected band center. + + Note that the band width is often highly sensitive to the selected energy range. + + "elements" and "sites" cannot be used together. Args: - band: Orbital type to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. - erange: [min, max] energy range to consider, with respect to the Fermi level. - Default is None, which means all energies are considered. + band (OrbitalType): Orbital to get the band center of (default is d-band). + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. By default, both spin channels will be combined. + erange (tuple(min, max)): The energy range to consider, with respect to the + Fermi level. Default to None for all energies. Returns: - float: Orbital-projected band width in eV + float: Orbital-projected band width in eV. """ return np.sqrt(self.get_n_moment(2, elements=elements, sites=sites, band=band, spin=spin, erange=erange)) @@ -886,26 +973,28 @@ def get_band_skewness( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, spin: Spin | None = None, - erange: list[float] | None = None, + erange: tuple[float, float] | None = None, ) -> float: - """Get the orbital-projected skewness, defined as the third standardized moment + """Get the orbital-projected skewness, defined as the third standardized moment: int_{-inf}^{+inf} rho(E)*(E-E_center)^3 dE/int_{-inf}^{+inf} rho(E) dE) / (int_{-inf}^{+inf} rho(E)*(E-E_center)^2 dE/int_{-inf}^{+inf} rho(E) dE))^(3/2) - where E_center is the orbital-projected band center, the limits of the integration can be - modified by erange, and E is the set of energies taken with respect to the Fermi level. - Note that the skewness is often highly sensitive to the selected erange. + where E_center is the orbital-projected band center. + + Note that the skewness is often highly sensitive to the selected energy range. + + "elements" and "sites" cannot be used together. Args: - band: Orbitals to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. - erange: [min, max] energy range to consider, with respect to the Fermi level. - Default is None, which means all energies are considered. + band (OrbitalType): Orbitals to get the band center of (default is d-band). + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. By default, both spin channels will be combined. + erange (tuple(min, max)): The energy range to consider, with respect to the + Fermi level. Default to None for all energies. Returns: - float: orbital-projected skewness (dimensionless) + float: The orbital-projected skewness (dimensionless). """ kwds: dict = dict(elements=elements, sites=sites, band=band, spin=spin, erange=erange) return self.get_n_moment(3, **kwds) / self.get_n_moment(2, **kwds) ** (3 / 2) @@ -916,26 +1005,28 @@ def get_band_kurtosis( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, spin: Spin | None = None, - erange: list[float] | None = None, + erange: tuple[float, float] | None = None, ) -> float: """Get the orbital-projected kurtosis, defined as the fourth standardized moment int_{-inf}^{+inf} rho(E)*(E-E_center)^4 dE/int_{-inf}^{+inf} rho(E) dE) / (int_{-inf}^{+inf} rho(E)*(E-E_center)^2 dE/int_{-inf}^{+inf} rho(E) dE))^2 - where E_center is the orbital-projected band center, the limits of the integration can be - modified by erange, and E is the set of energies taken with respect to the Fermi level. - Note that the skewness is often highly sensitive to the selected erange. + where E_center is the orbital-projected band center. + + Note that the kurtosis is often highly sensitive to the selected energy range. + + "elements" and "sites" cannot be used together. Args: - band: Orbital type to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. - erange: [min, max] energy range to consider, with respect to the Fermi level. - Default is None, which means all energies are considered. + band (OrbitalType): Orbitals to get the band center of (default is d-band). + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. By default, both spin channels will be combined. + erange (tuple(min, max)): The energy range to consider, with respect to the + Fermi level. Default to None for all energies. Returns: - float: orbital-projected kurtosis (dimensionless) + float: The orbital-projected kurtosis (dimensionless). """ kwds: dict = dict(elements=elements, sites=sites, band=band, spin=spin, erange=erange) return self.get_n_moment(4, **kwds) / self.get_n_moment(2, **kwds) ** 2 @@ -947,24 +1038,26 @@ def get_n_moment( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, spin: Spin | None = None, - erange: list[float] | None = None, + erange: tuple[float, float] | None = None, center: bool = True, ) -> float: - """Get the nth moment of the DOS centered around the orbital-projected band center, defined as + """Get the nth moment of the DOS centered around the orbital-projected + band center, defined as: int_{-inf}^{+inf} rho(E)*(E-E_center)^n dE/int_{-inf}^{+inf} rho(E) dE - where n is the order, E_center is the orbital-projected band center, the limits of the integration can be - modified by erange, and E is the set of energies taken with respect to the Fermi level. If center is False, - then the E_center reference is not used. + where n is the order, E_center is the orbital-projected band center, and + E is the set of energies taken with respect to the Fermi level. + + "elements" and "sites" cannot be used together. Args: - n: The order for the moment - band: Orbital type to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. - erange: [min, max] energy range to consider, with respect to the Fermi level. - Default is None, which means all energies are considered. - center: Take moments with respect to the band center + n (int): The order for the moment. + band (OrbitalType): Orbital to get the band center of (default is d-band). + elements (list[PeriodicSite]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. By default, both spin channels will be combined. + erange (tuple(min, max)): The energy range to consider, with respect to the + Fermi level. Default to None for all energies. + center (bool): Take moments with respect to the band center. Returns: Orbital-projected nth moment in eV @@ -973,24 +1066,27 @@ def get_n_moment( if elements and sites: raise ValueError("Both element and site cannot be specified.") - densities: Mapping[Spin, ArrayLike] = {} + densities: dict[Spin, NDArray] = {} if elements: for idx, el in enumerate(elements): spd_dos = self.get_element_spd_dos(el)[band] densities = spd_dos.densities if idx == 0 else add_densities(densities, spd_dos.densities) dos = Dos(self.efermi, self.energies, densities) + elif sites: for idx, site in enumerate(sites): spd_dos = self.get_site_spd_dos(site)[band] densities = spd_dos.densities if idx == 0 else add_densities(densities, spd_dos.densities) dos = Dos(self.efermi, self.energies, densities) + else: dos = self.get_spd_dos()[band] energies = dos.energies - dos.efermi dos_densities = dos.get_densities(spin=spin) + assert dos_densities is not None - # Only consider a given erange, if desired + # Only consider a given energy range if erange: dos_densities = dos_densities[(energies >= erange[0]) & (energies <= erange[1])] energies = energies[(energies >= erange[0]) & (energies <= erange[1])] @@ -1011,32 +1107,36 @@ def get_hilbert_transform( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, ) -> Dos: - """Return the Hilbert transform of the orbital-projected density of states, + """Get the Hilbert transform of the orbital-projected DOS, often plotted for a Newns-Anderson analysis. + "elements" and "sites" cannot be used together. + Args: - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - band: Orbitals to get the band center of (default is d-band) + band (OrbitalType): Orbital to get the band center of (default is d-band). + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. Returns: - Hilbert transformation of the projected DOS. + Dos: Hilbert transformation of the projected DOS. """ # Get the projected DOS if elements and sites: raise ValueError("Both element and site cannot be specified.") - densities: Mapping[Spin, ArrayLike] = {} + densities: dict[Spin, NDArray] = {} if elements: for idx, el in enumerate(elements): spd_dos = self.get_element_spd_dos(el)[band] densities = spd_dos.densities if idx == 0 else add_densities(densities, spd_dos.densities) dos = Dos(self.efermi, self.energies, densities) + elif sites: for idx, site in enumerate(sites): spd_dos = self.get_site_spd_dos(site)[band] densities = spd_dos.densities if idx == 0 else add_densities(densities, spd_dos.densities) dos = Dos(self.efermi, self.energies, densities) + else: dos = self.get_spd_dos()[band] @@ -1053,30 +1153,35 @@ def get_upper_band_edge( elements: list[SpeciesLike] | None = None, sites: list[PeriodicSite] | None = None, spin: Spin | None = None, - erange: list[float] | None = None, + erange: tuple[float, float] | None = None, ) -> float: - """Get the orbital-projected upper band edge. The definition by Xin et al. - Phys. Rev. B, 89, 115114 (2014) is used, which is the highest peak position of the - Hilbert transform of the orbital-projected DOS. + """Get the orbital-projected upper band edge. + + The definition by Xin et al. Phys. Rev. B, 89, 115114 (2014) is used, + which is the highest peak position of the Hilbert transform of + the orbital-projected DOS. + + "elements" and "sites" cannot be used together. Args: - band: Orbital type to get the band center of (default is d-band) - elements: Elements to get the band center of (cannot be used in conjunction with site) - sites: Sites to get the band center of (cannot be used in conjunction with el) - spin: Spin channel to use. By default, the spin channels will be combined. - erange: [min, max] energy range to consider, with respect to the Fermi level. - Default is None, which means all energies are considered. + band (OrbitalType): Orbital to get the band center of (default is d-band). + elements (list[SpeciesLike]): Elements to get the band center of. + sites (list[PeriodicSite]): Sites to get the band center of. + spin (Spin): Spin channel to use. Both spin channels will be combined by default. + erange (tuple(min, max)): The energy range to consider, with respect to the + Fermi level. Default to None for all energies. Returns: - Upper band edge in eV, often denoted epsilon_u + float: Upper band edge in eV, often denoted epsilon_u. """ # Get the Hilbert-transformed DOS transformed_dos = self.get_hilbert_transform(elements=elements, sites=sites, band=band) energies = transformed_dos.energies - transformed_dos.efermi densities = transformed_dos.get_densities(spin=spin) + assert densities is not None - # Only consider a given erange, if specified + # Only consider a given energy range, if specified if erange: densities = densities[(energies >= erange[0]) & (energies <= erange[1])] energies = energies[(energies >= erange[0]) & (energies <= erange[1])] @@ -1092,40 +1197,30 @@ def get_dos_fp( max_e: float | None = None, n_bins: int = 256, normalize: bool = True, - ) -> NamedTuple: - """Generate the DOS fingerprint. - - Based on work of: + ) -> FingerPrint: + """Generate the DOS FingerPrint. - F. Knoop, T. A. r Purcell, M. Scheffler, C. Carbogno, J. Open Source Softw. 2020, 5, 2671. - Source - https://gitlab.com/vibes-developers/vibes/-/tree/master/vibes/materials_fp - Copyright (c) 2020 Florian Knoop, Thomas A.R.Purcell, Matthias Scheffler, Christian Carbogno. + Based on the work of: + F. Knoop, T. A. r Purcell, M. Scheffler, C. Carbogno, J. Open Source Softw. 2020, 5, 2671. + Source - https://gitlab.com/vibes-developers/vibes/-/tree/master/vibes/materials_fp + Copyright (c) 2020 Florian Knoop, Thomas A.R.Purcell, Matthias Scheffler, Christian Carbogno. Args: - type (str): Specify fingerprint type needed can accept '{s/p/d/f/}summed_{pdos/tdos}' - (default is summed_pdos) - binning (bool): If true, the DOS fingerprint is binned using np.linspace and n_bins. + type (str): The FingerPrint type, can be "{s/p/d/f/summed}_{pdos/tdos}" + (default is summed_pdos). + binning (bool): Whether to bin the DOS FingerPrint using np.linspace and n_bins. Default is True. - min_e (float): The minimum mode energy to include in the fingerprint (default is None) - max_e (float): The maximum mode energy to include in the fingerprint (default is None) - n_bins (int): Number of bins to be used in the fingerprint (default is 256) - normalize (bool): If true, normalizes the area under fp to equal to 1. Default is True. + min_e (float): The minimum energy to include (default is None). + max_e (float): The maximum energy to include (default is None). + n_bins (int): Number of bins to be used, if binning (default is 256). + normalize (bool): Whether to normalize the integrated DOS to 1. Default is True. Raises: - ValueError: If type is not one of the accepted values {s/p/d/f/}summed_{pdos/tdos}. + ValueError: If "type" is not one of the accepted values. Returns: - NamedTuple: The electronic density of states fingerprint - of format (energies, densities, type, n_bins) + FingerPrint(energies, densities, type, n_bins): The DOS fingerprint. """ - - class fingerprint(NamedTuple): - energies: NDArray - densities: NDArray - type: str - n_bins: int - bin_width: float - energies = self.energies - self.efermi if max_e is None: @@ -1136,20 +1231,18 @@ class fingerprint(NamedTuple): pdos_obj = self.get_spd_dos() - pdos = {} - for key in pdos_obj: - dens = pdos_obj[key].get_densities() - - pdos[key.name] = dens + pdos = {key.name: pdos_obj[key].get_densities() for key in pdos_obj} pdos["summed_pdos"] = np.sum(list(pdos.values()), axis=0) pdos["tdos"] = self.get_densities() try: densities = pdos[type] + assert densities is not None + if len(energies) < n_bins: inds = np.where((energies >= min_e) & (energies <= max_e)) - return fingerprint(energies[inds], densities[inds], type, len(energies), np.diff(energies)[0]) + return FingerPrint(energies[inds], densities[inds], type, len(energies), np.diff(energies)[0]) if binning: ener_bounds = np.linspace(min_e, max_e, n_bins + 1) @@ -1166,63 +1259,63 @@ class fingerprint(NamedTuple): for ii, e1, e2 in zip(range(len(ener)), ener_bounds[:-1], ener_bounds[1:]): inds = np.where((energies >= e1) & (energies < e2)) dos_rebin[ii] = np.sum(densities[inds]) - if normalize: # scale DOS bins to make area under histogram equal 1 + + # Scale DOS bins to make area under histogram equal 1 + if normalize: area = np.sum(dos_rebin * bin_width) dos_rebin_sc = dos_rebin / area else: dos_rebin_sc = dos_rebin - return fingerprint(np.array([ener]), dos_rebin_sc, type, n_bins, bin_width) + return FingerPrint(np.array([ener]), dos_rebin_sc, type, n_bins, bin_width) - except KeyError: + except KeyError as exc: raise ValueError( "Please recheck type requested, either the orbital projections unavailable in input DOS or " "there's a typo in type." - ) + ) from exc @staticmethod - def fp_to_dict(fp: NamedTuple) -> dict: - """Convert a fingerprint into a dictionary. + def fp_to_dict(fp: FingerPrint) -> dict[str, NDArray]: + """Convert a DOS FingerPrint into a dict. Args: - fp: The DOS fingerprint to be converted into a dictionary + fp (FingerPrint): The DOS FingerPrint to convert. Returns: - dict: A dict of the fingerprint Keys=type, Values=np.ndarray(energies, densities) + dict(Keys=type, Values=np.array(energies, densities)): The FingerPrint as dict. """ - fp_dict = {} - fp_dict[fp[2]] = np.array([fp[0], fp[1]], dtype="object").T - - return fp_dict + return {fp[2]: np.array([fp[0], fp[1]], dtype="object").T} @staticmethod def get_dos_fp_similarity( - fp1: NamedTuple, - fp2: NamedTuple, + fp1: FingerPrint, + fp2: FingerPrint, col: int = 1, - pt: int | str = "All", + pt: int | Literal["All"] = "All", normalize: bool = False, tanimoto: bool = False, ) -> float: - """Calculate the similarity index (dot product) of two fingerprints. + """Calculate the similarity index (dot product) of two DOS FingerPrints. Args: - fp1 (NamedTuple): The 1st dos fingerprint object - fp2 (NamedTuple): The 2nd dos fingerprint object - col (int): The item in the fingerprints (0:energies,1: densities) to take the dot product of (default is 1) - pt (int or str) : The index of the point that the dot product is to be taken (default is All) - normalize (bool): If True normalize the scalar product to 1 (default is False) - tanimoto (bool): If True will compute Tanimoto index (default is False) + fp1 (FingerPrint): The 1st DOS FingerPrint. + fp2 (FingerPrint): The 2nd DOS FingerPrint. + col (int): The item in the fingerprints (0: energies, 1: densities) + to take the dot product of (default is 1). + pt (int | "ALL") : The index of the point that the dot product is + to be taken (default is All). + normalize (bool): Whether to normalize the scalar product to 1 (default is False). + tanimoto (bool): Whether to compute Tanimoto index (default is False). Raises: - ValueError: If both tanimoto and normalize are set to True. + ValueError: If both tanimoto and normalize are True. Returns: - float: Similarity index given by the dot product + float: Similarity index given by the dot product. """ - fp1_dict = CompleteDos.fp_to_dict(fp1) if not isinstance(fp1, dict) else fp1 - - fp2_dict = CompleteDos.fp_to_dict(fp2) if not isinstance(fp2, dict) else fp2 + fp1_dict = fp1 if isinstance(fp1, dict) else CompleteDos.fp_to_dict(fp1) + fp2_dict = fp2 if isinstance(fp2, dict) else CompleteDos.fp_to_dict(fp2) if pt == "All": vec1 = np.array([pt[col] for pt in fp1_dict.values()]).flatten() @@ -1231,16 +1324,13 @@ def get_dos_fp_similarity( vec1 = fp1_dict[fp1[2][pt]][col] vec2 = fp2_dict[fp2[2][pt]][col] - if not normalize and tanimoto: - rescale = np.linalg.norm(vec1) ** 2 + np.linalg.norm(vec2) ** 2 - np.dot(vec1, vec2) - return np.dot(vec1, vec2) / rescale + if not normalize: + rescale = np.linalg.norm(vec1) ** 2 + np.linalg.norm(vec2) ** 2 - np.dot(vec1, vec2) if tanimoto else 1.0 - if not tanimoto and normalize: - rescale = np.linalg.norm(vec1) * np.linalg.norm(vec2) return np.dot(vec1, vec2) / rescale - if not tanimoto and not normalize: - rescale = 1.0 + if not tanimoto: + rescale = np.linalg.norm(vec1) * np.linalg.norm(vec2) return np.dot(vec1, vec2) / rescale raise ValueError( @@ -1262,7 +1352,7 @@ def from_dict(cls, dct: dict) -> Self: pdoss[at] = orb_dos return cls(struct, tdos, pdoss) - def as_dict(self) -> dict: + def as_dict(self) -> dict[str, Any]: """JSON-serializable dict representation of CompleteDos.""" dct = { "@module": type(self).__module__, @@ -1283,86 +1373,95 @@ def as_dict(self) -> dict: dct["spd_dos"] = {str(orb): dos.as_dict() for orb, dos in self.get_spd_dos().items()} return dct - def __str__(self) -> str: - return f"Complete DOS for {self.structure}" + +_lobster_orb_labs = ( + "s", + "p_y", + "p_z", + "p_x", + "d_xy", + "d_yz", + "d_z^2", + "d_xz", + "d_x^2-y^2", + "f_y(3x^2-y^2)", + "f_xyz", + "f_yz^2", + "f_z^3", + "f_xz^2", + "f_z(x^2-y^2)", + "f_x(x^2-3y^2)", +) class LobsterCompleteDos(CompleteDos): - """Extended CompleteDOS for Lobster.""" + """Extended CompleteDos for LOBSTER.""" def get_site_orbital_dos(self, site: PeriodicSite, orbital: str) -> Dos: # type: ignore[override] - """Get the Dos for a particular orbital of a particular site. + """Get the DOS for a particular orbital of a particular site. Args: - site: Site in Structure associated with CompleteDos. - orbital: principal quantum number and orbital in string format, e.g. "4s". - possible orbitals are: "s", "p_y", "p_z", "p_x", "d_xy", "d_yz", "d_z^2", - "d_xz", "d_x^2-y^2", "f_y(3x^2-y^2)", "f_xyz", - "f_yz^2", "f_z^3", "f_xz^2", "f_z(x^2-y^2)", "f_x(x^2-3y^2)" - In contrast to the Cohpcar and the Cohplist objects, the strings from the Lobster files are used + site (PeriodicSite): Site in Structure associated with LobsterCompleteDos. + orbital (str): Principal quantum number and orbital, e.g. "4s". + Possible orbitals are: "s", "p_y", "p_z", "p_x", "d_xy", "d_yz", "d_z^2", + "d_xz", "d_x^2-y^2", "f_y(3x^2-y^2)", "f_xyz", + "f_yz^2", "f_z^3", "f_xz^2", "f_z(x^2-y^2)", "f_x(x^2-3y^2)". + In contrast to the Cohpcar and the Cohplist objects, + the strings from the LOBSTER files are used. Returns: - Dos containing densities of an orbital of a specific site. + Dos: DOS of an orbital of a specific site. """ - if orbital[1:] not in { - "s", - "p_y", - "p_z", - "p_x", - "d_xy", - "d_yz", - "d_z^2", - "d_xz", - "d_x^2-y^2", - "f_y(3x^2-y^2)", - "f_xyz", - "f_yz^2", - "f_z^3", - "f_xz^2", - "f_z(x^2-y^2)", - "f_x(x^2-3y^2)", - }: + if orbital[1:] not in _lobster_orb_labs: raise ValueError("orbital is not correct") + return Dos(self.efermi, self.energies, self.pdos[site][orbital]) # type: ignore[index] - def get_site_t2g_eg_resolved_dos(self, site: PeriodicSite) -> dict[str, Dos]: - """Get the t2g, eg projected DOS for a particular site. + def get_site_t2g_eg_resolved_dos( + self, + site: PeriodicSite, + ) -> dict[Literal["e_g", "t2g"], Dos]: + """Get the t2g/e_g projected DOS for a particular site. Args: - site: Site in Structure associated with CompleteDos. + site (PeriodicSite): Site in Structure associated with LobsterCompleteDos. Returns: - A dict {"e_g": Dos, "t2g": Dos} containing summed e_g and t2g DOS - for the site. + dict[Literal["e_g", "t2g"], Dos]: Summed e_g and t2g DOS for the site. """ warnings.warn("Are the orbitals correctly oriented? Are you sure?") + t2g_dos = [] eg_dos = [] for s, atom_dos in self.pdos.items(): if s == site: for orb, pdos in atom_dos.items(): - if _get_orb_lobster(orb) in (Orbital.dxy, Orbital.dxz, Orbital.dyz): + orbital = _get_orb_lobster(str(orb)) + assert orbital is not None + + if orbital in (Orbital.dxy, Orbital.dxz, Orbital.dyz): t2g_dos.append(pdos) - elif _get_orb_lobster(orb) in (Orbital.dx2, Orbital.dz2): + elif orbital in (Orbital.dx2, Orbital.dz2): eg_dos.append(pdos) return { "t2g": Dos(self.efermi, self.energies, functools.reduce(add_densities, t2g_dos)), "e_g": Dos(self.efermi, self.energies, functools.reduce(add_densities, eg_dos)), } - def get_spd_dos(self) -> dict[OrbitalType, Dos]: - """Get orbital projected Dos. - For example, if 3s and 4s are included in the basis of some element, they will be both summed in the orbital - projected DOS. + def get_spd_dos(self) -> dict[str, Dos]: # type: ignore[override] + """Get orbital projected DOS. + + For example, if 3s and 4s are included in the basis of some element, + they will be both summed in the orbital projected DOS. Returns: - dict of {orbital: Dos}, e.g. {"s": Dos object, ...} + {orbital: Dos} """ spd_dos = {} orb = None for atom_dos in self.pdos.values(): for orb, pdos in atom_dos.items(): - orbital_type = _get_orb_type_lobster(orb) + orbital_type = _get_orb_type_lobster(str(orb)) if orbital_type not in spd_dos: spd_dos[orbital_type] = pdos else: @@ -1370,21 +1469,21 @@ def get_spd_dos(self) -> dict[OrbitalType, Dos]: return {orb: Dos(self.efermi, self.energies, densities) for orb, densities in spd_dos.items()} # type: ignore[misc] - def get_element_spd_dos(self, el: SpeciesLike) -> dict[OrbitalType, Dos]: - """Get element and spd projected Dos. + def get_element_spd_dos(self, el: SpeciesLike) -> dict[str, Dos]: # type: ignore[override] + """Get element and s/p/d projected DOS. Args: - el: Element in Structure.composition associated with LobsterCompleteDos + el (SpeciesLike): Element associated with LobsterCompleteDos. Returns: - dict of {OrbitalType.s: densities, OrbitalType.p: densities, OrbitalType.d: densities} + dict of {OrbitalType.s: Dos, OrbitalType.p: Dos, OrbitalType.d: Dos} """ el = get_el_sp(el) el_dos = {} for site, atom_dos in self.pdos.items(): if site.specie == el: for orb, pdos in atom_dos.items(): - orbital_type = _get_orb_type_lobster(orb) + orbital_type = _get_orb_type_lobster(str(orb)) if orbital_type not in el_dos: el_dos[orbital_type] = pdos else: @@ -1394,102 +1493,86 @@ def get_element_spd_dos(self, el: SpeciesLike) -> dict[OrbitalType, Dos]: @classmethod def from_dict(cls, dct: dict) -> Self: - """Hydrate CompleteDos object from dict representation.""" + """Get LobsterCompleteDos from a dict representation.""" tdos = Dos.from_dict(dct) struct = Structure.from_dict(dct["structure"]) - pdoss = {} - for i in range(len(dct["pdos"])): - at = struct[i] - orb_dos = {} - for orb_str, odos in dct["pdos"][i].items(): - orb = orb_str - orb_dos[orb] = {Spin(int(k)): v for k, v in odos["densities"].items()} - pdoss[at] = orb_dos - return cls(struct, tdos, pdoss) + pdos = {} + for idx in range(len(dct["pdos"])): + pdos[struct[idx]] = { + orb_str: {Spin(int(k)): v for k, v in odos["densities"].items()} + for orb_str, odos in dct["pdos"][idx].items() + } + return cls(struct, tdos, pdos) -def add_densities(density1: Mapping[Spin, ArrayLike], density2: Mapping[Spin, ArrayLike]) -> dict[Spin, np.ndarray]: - """Sum two densities. +def add_densities( + density1: dict[Spin, NDArray], + density2: dict[Spin, NDArray], +) -> dict[Spin, NDArray]: + """Sum two DOS along each spin channel. Args: - density1: First density. - density2: Second density. + density1 (dict[Spin, NDArray]): First DOS. + density2 (dict[Spin, NDArray]): Second DOS. Returns: - dict[Spin, np.ndarray] + dict[Spin, NDArray] """ return {spin: np.array(density1[spin]) + np.array(density2[spin]) for spin in density1} -def _get_orb_type(orb) -> OrbitalType: +def _get_orb_type(orb: Orbital | OrbitalType) -> OrbitalType: + """Get OrbitalType.""" try: - return orb.orbital_type + return cast(Orbital, orb).orbital_type except AttributeError: - return orb + return cast(OrbitalType, orb) -def f0(E, fermi, T) -> float: +def f0(E: float, fermi: float, T: float) -> float: """Fermi-Dirac distribution function. Args: - E (float): energy in eV - fermi (float): the Fermi level in eV - T (float): the temperature in kelvin + E (float): Energy in eV. + fermi (float): The Fermi level in eV. + T (float): The temperature in kelvin. Returns: - float: the Fermi-Dirac occupation probability at energy E + float: The Fermi-Dirac occupation probability at energy E. """ - return 1.0 / (1.0 + np.exp((E - fermi) / (_cd("Boltzmann constant in eV/K") * T))) + return 1.0 / (1.0 + np.exp((E - fermi) / (_constant("Boltzmann constant in eV/K") * T))) -def _get_orb_type_lobster(orb) -> OrbitalType | None: - """ +def _get_orb_type_lobster(orb: str) -> OrbitalType | None: + """Get OrbitalType from str representation of the orbital. + Args: - orb: string representation of orbital. + orb (str): String representation of the orbital. Returns: OrbitalType """ - orb_labs = ["s", "p_y", "p_z", "p_x", "d_xy", "d_yz", "d_z^2", "d_xz", "d_x^2-y^2"] - orb_labs += ["f_y(3x^2-y^2)", "f_xyz", "f_yz^2", "f_z^3", "f_xz^2", "f_z(x^2-y^2)", "f_x(x^2-3y^2)"] - try: - orbital = Orbital(orb_labs.index(orb[1:])) + orbital = Orbital(_lobster_orb_labs.index(orb[1:])) return orbital.orbital_type + except AttributeError: print("Orb not in list") return None -def _get_orb_lobster(orb): - """ +def _get_orb_lobster(orb: str) -> Orbital | None: + """Get Orbital from str representation of the orbital. + Args: - orb: string representation of orbital. + orb (str): String representation of the orbital. Returns: Orbital. """ - orb_labs = [ - "s", - "p_y", - "p_z", - "p_x", - "d_xy", - "d_yz", - "d_z^2", - "d_xz", - "d_x^2-y^2", - "f_y(3x^2-y^2)", - "f_xyz", - "f_yz^2", - "f_z^3", - "f_xz^2", - "f_z(x^2-y^2)", - "f_x(x^2-3y^2)", - ] - try: - return Orbital(orb_labs.index(orb[1:])) + return Orbital(_lobster_orb_labs.index(orb[1:])) + except AttributeError: print("Orb not in list") - return None + return None diff --git a/src/pymatgen/electronic_structure/plotter.py b/src/pymatgen/electronic_structure/plotter.py index b11a7c264ad..d17f8710740 100644 --- a/src/pymatgen/electronic_structure/plotter.py +++ b/src/pymatgen/electronic_structure/plotter.py @@ -20,6 +20,7 @@ from matplotlib.gridspec import GridSpec from monty.dev import requires from monty.json import jsanitize + from pymatgen.core import Element from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.boltztrap import BoltztrapError @@ -36,6 +37,7 @@ from typing import Literal from numpy.typing import ArrayLike + from pymatgen.electronic_structure.dos import CompleteDos, Dos logger = logging.getLogger(__name__) @@ -49,22 +51,27 @@ class DosPlotter: - """Plot DOSs. The interface is extremely flexible given there are many + """Plot DOS. The interface is extremely flexible given there are many different ways in which people want to view DOS. Typical usage is: - # Initializes plotter with some optional args. Defaults are usually fine + # Initialize plotter with some optional args. Defaults are usually fine plotter = PhononDosPlotter(). # Add DOS with a label plotter.add_dos("Total DOS", dos) - # Alternatively, you can add a dict of DOSes. This is the typical form + # Alternatively, you can add a dict of DOS. This is the typical form # returned by CompletePhononDos.get_element_dos(). plotter.add_dos_dict({"dos1": dos1, "dos2": dos2}) plotter.add_dos_dict(complete_dos.get_spd_dos()) """ - def __init__(self, zero_at_efermi: bool = True, stack: bool = False, sigma: float | None = None) -> None: + def __init__( + self, + zero_at_efermi: bool = True, + stack: bool = False, + sigma: float | None = None, + ) -> None: """ Args: zero_at_efermi (bool): Whether to shift all Dos to have zero energy at the @@ -3920,8 +3927,7 @@ def plot_fermi_surface( By default 0 eV correspond to the VBM, as in the plot of band structure along symmetry line. Default: One surface, with max energy value + 0.01 eV - cbm (bool): Boolean value to specify if the considered band is a - conduction band or not + cbm (bool): True if the considered band is a conduction band or not. multiple_figure (bool): If True a figure for each energy level will be shown. If False all the surfaces will be shown in the same figure. In this last case, tune the transparency factor. diff --git a/src/pymatgen/entries/__init__.py b/src/pymatgen/entries/__init__.py index fd0471b8416..51e6d6cee5a 100644 --- a/src/pymatgen/entries/__init__.py +++ b/src/pymatgen/entries/__init__.py @@ -12,6 +12,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.core.composition import Composition if TYPE_CHECKING: diff --git a/src/pymatgen/entries/compatibility.py b/src/pymatgen/entries/compatibility.py index ba2ef68dd07..8f5e082d479 100644 --- a/src/pymatgen/entries/compatibility.py +++ b/src/pymatgen/entries/compatibility.py @@ -12,9 +12,13 @@ from typing import TYPE_CHECKING, Union, cast import numpy as np +from joblib import Parallel, delayed from monty.design_patterns import cached_class from monty.json import MSONable from monty.serialization import loadfn +from tqdm import tqdm +from uncertainties import ufloat + from pymatgen.analysis.structure_analyzer import oxide_type, sulfide_type from pymatgen.core import SETTINGS, Composition, Element from pymatgen.entries.computed_entries import ( @@ -27,8 +31,7 @@ ) from pymatgen.io.vasp.sets import MITRelaxSet, MPRelaxSet, VaspInputSet from pymatgen.util.due import Doi, due -from tqdm import tqdm -from uncertainties import ufloat +from pymatgen.util.joblib import set_python_warnings, tqdm_joblib if TYPE_CHECKING: from collections.abc import Sequence @@ -537,28 +540,86 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]: """ raise NotImplementedError - def process_entry(self, entry: ComputedEntry, **kwargs) -> ComputedEntry | None: + def process_entry(self, entry: ComputedEntry, inplace: bool = True, **kwargs) -> ComputedEntry | None: """Process a single entry with the chosen Corrections. Note that this method will change the data of the original entry. Args: entry: A ComputedEntry object. + inplace (bool): Whether to adjust the entry in place. Defaults to True. **kwargs: Will be passed to process_entries(). Returns: An adjusted entry if entry is compatible, else None. """ - try: - return self.process_entries(entry, **kwargs)[0] - except IndexError: + if not inplace: + entry = copy.deepcopy(entry) + + entry = self._process_entry_inplace(entry, **kwargs) + + return entry[0] if entry is not None else None + + def _process_entry_inplace( + self, + entry: AnyComputedEntry, + clean: bool = True, + on_error: Literal["ignore", "warn", "raise"] = "ignore", + ) -> ComputedEntry | None: + """Process a single entry with the chosen Corrections. Note + that this method will change the data of the original entry. + + Args: + entry: A ComputedEntry object. + clean (bool): Whether to remove any previously-applied energy adjustments. + If True, all EnergyAdjustment are removed prior to processing the Entry. + Defaults to True. + on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry) + raises CompatibilityError. Defaults to 'ignore'. + + Returns: + An adjusted entry if entry is compatible, else None. + """ + ignore_entry = False + # if clean is True, remove all previous adjustments from the entry + if clean: + entry.energy_adjustments = [] + + try: # get the energy adjustments + adjustments = self.get_adjustments(entry) + except CompatibilityError as exc: + if on_error == "raise": + raise + if on_error == "warn": + warnings.warn(str(exc)) return None + for ea in adjustments: + # Has this correction already been applied? + if (ea.name, ea.cls, ea.value) in [(ea2.name, ea2.cls, ea2.value) for ea2 in entry.energy_adjustments]: + # we already applied this exact correction. Do nothing. + pass + elif (ea.name, ea.cls) in [(ea2.name, ea2.cls) for ea2 in entry.energy_adjustments]: + # we already applied a correction with the same name + # but a different value. Something is wrong. + ignore_entry = True + warnings.warn( + f"Entry {entry.entry_id} already has an energy adjustment called {ea.name}, but its " + f"value differs from the value of {ea.value:.3f} calculated here. This " + "Entry will be discarded." + ) + else: + # Add the correction to the energy_adjustments list + entry.energy_adjustments.append(ea) + + return entry, ignore_entry + def process_entries( self, entries: AnyComputedEntry | list[AnyComputedEntry], clean: bool = True, verbose: bool = False, inplace: bool = True, + n_workers: int = 1, on_error: Literal["ignore", "warn", "raise"] = "ignore", ) -> list[AnyComputedEntry]: """Process a sequence of entries with the chosen Compatibility scheme. @@ -575,6 +636,7 @@ def process_entries( verbose (bool): Whether to display progress bar for processing multiple entries. Defaults to False. inplace (bool): Whether to adjust input entries in place. Defaults to True. + n_workers (int): Number of workers to use for parallel processing. Defaults to 1. on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry) raises CompatibilityError. Defaults to 'ignore'. @@ -592,41 +654,28 @@ def process_entries( if not inplace: entries = copy.deepcopy(entries) - for entry in tqdm(entries, disable=not verbose): - ignore_entry = False - # if clean is True, remove all previous adjustments from the entry - if clean: - entry.energy_adjustments = [] - - try: # get the energy adjustments - adjustments = self.get_adjustments(entry) - except CompatibilityError as exc: - if on_error == "raise": - raise - if on_error == "warn": - warnings.warn(str(exc)) - continue - - for ea in adjustments: - # Has this correction already been applied? - if (ea.name, ea.cls, ea.value) in [(ea2.name, ea2.cls, ea2.value) for ea2 in entry.energy_adjustments]: - # we already applied this exact correction. Do nothing. - pass - elif (ea.name, ea.cls) in [(ea2.name, ea2.cls) for ea2 in entry.energy_adjustments]: - # we already applied a correction with the same name - # but a different value. Something is wrong. - ignore_entry = True - warnings.warn( - f"Entry {entry.entry_id} already has an energy adjustment called {ea.name}, but its " - f"value differs from the value of {ea.value:.3f} calculated here. This " - "Entry will be discarded." - ) - else: - # Add the correction to the energy_adjustments list - entry.energy_adjustments.append(ea) - - if not ignore_entry: - processed_entry_list.append(entry) + if n_workers == 1: + for entry in tqdm(entries, disable=not verbose): + result = self._process_entry_inplace(entry, clean, on_error) + if result is None: + continue + entry, ignore_entry = result + if not ignore_entry: + processed_entry_list.append(entry) + elif not inplace: + # set python warnings to ignore otherwise warnings will be printed multiple times + with tqdm_joblib(tqdm(total=len(entries), disable=not verbose)), set_python_warnings("ignore"): + results = Parallel(n_jobs=n_workers)( + delayed(self._process_entry_inplace)(entry, clean, on_error) for entry in entries + ) + for result in results: + if result is None: + continue + entry, ignore_entry = result + if not ignore_entry: + processed_entry_list.append(entry) + else: + raise ValueError("Parallel processing is not possible with for 'inplace=True'") return processed_entry_list @@ -1132,7 +1181,9 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]: expected_u = float(u_settings.get(symbol, 0)) actual_u = float(calc_u.get(symbol, 0)) if actual_u != expected_u: - raise CompatibilityError(f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3}") + raise CompatibilityError( + f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3} for {entry.as_dict()}" + ) if symbol in u_corrections: adjustments.append( CompositionEnergyAdjustment( @@ -1449,6 +1500,7 @@ def process_entries( clean: bool = False, verbose: bool = False, inplace: bool = True, + n_workers: int = 1, on_error: Literal["ignore", "warn", "raise"] = "ignore", ) -> list[AnyComputedEntry]: """Process a sequence of entries with the chosen Compatibility scheme. @@ -1462,6 +1514,7 @@ def process_entries( Default is False. inplace (bool): Whether to modify the entries in place. If False, a copy of the entries is made and processed. Default is True. + n_workers (int): Number of workers to use for parallel processing. Default is 1. on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry) raises CompatibilityError. Defaults to 'ignore'. @@ -1479,7 +1532,8 @@ def process_entries( # pre-process entries with the given solid compatibility class if self.solid_compat: - entries = self.solid_compat.process_entries(entries, clean=True) + entries = self.solid_compat.process_entries(entries, clean=True, inplace=inplace, n_workers=n_workers) + return [entries] # when processing single entries, all H2 polymorphs will get assigned the # same energy @@ -1513,7 +1567,9 @@ def process_entries( h2_entries = sorted(h2_entries, key=lambda e: e.energy_per_atom) self.h2_energy = h2_entries[0].energy_per_atom # type: ignore[assignment] - return super().process_entries(entries, clean=clean, verbose=verbose, inplace=inplace, on_error=on_error) + return super().process_entries( + entries, clean=clean, verbose=verbose, inplace=inplace, n_workers=n_workers, on_error=on_error + ) def needs_u_correction( diff --git a/src/pymatgen/entries/computed_entries.py b/src/pymatgen/entries/computed_entries.py index b80a45f8b29..c1d89545828 100644 --- a/src/pymatgen/entries/computed_entries.py +++ b/src/pymatgen/entries/computed_entries.py @@ -17,18 +17,20 @@ import numpy as np from monty.json import MontyDecoder, MontyEncoder, MSONable +from scipy.interpolate import interp1d +from uncertainties import ufloat + from pymatgen.core.composition import Composition from pymatgen.entries import Entry from pymatgen.util.due import Doi, due -from scipy.interpolate import interp1d -from uncertainties import ufloat if TYPE_CHECKING: from typing import Literal - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + __author__ = "Ryan Kingsbury, Matt McDermott, Shyue Ping Ong, Anubhav Jain" __copyright__ = "Copyright 2011-2020, The Materials Project" __version__ = "1.1" @@ -506,15 +508,13 @@ def from_dict(cls, dct: dict) -> Self: def as_dict(self) -> dict: """MSONable dict.""" return_dict = super().as_dict() - return_dict.update( - { - "entry_id": self.entry_id, - "correction": self.correction, - "energy_adjustments": json.loads(json.dumps(self.energy_adjustments, cls=MontyEncoder)), - "parameters": json.loads(json.dumps(self.parameters, cls=MontyEncoder)), - "data": json.loads(json.dumps(self.data, cls=MontyEncoder)), - } - ) + return_dict |= { + "entry_id": self.entry_id, + "correction": self.correction, + "energy_adjustments": json.loads(json.dumps(self.energy_adjustments, cls=MontyEncoder)), + "parameters": json.loads(json.dumps(self.parameters, cls=MontyEncoder)), + "data": json.loads(json.dumps(self.data, cls=MontyEncoder)), + } return return_dict def __hash__(self) -> int: diff --git a/src/pymatgen/entries/correction_calculator.py b/src/pymatgen/entries/correction_calculator.py index 778c3885752..e832a0576c0 100644 --- a/src/pymatgen/entries/correction_calculator.py +++ b/src/pymatgen/entries/correction_calculator.py @@ -10,11 +10,12 @@ import numpy as np import plotly.graph_objects as go from monty.serialization import loadfn +from ruamel import yaml +from scipy.optimize import curve_fit + from pymatgen.analysis.reaction_calculator import ComputedReaction from pymatgen.analysis.structure_analyzer import sulfide_type from pymatgen.core import Composition, Element -from ruamel import yaml -from scipy.optimize import curve_fit class CorrectionCalculator: diff --git a/src/pymatgen/entries/entry_tools.py b/src/pymatgen/entries/entry_tools.py index fc6eb821891..71581cc22e9 100644 --- a/src/pymatgen/entries/entry_tools.py +++ b/src/pymatgen/entries/entry_tools.py @@ -6,16 +6,17 @@ import collections import csv -import datetime import itertools import json import logging import multiprocessing as mp import re from collections import defaultdict +from datetime import datetime, timezone from typing import TYPE_CHECKING from monty.json import MontyDecoder, MontyEncoder, MSONable + from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.structure_matcher import SpeciesComparator, StructureMatcher from pymatgen.core import Composition, Element @@ -24,9 +25,10 @@ from collections.abc import Iterable from typing import Literal + from typing_extensions import Self + from pymatgen.entries import Entry from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry - from typing_extensions import Self logger = logging.getLogger(__name__) @@ -110,7 +112,7 @@ def group_entries_by_structure( """ if comparator is None: comparator = SpeciesComparator() - start = datetime.datetime.now(tz=datetime.timezone.utc) + start = datetime.now(tz=timezone.utc) logger.info(f"Started at {start}") entries_host = [(entry, _get_host(entry.structure, species_to_remove)) for entry in entries] if ncpus: @@ -159,8 +161,8 @@ def group_entries_by_structure( entry_groups = [] for g in groups: entry_groups.append(json.loads(g, cls=MontyDecoder)) - logging.info(f"Finished at {datetime.datetime.now(tz=datetime.timezone.utc)}") - logging.info(f"Took {datetime.datetime.now(tz=datetime.timezone.utc) - start}") + logging.info(f"Finished at {datetime.now(tz=timezone.utc)}") + logging.info(f"Took {datetime.now(tz=timezone.utc) - start}") return entry_groups diff --git a/src/pymatgen/entries/exp_entries.py b/src/pymatgen/entries/exp_entries.py index 78adad7a6ef..e012ed7a9fb 100644 --- a/src/pymatgen/entries/exp_entries.py +++ b/src/pymatgen/entries/exp_entries.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from monty.json import MSONable + from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.thermochemistry import ThermoData from pymatgen.core.composition import Composition diff --git a/src/pymatgen/entries/mixing_scheme.py b/src/pymatgen/entries/mixing_scheme.py index 8a4a50fb214..8f8fb637ab5 100644 --- a/src/pymatgen/entries/mixing_scheme.py +++ b/src/pymatgen/entries/mixing_scheme.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd + from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.entries.compatibility import ( diff --git a/src/pymatgen/ext/cod.py b/src/pymatgen/ext/cod.py index 0bc9966c2dc..21983b6891f 100644 --- a/src/pymatgen/ext/cod.py +++ b/src/pymatgen/ext/cod.py @@ -34,6 +34,7 @@ import requests from monty.dev import requires + from pymatgen.core.composition import Composition from pymatgen.core.structure import Structure diff --git a/src/pymatgen/ext/matproj.py b/src/pymatgen/ext/matproj.py index 23d867529ff..25f7229a3e5 100644 --- a/src/pymatgen/ext/matproj.py +++ b/src/pymatgen/ext/matproj.py @@ -20,6 +20,7 @@ import requests from monty.json import MontyDecoder + from pymatgen.core import SETTINGS from pymatgen.core import __version__ as PMG_VERSION from pymatgen.symmetry.analyzer import SpacegroupAnalyzer @@ -28,10 +29,11 @@ from typing import Callable from mp_api.client import MPRester as _MPResterNew + from typing_extensions import Self + from pymatgen.core.structure import Structure from pymatgen.entries.computed_entries import ComputedStructureEntry from pymatgen.ext.matproj_legacy import _MPResterLegacy - from typing_extensions import Self logger = logging.getLogger(__name__) @@ -332,7 +334,7 @@ def get_entries_in_chemsys(self, elements, *args, **kwargs): Li, Fe and O phases. Extremely useful for creating phase diagrams of entire chemical systems. Args: - elements (str or [str]): Chemical system string comprising element + elements (str | list[str]): Chemical system string comprising element symbols separated by dashes, e.g. "Li-Fe-O" or List of element symbols, e.g. ["Li", "Fe", "O"]. *args: Pass-through to get_entries. diff --git a/src/pymatgen/ext/matproj_legacy.py b/src/pymatgen/ext/matproj_legacy.py index ff3ea075ec3..f256207d521 100644 --- a/src/pymatgen/ext/matproj_legacy.py +++ b/src/pymatgen/ext/matproj_legacy.py @@ -20,6 +20,9 @@ import requests from monty.json import MontyDecoder, MontyEncoder +from ruamel.yaml import YAML +from tqdm import tqdm + from pymatgen.core import SETTINGS, Composition, Element, Structure from pymatgen.core import __version__ as PMG_VERSION from pymatgen.core.surface import get_symmetrically_equivalent_miller_indices @@ -28,16 +31,15 @@ from pymatgen.entries.exp_entries import ExpEntry from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.due import Doi, due -from ruamel.yaml import YAML -from tqdm import tqdm if TYPE_CHECKING: from collections.abc import Sequence from typing import Any, Literal + from typing_extensions import Self + from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos - from typing_extensions import Self logger = logging.getLogger(__name__) MP_LOG_FILE = os.path.join(os.path.expanduser("~"), ".mprester.log.yaml") @@ -533,7 +535,7 @@ def get_entries( ] data = {"oxide_type": d["oxide_type"]} if property_data: - data.update({k: d[k] for k in property_data}) + data |= {k: d[k] for k in property_data} if not inc_structure: e = ComputedEntry( d["unit_cell_formula"], @@ -573,7 +575,7 @@ def get_pourbaix_entries(self, chemsys, solid_compat="MaterialsProject2020Compat a Pourbaix diagram from the rest interface. Args: - chemsys (str or [str]): Chemical system string comprising element + chemsys (str | list[str]): Chemical system string comprising element symbols separated by dashes, e.g. "Li-Fe-O" or List of element symbols, e.g. ["Li", "Fe", "O"]. solid_compat: Compatibility scheme used to pre-process solid DFT energies prior to applying aqueous @@ -779,7 +781,7 @@ def get_bandstructure_by_material_id(self, material_id, line_mode=True): (default). If False, return the uniform band structure. Returns: - A BandStructure object. + BandStructure """ prop = "bandstructure" if line_mode else "bandstructure_uniform" data = self.get_data(material_id, prop=prop) @@ -834,7 +836,7 @@ def get_entries_in_chemsys( phases. Extremely useful for creating phase diagrams of entire chemical systems. Args: - elements (str or [str]): Chemical system string comprising element + elements (str | list[str]): Chemical system string comprising element symbols separated by dashes, e.g. "Li-Fe-O" or List of element symbols, e.g. ["Li", "Fe", "O"]. compatible_only (bool): Whether to return only "compatible" @@ -999,7 +1001,7 @@ def query( progress_bar = tqdm(total=len(mids), disable=not show_progress_bar) for chunk in chunks: chunk_criteria = criteria.copy() - chunk_criteria.update({"material_id": {"$in": chunk}}) + chunk_criteria |= {"material_id": {"$in": chunk}} n_tries = 0 while n_tries < max_tries_per_chunk: try: @@ -1575,7 +1577,7 @@ def _check_nomad_exist(url) -> bool: return content["pagination"]["total"] != 0 @staticmethod - def parse_criteria(criteria_string): + def parse_criteria(criteria_string) -> dict: """Parse a powerful and simple string criteria and generates a proper mongo syntax criteria. diff --git a/src/pymatgen/ext/optimade.py b/src/pymatgen/ext/optimade.py index fa7e9fb3d3c..5221e0bc0d1 100644 --- a/src/pymatgen/ext/optimade.py +++ b/src/pymatgen/ext/optimade.py @@ -8,10 +8,11 @@ from urllib.parse import urljoin, urlparse import requests +from tqdm import tqdm + from pymatgen.core import DummySpecies, Structure from pymatgen.util.due import Doi, due from pymatgen.util.provenance import StructureNL -from tqdm import tqdm if TYPE_CHECKING: from typing import ClassVar @@ -499,7 +500,7 @@ def _parse_provider(self, provider: str, provider_url: str) -> dict[str, Provide It does not raise exceptions but will instead _logger.warning and provide an empty dictionary in the case of invalid data. - In future, when the specification is sufficiently well adopted, + In future, when the specification is sufficiently well adopted, we might be more strict here. Args: @@ -507,8 +508,7 @@ def _parse_provider(self, provider: str, provider_url: str) -> dict[str, Provide provider_url: An OPTIMADE provider URL Returns: - A dictionary of keys (in format of "provider.database") to - Provider objects. + dict: keys (in format of "provider.database") mapped to Provider objects. """ # Add trailing slash to all URLs if missing; prevents urljoin from scrubbing if urlparse(provider_url).path is not None and not provider_url.endswith("/"): diff --git a/src/pymatgen/io/abinit/abiobjects.py b/src/pymatgen/io/abinit/abiobjects.py index 46482f44b3a..76d99e289fa 100644 --- a/src/pymatgen/io/abinit/abiobjects.py +++ b/src/pymatgen/io/abinit/abiobjects.py @@ -13,6 +13,7 @@ from monty.collections import AttrDict from monty.design_patterns import singleton from monty.json import MontyDecoder, MontyEncoder, MSONable + from pymatgen.core import ArrayWithUnit, Lattice, Species, Structure, units if TYPE_CHECKING: @@ -265,10 +266,7 @@ def structure_to_abivars( # One should make sure that the orientation is preserved (see Curtarolo's settings) if geo_mode == "rprim": - dct.update( - acell=3 * [1.0], - rprim=r_prim, - ) + dct.update(acell=3 * [1.0], rprim=r_prim) elif geo_mode == "angdeg": dct.update( @@ -375,7 +373,7 @@ def to_abivars(self): def as_dict(self): """JSON-friendly dict representation of SpinMode.""" out = {k: getattr(self, k) for k in self._fields} - out.update({"@module": type(self).__module__, "@class": type(self).__name__}) + out |= {"@module": type(self).__module__, "@class": type(self).__name__} return out @classmethod @@ -384,7 +382,6 @@ def from_dict(cls, dct: dict) -> Self: return cls(**{key: dct[key] for key in dct if key in cls._fields}) -# An handy Multiton _mode_to_spin_vars = { "unpolarized": SpinMode("unpolarized", 1, 1, 1), "polarized": SpinMode("polarized", 2, 1, 2), @@ -631,13 +628,7 @@ def to_abivars(self): """Return dictionary with Abinit variables.""" abivars = self.spin_mode.to_abivars() - abivars.update( - { - "nband": self.nband, - "fband": self.fband, - "charge": self.charge, - } - ) + abivars |= {"nband": self.nband, "fband": self.fband, "charge": self.charge} if self.smearing: abivars.update(self.smearing.to_abivars()) @@ -740,15 +731,13 @@ def __init__( else: # use_symmetries and not use_time_reversal kptopt = 4 - abivars.update( - { - "ngkpt": ngkpt, - "shiftk": shiftk, - "nshiftk": len(shiftk), - "kptopt": kptopt, - "chksymbreak": chksymbreak, - } - ) + abivars |= { + "ngkpt": ngkpt, + "shiftk": shiftk, + "nshiftk": len(shiftk), + "kptopt": kptopt, + "chksymbreak": chksymbreak, + } elif mode == KSamplingModes.path: if num_kpts <= 0: @@ -756,29 +745,21 @@ def __init__( kptbounds = np.reshape(kpts, (-1, 3)) - abivars.update( - { - "ndivsm": num_kpts, - "kptbounds": kptbounds, - "kptopt": -len(kptbounds) + 1, - } - ) + abivars |= {"ndivsm": num_kpts, "kptbounds": kptbounds, "kptopt": -len(kptbounds) + 1} elif mode == KSamplingModes.automatic: kpts = np.reshape(kpts, (-1, 3)) if len(kpts) != num_kpts: raise ValueError("For Automatic mode, num_kpts must be specified.") - abivars.update( - { - "kptopt": 0, - "kpt": kpts, - "nkpt": num_kpts, - "kptnrm": np.ones(num_kpts), - "wtk": kpts_weights, # for iscf/=-2, wtk. - "chksymbreak": chksymbreak, - } - ) + abivars |= { + "kptopt": 0, + "kpt": kpts, + "nkpt": num_kpts, + "kptnrm": np.ones(num_kpts), + "wtk": kpts_weights, # for iscf/=-2, wtk. + "chksymbreak": chksymbreak, + } else: raise ValueError(f"Unknown {mode=}") @@ -1115,11 +1096,7 @@ def to_abivars(self): # Atom relaxation. if self.move_atoms: - out_vars.update( - { - "tolmxf": self.abivars.tolmxf, - } - ) + out_vars["tolmxf"] = self.abivars.tolmxf if self.abivars.atoms_constraints: # Add input variables for constrained relaxation. diff --git a/src/pymatgen/io/abinit/abitimer.py b/src/pymatgen/io/abinit/abitimer.py index ceb86f03207..036d8ef1911 100644 --- a/src/pymatgen/io/abinit/abitimer.py +++ b/src/pymatgen/io/abinit/abitimer.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd from matplotlib.gridspec import GridSpec + from pymatgen.io.core import ParseError from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig diff --git a/src/pymatgen/io/abinit/inputs.py b/src/pymatgen/io/abinit/inputs.py index 1382f819d7a..180fabff915 100644 --- a/src/pymatgen/io/abinit/inputs.py +++ b/src/pymatgen/io/abinit/inputs.py @@ -18,6 +18,7 @@ import numpy as np from monty.collections import AttrDict from monty.json import MSONable + from pymatgen.core.structure import Structure from pymatgen.io.abinit import abiobjects as aobj from pymatgen.io.abinit.pseudos import Pseudo, PseudoTable diff --git a/src/pymatgen/io/abinit/netcdf.py b/src/pymatgen/io/abinit/netcdf.py index dac91a9dcfa..5e4dfd75f84 100644 --- a/src/pymatgen/io/abinit/netcdf.py +++ b/src/pymatgen/io/abinit/netcdf.py @@ -12,6 +12,7 @@ from monty.dev import requires from monty.functools import lazy_property from monty.string import marquee + from pymatgen.core.structure import Structure from pymatgen.core.units import ArrayWithUnit from pymatgen.core.xcfunc import XcFunc diff --git a/src/pymatgen/io/abinit/pseudos.py b/src/pymatgen/io/abinit/pseudos.py index df6c2f9264e..345f6dc652c 100644 --- a/src/pymatgen/io/abinit/pseudos.py +++ b/src/pymatgen/io/abinit/pseudos.py @@ -24,11 +24,12 @@ from monty.itertools import iterator_from_slice from monty.json import MontyDecoder, MSONable from monty.os.path import find_exts +from tabulate import tabulate + from pymatgen.core import Element from pymatgen.core.xcfunc import XcFunc from pymatgen.io.core import ParseError from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig -from tabulate import tabulate if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -36,9 +37,10 @@ import matplotlib.pyplot as plt from numpy.typing import NDArray - from pymatgen.core import Structure from typing_extensions import Self + from pymatgen.core import Structure + logger = logging.getLogger(__name__) diff --git a/src/pymatgen/io/adf.py b/src/pymatgen/io/adf.py index 59a3d8bfabc..865493b697f 100644 --- a/src/pymatgen/io/adf.py +++ b/src/pymatgen/io/adf.py @@ -10,6 +10,7 @@ from monty.itertools import chunks from monty.json import MSONable from monty.serialization import zopen + from pymatgen.core.structure import Molecule if TYPE_CHECKING: diff --git a/src/pymatgen/io/aims/inputs.py b/src/pymatgen/io/aims/inputs.py index e03d0355f75..9ca84366b28 100644 --- a/src/pymatgen/io/aims/inputs.py +++ b/src/pymatgen/io/aims/inputs.py @@ -17,15 +17,17 @@ from monty.io import zopen from monty.json import MontyDecoder, MSONable from monty.os.path import zpath + from pymatgen.core import SETTINGS, Element, Lattice, Molecule, Structure if TYPE_CHECKING: from collections.abc import Sequence from typing import Any - from pymatgen.util.typing import Tuple3Floats, Tuple3Ints from typing_extensions import Self + from pymatgen.util.typing import Tuple3Floats, Tuple3Ints + __author__ = "Thomas A. R. Purcell" __version__ = "1.0" __email__ = "purcellt@arizona.edu" diff --git a/src/pymatgen/io/aims/outputs.py b/src/pymatgen/io/aims/outputs.py index f6012472ac5..45f680f6d84 100644 --- a/src/pymatgen/io/aims/outputs.py +++ b/src/pymatgen/io/aims/outputs.py @@ -6,6 +6,7 @@ import numpy as np from monty.json import MontyDecoder, MSONable + from pymatgen.io.aims.parsers import ( read_aims_header_info, read_aims_header_info_from_content, @@ -18,9 +19,10 @@ from pathlib import Path from typing import Any + from typing_extensions import Self + from pymatgen.core import Molecule, Structure from pymatgen.util.typing import Matrix3D, Vector3D - from typing_extensions import Self __author__ = "Andrey Sobolev and Thomas A. R. Purcell" __version__ = "1.0" diff --git a/src/pymatgen/io/aims/parsers.py b/src/pymatgen/io/aims/parsers.py index 0a70723e57e..288d735682a 100644 --- a/src/pymatgen/io/aims/parsers.py +++ b/src/pymatgen/io/aims/parsers.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, cast import numpy as np + from pymatgen.core import Lattice, Molecule, Structure from pymatgen.core.tensors import Tensor from pymatgen.util.typing import Tuple3Floats @@ -319,24 +320,14 @@ def _parse_k_points(self) -> None: """Parse the list of k-points used in the calculation.""" n_kpts = self.parse_scalar("n_kpts") if n_kpts is None: - self._cache.update( - { - "k_points": None, - "k_point_weights": None, - } - ) + self._cache |= {"k_points": None, "k_point_weights": None} return n_kpts = int(n_kpts) line_start = self.reverse_search_for(["| K-points in task"]) line_end = self.reverse_search_for(["| k-point:"]) if LINE_NOT_FOUND in {line_start, line_end} or (line_end - line_start != n_kpts): - self._cache.update( - { - "k_points": None, - "k_point_weights": None, - } - ) + self._cache |= {"k_points": None, "k_point_weights": None} return k_points = np.zeros((n_kpts, 3)) @@ -345,12 +336,7 @@ def _parse_k_points(self) -> None: k_points[kk] = [float(inp) for inp in line.split()[4:7]] k_point_weights[kk] = float(line.split()[-1]) - self._cache.update( - { - "k_points": k_points, - "k_point_weights": k_point_weights, - } - ) + self._cache |= {"k_points": k_points, "k_point_weights": k_point_weights} @property def n_atoms(self) -> int: @@ -611,7 +597,7 @@ def lattice(self) -> Lattice: return self._cache["lattice"] @property - def forces(self) -> np.array[Vector3D] | None: + def forces(self) -> np.ndarray | None: """The forces from the aims.out file.""" line_start = self.reverse_search_for(["Total atomic forces"]) if line_start == LINE_NOT_FOUND: @@ -624,8 +610,8 @@ def forces(self) -> np.array[Vector3D] | None: ) @property - def stresses(self) -> np.array[Matrix3D] | None: - """The stresses from the aims.out file and convert to kbar.""" + def stresses(self) -> np.ndarray | None: + """The stresses from the aims.out file and convert to kBar.""" line_start = self.reverse_search_for(["Per atom stress (eV) used for heat flux calculation"]) if line_start == LINE_NOT_FOUND: return None @@ -639,12 +625,9 @@ def stresses(self) -> np.array[Matrix3D] | None: @property def stress(self) -> Matrix3D | None: - """The stress from the aims.out file and convert to kbar.""" + """The stress from the aims.out file and convert to kBar.""" line_start = self.reverse_search_for( - [ - "Analytical stress tensor - Symmetrized", - "Numerical stress tensor", - ] + ["Analytical stress tensor - Symmetrized", "Numerical stress tensor"] ) # Offset to relevant lines if line_start == LINE_NOT_FOUND: return None @@ -738,14 +721,12 @@ def _parse_hirshfeld( """Parse the Hirshfled charges volumes, and dipole moments.""" line_start = self.reverse_search_for(["Performing Hirshfeld analysis of fragment charges and moments."]) if line_start == LINE_NOT_FOUND: - self._cache.update( - { - "hirshfeld_charges": None, - "hirshfeld_volumes": None, - "hirshfeld_atomic_dipoles": None, - "hirshfeld_dipole": None, - } - ) + self._cache |= { + "hirshfeld_charges": None, + "hirshfeld_volumes": None, + "hirshfeld_atomic_dipoles": None, + "hirshfeld_dipole": None, + } return line_inds = self.search_for_all("Hirshfeld charge", line_start, -1) @@ -767,14 +748,12 @@ def _parse_hirshfeld( else: hirshfeld_dipole = None - self._cache.update( - { - "hirshfeld_charges": hirshfeld_charges, - "hirshfeld_volumes": hirshfeld_volumes, - "hirshfeld_atomic_dipoles": hirshfeld_atomic_dipoles, - "hirshfeld_dipole": hirshfeld_dipole, - } - ) + self._cache |= { + "hirshfeld_charges": hirshfeld_charges, + "hirshfeld_volumes": hirshfeld_volumes, + "hirshfeld_atomic_dipoles": hirshfeld_atomic_dipoles, + "hirshfeld_dipole": hirshfeld_dipole, + } @property def structure(self) -> Structure | Molecule: diff --git a/src/pymatgen/io/aims/sets/base.py b/src/pymatgen/io/aims/sets/base.py index c15385d9a2e..5f266a21a0c 100644 --- a/src/pymatgen/io/aims/sets/base.py +++ b/src/pymatgen/io/aims/sets/base.py @@ -5,20 +5,20 @@ import copy import json import logging -from collections.abc import Iterable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from warnings import warn import numpy as np from monty.json import MontyDecoder, MontyEncoder + from pymatgen.core import Molecule, Structure from pymatgen.io.aims.inputs import AimsControlIn, AimsGeometryIn from pymatgen.io.aims.parsers import AimsParseError, read_aims_output from pymatgen.io.core import InputFile, InputGenerator, InputSet if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from pymatgen.util.typing import PathLike @@ -233,7 +233,7 @@ def _read_previous( prev_dir (str or Path): The previous directory for the calculation """ prev_structure: Structure | Molecule | None = None - prev_parameters = {} + prev_params = {} prev_results: dict[str, Any] = {} if prev_dir: @@ -242,7 +242,7 @@ def _read_previous( # jobflow_remote) split_prev_dir = str(prev_dir).split(":")[-1] with open(f"{split_prev_dir}/parameters.json") as param_file: - prev_parameters = json.load(param_file, cls=MontyDecoder) + prev_params = json.load(param_file, cls=MontyDecoder) try: aims_output: Sequence[Structure | Molecule] = read_aims_output( @@ -255,7 +255,7 @@ def _read_previous( except (IndexError, AimsParseError): pass - return prev_structure, prev_parameters, prev_results + return prev_structure, prev_params, prev_results @staticmethod def _get_properties( @@ -307,12 +307,9 @@ def _get_input_parameters( Returns: dict: The input object """ - # Get the default configuration - # FHI-aims recommends using their defaults so bare-bones default parameters - parameters: dict[str, Any] = { - "xc": "pbe", - "relativistic": "atomic_zora scalar", - } + # Get the default config + # FHI-aims recommends using their defaults so bare-bones default params + params: dict[str, Any] = {"xc": "pbe", "relativistic": "atomic_zora scalar"} # Override default parameters with previous parameters prev_parameters = {} if prev_parameters is None else copy.deepcopy(prev_parameters) @@ -326,25 +323,25 @@ def _get_input_parameters( kpt_settings["density"] = density parameter_updates = self.get_parameter_updates(structure, prev_parameters) - parameters = recursive_update(parameters, parameter_updates) + params = recursive_update(params, parameter_updates) # Override default parameters with user_params - parameters = recursive_update(parameters, self.user_params) - if ("k_grid" in parameters) and ("density" in kpt_settings): + params = recursive_update(params, self.user_params) + if ("k_grid" in params) and ("density" in kpt_settings): warn( "WARNING: the k_grid is set in user_params and in the kpt_settings," " using the one passed in user_params.", stacklevel=1, ) - elif isinstance(structure, Structure) and ("k_grid" not in parameters): + elif isinstance(structure, Structure) and ("k_grid" not in params): density = kpt_settings.get("density", 5.0) even = kpt_settings.get("even", True) - parameters["k_grid"] = self.d2k(structure, density, even) - elif isinstance(structure, Molecule) and "k_grid" in parameters: + params["k_grid"] = self.d2k(structure, density, even) + elif isinstance(structure, Molecule) and "k_grid" in params: warn("WARNING: removing unnecessary k_grid information", stacklevel=1) - del parameters["k_grid"] + del params["k_grid"] - return parameters + return params def get_parameter_updates( self, @@ -365,7 +362,7 @@ def get_parameter_updates( def d2k( self, structure: Structure, - kptdensity: float | list[float] = 5.0, + kpt_density: float | tuple[float, float, float] = 5.0, even: bool = True, ) -> Iterable[float]: """Convert k-point density to Monkhorst-Pack grid size. @@ -375,15 +372,15 @@ def d2k( Args: structure (Structure): Contains unit cell and information about boundary conditions. - kptdensity (float | list[float]): Required k-point + kpt_density (float | list[float]): Required k-point density. Default value is 5.0 point per Ang^-1. even (bool): Round up to even numbers. Returns: dict: Monkhorst-Pack grid size in all directions """ - recipcell = structure.lattice.inv_matrix - return self.d2k_recipcell(recipcell, structure.lattice.pbc, kptdensity, even) + recip_cell = structure.lattice.inv_matrix.transpose() + return self.d2k_recip_cell(recip_cell, structure.lattice.pbc, kpt_density, even) def k2d(self, structure: Structure, k_grid: np.ndarray[int]): """Generate the kpoint density in each direction from given k_grid. @@ -397,36 +394,36 @@ def k2d(self, structure: Structure, k_grid: np.ndarray[int]): Returns: dict: Density of kpoints in each direction. result.mean() computes average density """ - recipcell = structure.lattice.inv_matrix - densities = k_grid / (2 * np.pi * np.sqrt((recipcell**2).sum(axis=1))) + recip_cell = structure.lattice.inv_matrix.transpose() + densities = k_grid / (2 * np.pi * np.sqrt((recip_cell**2).sum(axis=1))) return np.array(densities) @staticmethod - def d2k_recipcell( - recipcell: np.ndarray, + def d2k_recip_cell( + recip_cell: np.ndarray, pbc: Sequence[bool], - kptdensity: float | Sequence[float] = 5.0, + kpt_density: float | tuple[float, float, float] = 5.0, even: bool = True, ) -> Sequence[int]: """Convert k-point density to Monkhorst-Pack grid size. Args: - recipcell (Cell): The reciprocal cell + recip_cell (Cell): The reciprocal cell pbc (Sequence[bool]): If element of pbc is True then system is periodic in that direction - kptdensity (float or list[floats]): Required k-point - density. Default value is 3.5 point per Ang^-1. + kpt_density (float or list[floats]): Required k-point + density. Default value is 5 points per Ang^-1. even(bool): Round up to even numbers. Returns: dict: Monkhorst-Pack grid size in all directions """ - if not isinstance(kptdensity, Iterable): - kptdensity = 3 * [float(kptdensity)] + if isinstance(kpt_density, float): + kpt_density = (kpt_density, kpt_density, kpt_density) kpts: list[int] = [] for i in range(3): if pbc[i]: - k = 2 * np.pi * np.sqrt((recipcell[i] ** 2).sum()) * float(kptdensity[i]) + k = 2 * np.pi * np.sqrt((recip_cell[i] ** 2).sum()) * float(kpt_density[i]) if even: kpts.append(2 * int(np.ceil(k / 2))) else: diff --git a/src/pymatgen/io/ase.py b/src/pymatgen/io/ase.py index bae2dbc5a62..7058a9d2eee 100644 --- a/src/pymatgen/io/ase.py +++ b/src/pymatgen/io/ase.py @@ -12,6 +12,7 @@ import numpy as np from monty.json import MontyDecoder, MSONable, jsanitize + from pymatgen.core.structure import Molecule, Structure try: @@ -35,9 +36,10 @@ def __init__(self, *args, **kwargs): from typing import Any from numpy.typing import ArrayLike - from pymatgen.core.structure import SiteCollection from typing_extensions import Self + from pymatgen.core.structure import SiteCollection + __author__ = "Shyue Ping Ong, Andrew S. Rosen" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "1.0" diff --git a/src/pymatgen/io/atat.py b/src/pymatgen/io/atat.py index 98660051d6b..7badcab828c 100644 --- a/src/pymatgen/io/atat.py +++ b/src/pymatgen/io/atat.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np + from pymatgen.core import Lattice, Structure, get_el_sp __author__ = "Matthew Horton" diff --git a/src/pymatgen/io/babel.py b/src/pymatgen/io/babel.py index a648a783831..0d0e19871c4 100644 --- a/src/pymatgen/io/babel.py +++ b/src/pymatgen/io/babel.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING from monty.dev import requires + from pymatgen.core.structure import IMolecule, Molecule try: @@ -19,9 +20,10 @@ openbabel = pybel = None if TYPE_CHECKING: - from pymatgen.analysis.graphs import MoleculeGraph from typing_extensions import Self + from pymatgen.analysis.graphs import MoleculeGraph + __author__ = "Shyue Ping Ong, Qi Wang" __copyright__ = "Copyright 2012, The Materials Project" diff --git a/src/pymatgen/io/cif.py b/src/pymatgen/io/cif.py index 5020d66d83b..6b0a2453a71 100644 --- a/src/pymatgen/io/cif.py +++ b/src/pymatgen/io/cif.py @@ -19,6 +19,7 @@ from monty.dev import deprecated from monty.io import zopen from monty.serialization import loadfn + from pymatgen.core import Composition, DummySpecies, Element, Lattice, PeriodicSite, Species, Structure, get_el_sp from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.electronic_structure.core import Magmom @@ -32,9 +33,10 @@ from typing import Any from numpy.typing import NDArray - from pymatgen.util.typing import PathLike, Vector3D from typing_extensions import Self + from pymatgen.util.typing import PathLike, Vector3D + __author__ = "Shyue Ping Ong, Will Richards, Matthew Horton" diff --git a/src/pymatgen/io/common.py b/src/pymatgen/io/common.py index 0949bf08306..00a0cac7932 100644 --- a/src/pymatgen/io/common.py +++ b/src/pymatgen/io/common.py @@ -11,10 +11,11 @@ import numpy as np from monty.io import zopen from monty.json import MSONable +from scipy.interpolate import RegularGridInterpolator + from pymatgen.core import Element, Site, Structure from pymatgen.core.units import ang_to_bohr, bohr_to_angstrom from pymatgen.electronic_structure.core import Spin -from scipy.interpolate import RegularGridInterpolator if TYPE_CHECKING: from pathlib import Path diff --git a/src/pymatgen/io/core.py b/src/pymatgen/io/core.py index c4c944706eb..acdbea51b96 100644 --- a/src/pymatgen/io/core.py +++ b/src/pymatgen/io/core.py @@ -37,9 +37,10 @@ from monty.json import MSONable if TYPE_CHECKING: - from pymatgen.util.typing import PathLike from typing_extensions import Self + from pymatgen.util.typing import PathLike + __author__ = "Ryan Kingsbury" __email__ = "RKingsbury@lbl.gov" @@ -176,6 +177,14 @@ def __setitem__(self, key: PathLike, value: str | InputFile) -> None: def __delitem__(self, key: PathLike) -> None: del self.inputs[key] + def __or__(self, other: dict | Self) -> Self: + # enable dict merge operator | for InputSet + if isinstance(other, dict): + other = type(self)(other) + if not isinstance(other, type(self)): + return NotImplemented + return type(self)({**self.inputs, **other.inputs}, **self._kwargs) + def write_input( self, directory: PathLike, diff --git a/src/pymatgen/io/cp2k/inputs.py b/src/pymatgen/io/cp2k/inputs.py index 83054d4380b..ae0c020c2a2 100644 --- a/src/pymatgen/io/cp2k/inputs.py +++ b/src/pymatgen/io/cp2k/inputs.py @@ -37,6 +37,7 @@ from monty.dev import deprecated from monty.io import zopen from monty.json import MSONable + from pymatgen.core import Element from pymatgen.io.cp2k.utils import chunk, postprocessor, preprocessor from pymatgen.io.vasp.inputs import Kpoints as VaspKpoints @@ -47,10 +48,11 @@ from collections.abc import Sequence from typing import Any, Literal + from typing_extensions import Self + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Molecule, Structure from pymatgen.util.typing import Kpoint, Tuple3Ints - from typing_extensions import Self __author__ = "Nicholas Winner" __version__ = "2.0" @@ -332,6 +334,10 @@ def __add__(self, other) -> Section: def __setitem__(self, key, value): self.setitem(key, value) + # dict merge + def __or__(self, other: dict) -> Section: + return self.update(other) + def setitem(self, key, value, strict=False): """ Helper function for setting items. Kept separate from the double-underscore function so that @@ -416,7 +422,7 @@ def get_keyword(self, d, default=None): return v return default - def update(self, dct: dict, strict=False): + def update(self, dct: dict, strict=False) -> Section: """ Update the Section according to a dictionary argument. This is most useful for providing user-override settings to default parameters. As you pass a @@ -441,6 +447,7 @@ def update(self, dct: dict, strict=False): new sections and keywords. Default: False """ Section._update(self, dct, strict=strict) + return self @staticmethod def _update(d1, d2, strict=False): @@ -1971,18 +1978,16 @@ def __init__( else: raise ValueError("No k-points provided!") - keywords.update( - { - "SCHEME": Keyword("SCHEME", scheme), - "EPS_GEO": Keyword("EPS_GEO", eps_geo), - "FULL_GRID": Keyword("FULL_GRID", full_grid), - "PARALLEL_GROUP_SIZE": Keyword("PARALLEL_GROUP_SIZE", parallel_group_size), - "SYMMETRY": Keyword("SYMMETRY", symmetry), - "UNITS": Keyword("UNITS", units), - "VERBOSE": Keyword("VERBOSE", verbose), - "WAVEFUNCTIONS": Keyword("WAVEFUNCTIONS", wavefunctions), - } - ) + keywords |= { + "SCHEME": Keyword("SCHEME", scheme), + "EPS_GEO": Keyword("EPS_GEO", eps_geo), + "FULL_GRID": Keyword("FULL_GRID", full_grid), + "PARALLEL_GROUP_SIZE": Keyword("PARALLEL_GROUP_SIZE", parallel_group_size), + "SYMMETRY": Keyword("SYMMETRY", symmetry), + "UNITS": Keyword("UNITS", units), + "VERBOSE": Keyword("VERBOSE", verbose), + "WAVEFUNCTIONS": Keyword("WAVEFUNCTIONS", wavefunctions), + } super().__init__( name="KPOINTS", diff --git a/src/pymatgen/io/cp2k/outputs.py b/src/pymatgen/io/cp2k/outputs.py index 0fecc2bd98b..19c63761eba 100644 --- a/src/pymatgen/io/cp2k/outputs.py +++ b/src/pymatgen/io/cp2k/outputs.py @@ -17,6 +17,7 @@ from monty.io import zopen from monty.json import MSONable, jsanitize from monty.re import regrep + from pymatgen.core.structure import Molecule, Structure from pymatgen.core.units import Ha_to_eV from pymatgen.electronic_structure.bandstructure import BandStructure, BandStructureSymmLine diff --git a/src/pymatgen/io/cp2k/sets.py b/src/pymatgen/io/cp2k/sets.py index 9af72d01594..16a79842efa 100644 --- a/src/pymatgen/io/cp2k/sets.py +++ b/src/pymatgen/io/cp2k/sets.py @@ -24,6 +24,8 @@ import warnings import numpy as np +from ruamel.yaml import YAML + from pymatgen.core import SETTINGS from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Element, Molecule, Structure @@ -65,7 +67,6 @@ from pymatgen.io.cp2k.utils import get_truncated_coulomb_cutoff, get_unique_site_indices from pymatgen.io.vasp.inputs import Kpoints as VaspKpoints from pymatgen.io.vasp.inputs import KpointsSupportedModes -from ruamel.yaml import YAML __author__ = "Nicholas Winner" __version__ = "2.0" @@ -828,7 +829,7 @@ def activate_hybrid( " distance between atoms. I hope you know what you're doing." ) - ip_keywords = {} + ip_keywords: dict[str, Keyword] = {} if hybrid_functional == "HSE06": pbe = PBE("ORIG", scale_c=1, scale_x=0) xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) @@ -845,13 +846,11 @@ def activate_hybrid( }, ) ) - ip_keywords.update( - { - "POTENTIAL_TYPE": Keyword("POTENTIAL_TYPE", potential_type), - "OMEGA": Keyword("OMEGA", 0.11), - "CUTOFF_RADIUS": Keyword("CUTOFF_RADIUS", cutoff_radius), - } - ) + ip_keywords |= { + "POTENTIAL_TYPE": Keyword("POTENTIAL_TYPE", potential_type), + "OMEGA": Keyword("OMEGA", 0.11), + "CUTOFF_RADIUS": Keyword("CUTOFF_RADIUS", cutoff_radius), + } elif hybrid_functional == "PBE0": pbe = PBE("ORIG", scale_c=1, scale_x=1 - hf_fraction) xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) @@ -884,16 +883,14 @@ def activate_hybrid( potential_type = potential_type or "MIX_CL_TRUNC" hf_fraction = 1 - ip_keywords.update( - { - "POTENTIAL_TYPE": Keyword("POTENTIAL_TYPE", potential_type), - "CUTOFF_RADIUS": Keyword("CUTOFF_RADIUS", cutoff_radius), - "T_C_G_DATA": Keyword("T_C_G_DATA", "t_c_g.dat"), - "OMEGA": Keyword("OMEGA", omega), - "SCALE_COULOMB": Keyword("SCALE_COULOMB", scale_coulomb), - "SCALE_LONGRANGE": Keyword("SCALE_LONGRANGE", scale_longrange - scale_coulomb), - } - ) + ip_keywords |= { + "POTENTIAL_TYPE": Keyword("POTENTIAL_TYPE", potential_type), + "CUTOFF_RADIUS": Keyword("CUTOFF_RADIUS", cutoff_radius), + "T_C_G_DATA": Keyword("T_C_G_DATA", "t_c_g.dat"), + "OMEGA": Keyword("OMEGA", omega), + "SCALE_COULOMB": Keyword("SCALE_COULOMB", scale_coulomb), + "SCALE_LONGRANGE": Keyword("SCALE_LONGRANGE", scale_longrange - scale_coulomb), + } xc_functional.insert( Section( "XWPBE", @@ -923,17 +920,15 @@ def activate_hybrid( pbe = PBE("ORIG", scale_c=gga_c_fraction, scale_x=gga_x_fraction) xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) - ip_keywords.update( - { - "POTENTIAL_TYPE": Keyword("POTENTIAL_TYPE", potential_type), - "CUTOFF_RADIUS": Keyword("CUTOFF_RADIUS", cutoff_radius), - "T_C_G_DATA": Keyword("T_C_G_DATA", "t_c_g.dat"), - "SCALE_COULOMB": Keyword("SCALE_COULOMB", scale_coulomb), - "SCALE_GAUSSIAN": Keyword("SCALE_GAUSSIAN", scale_gaussian), - "SCALE_LONGRANGE": Keyword("SCALE_LONGRANGE", scale_longrange), - "OMEGA": Keyword("OMEGA", omega), - } - ) + ip_keywords |= { + "POTENTIAL_TYPE": Keyword("POTENTIAL_TYPE", potential_type), + "CUTOFF_RADIUS": Keyword("CUTOFF_RADIUS", cutoff_radius), + "T_C_G_DATA": Keyword("T_C_G_DATA", "t_c_g.dat"), + "SCALE_COULOMB": Keyword("SCALE_COULOMB", scale_coulomb), + "SCALE_GAUSSIAN": Keyword("SCALE_GAUSSIAN", scale_gaussian), + "SCALE_LONGRANGE": Keyword("SCALE_LONGRANGE", scale_longrange), + "OMEGA": Keyword("OMEGA", omega), + } interaction_potential = Section("INTERACTION_POTENTIAL", subsections={}, keywords=ip_keywords) @@ -1097,14 +1092,14 @@ def activate_epr(self, **kwargs) -> None: if not self.check("force_eval/properties/linres/localize"): self.activate_localize() self["FORCE_EVAL"]["PROPERTIES"]["LINRES"].insert(Section("EPR", **kwargs)) - self["FORCE_EVAL"]["PROPERTIES"]["LINRES"]["EPR"].update({"PRINT": {"G_TENSOR": {}}}) + self["FORCE_EVAL"]["PROPERTIES"]["LINRES"]["EPR"] |= {"PRINT": {"G_TENSOR": {}}} def activate_nmr(self, **kwargs) -> None: """Calculate nmr shifts. Requires localize. Suggested with GAPW.""" if not self.check("force_eval/properties/linres/localize"): self.activate_localize() self["FORCE_EVAL"]["PROPERTIES"]["LINRES"].insert(Section("NMR", **kwargs)) - self["FORCE_EVAL"]["PROPERTIES"]["LINRES"]["NMR"].update({"PRINT": {"CHI_TENSOR": {}, "SHIELDING_TENSOR": {}}}) + self["FORCE_EVAL"]["PROPERTIES"]["LINRES"]["NMR"] |= {"PRINT": {"CHI_TENSOR": {}, "SHIELDING_TENSOR": {}}} def activate_spinspin(self, **kwargs) -> None: """Calculate spin-spin coupling tensor. Requires localize.""" @@ -1185,7 +1180,7 @@ def activate_fast_minimization(self, on) -> None: algorithm="IRAC", linesearch="2PNT", ) - self.update({"FORCE_EVAL": {"DFT": {"SCF": {"OT": ot}}}}) + self |= {"FORCE_EVAL": {"DFT": {"SCF": {"OT": ot}}}} # type: ignore[assignment] def activate_robust_minimization(self) -> None: """Modify the set to use more robust SCF minimization technique.""" @@ -1195,7 +1190,7 @@ def activate_robust_minimization(self) -> None: algorithm="STRICT", linesearch="3PNT", ) - self.update({"FORCE_EVAL": {"DFT": {"SCF": {"OT": ot}}}}) + self |= {"FORCE_EVAL": {"DFT": {"SCF": {"OT": ot}}}} # type: ignore[assignment] def activate_very_strict_minimization(self) -> None: """Method to modify the set to use very strict SCF minimization scheme.""" @@ -1205,7 +1200,7 @@ def activate_very_strict_minimization(self) -> None: algorithm="STRICT", linesearch="GOLD", ) - self.update({"FORCE_EVAL": {"DFT": {"SCF": {"OT": ot}}}}) + self |= {"FORCE_EVAL": {"DFT": {"SCF": {"OT": ot}}}} # type: ignore[assignment] def activate_nonperiodic(self, solver="ANALYTIC") -> None: """ diff --git a/src/pymatgen/io/cssr.py b/src/pymatgen/io/cssr.py index 52d02996815..112292cf6b8 100644 --- a/src/pymatgen/io/cssr.py +++ b/src/pymatgen/io/cssr.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from monty.io import zopen + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure @@ -31,7 +32,7 @@ class Cssr: def __init__(self, structure: Structure): """ Args: - structure (Structure/IStructure): A structure to create the Cssr object. + structure (Structure | IStructure): A structure to create the Cssr object. """ if not structure.is_ordered: raise ValueError("Cssr file can only be constructed from ordered structure") diff --git a/src/pymatgen/io/exciting/inputs.py b/src/pymatgen/io/exciting/inputs.py index 2a61fa77184..13f280c8fa0 100644 --- a/src/pymatgen/io/exciting/inputs.py +++ b/src/pymatgen/io/exciting/inputs.py @@ -10,6 +10,7 @@ import scipy.constants as const from monty.io import zopen from monty.json import MSONable + from pymatgen.core import Element, Lattice, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.symmetry.bandstructure import HighSymmKpath diff --git a/src/pymatgen/io/feff/inputs.py b/src/pymatgen/io/feff/inputs.py index 52e95a14521..f4316fe121c 100644 --- a/src/pymatgen/io/feff/inputs.py +++ b/src/pymatgen/io/feff/inputs.py @@ -15,13 +15,14 @@ import numpy as np from monty.io import zopen from monty.json import MSONable +from tabulate import tabulate + from pymatgen.core import Element, Lattice, Molecule, Structure from pymatgen.io.cif import CifParser from pymatgen.io.core import ParseError from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.io_utils import clean_lines from pymatgen.util.string import str_delimited -from tabulate import tabulate if TYPE_CHECKING: from typing_extensions import Self @@ -374,7 +375,7 @@ def __init__(self, struct, absorbing_atom, radius): """ Args: struct (Structure): input structure - absorbing_atom (str/int): Symbol for absorbing atom or site index + absorbing_atom (str | int): Symbol for absorbing atom or site index radius (float): radius of the atom cluster in Angstroms. """ if not struct.is_ordered: @@ -777,7 +778,7 @@ def __init__(self, struct, absorbing_atom): """ Args: struct (Structure): Structure object. - absorbing_atom (str/int): Absorbing atom symbol or site index. + absorbing_atom (str | int): Absorbing atom symbol or site index. """ if not struct.is_ordered: raise ValueError("Structure with partial occupancies cannot be converted into atomic coordinates!") @@ -983,7 +984,7 @@ def get_absorbing_atom_symbol_index(absorbing_atom, structure): """Get the absorbing atom symbol and site index in the given structure. Args: - absorbing_atom (str/int): symbol or site index + absorbing_atom (str | int): symbol or site index structure (Structure) Returns: diff --git a/src/pymatgen/io/feff/outputs.py b/src/pymatgen/io/feff/outputs.py index 4337ee78741..9d5d00137d9 100644 --- a/src/pymatgen/io/feff/outputs.py +++ b/src/pymatgen/io/feff/outputs.py @@ -13,6 +13,7 @@ import numpy as np from monty.io import zopen from monty.json import MSONable + from pymatgen.core import Element from pymatgen.electronic_structure.core import Orbital, Spin from pymatgen.electronic_structure.dos import CompleteDos, Dos @@ -280,7 +281,7 @@ def __init__(self, header, parameters, absorbing_atom, data): Args: header: Header object parameters: Tags object - absorbing_atom (str/int): absorbing atom symbol or index + absorbing_atom (str | int): absorbing atom symbol or index data (numpy.ndarray, Nx6): cross_sections. """ self.header = header @@ -379,12 +380,12 @@ def as_dict(self): class Eels(MSONable): - """Parse'eels.dat' file.""" + """Parse eels.dat file.""" def __init__(self, data): """ Args: - data (): Eels data. + data (numpy.ndarray): data from eels.dat file """ self.data = np.array(data) diff --git a/src/pymatgen/io/feff/sets.py b/src/pymatgen/io/feff/sets.py index 4231f1772ad..83a3945fd65 100644 --- a/src/pymatgen/io/feff/sets.py +++ b/src/pymatgen/io/feff/sets.py @@ -20,6 +20,7 @@ from monty.json import MSONable from monty.os.path import zpath from monty.serialization import loadfn + from pymatgen.core.structure import Molecule, Structure from pymatgen.io.feff.inputs import Atoms, Header, Potential, Tags @@ -80,7 +81,7 @@ def all_input(self): dct = {"HEADER": self.header(), "PARAMETERS": self.tags} if "RECIPROCAL" not in self.tags: - dct.update({"POTENTIALS": self.potential, "ATOMS": self.atoms}) + dct |= {"POTENTIALS": self.potential, "ATOMS": self.atoms} return dct @@ -133,7 +134,7 @@ def __init__( ): """ Args: - absorbing_atom (str/int): absorbing atom symbol or site index + absorbing_atom (str | int): absorbing atom symbol or site index structure: Structure or Molecule object. If a Structure, SpaceGroupAnalyzer is used to determine symmetrically-equivalent sites. If a Molecule, there is no symmetry checking. @@ -365,7 +366,7 @@ def __init__( ): r""" Args: - absorbing_atom (str/int): absorbing atom symbol or site index + absorbing_atom (str | int): absorbing atom symbol or site index structure (Structure): input edge (str): absorption edge radius (float): cluster radius in Angstroms. @@ -404,7 +405,7 @@ def __init__( ): r""" Args: - absorbing_atom (str/int): absorbing atom symbol or site index + absorbing_atom (str | int): absorbing atom symbol or site index structure (Structure): input structure edge (str): absorption edge radius (float): cluster radius in Angstroms. @@ -448,7 +449,7 @@ def __init__( ): """ Args: - absorbing_atom (str/int): absorbing atom symbol or site index + absorbing_atom (str | int): absorbing atom symbol or site index structure (Structure): input structure edge (str): absorption edge spectrum (str): ELNES or EXELFS @@ -519,7 +520,7 @@ def __init__( ): r""" Args: - absorbing_atom (str/int): absorbing atom symbol or site index + absorbing_atom (str | int): absorbing atom symbol or site index structure (Structure): input structure edge (str): absorption edge radius (float): cluster radius in Angstroms. @@ -575,7 +576,7 @@ def __init__( ): r""" Args: - absorbing_atom (str/int): absorbing atom symbol or site index + absorbing_atom (str | int): absorbing atom symbol or site index structure (Structure): input structure edge (str): absorption edge radius (float): cluster radius in Angstroms. diff --git a/src/pymatgen/io/fiesta.py b/src/pymatgen/io/fiesta.py index dc75893c37a..1c17ced40fa 100644 --- a/src/pymatgen/io/fiesta.py +++ b/src/pymatgen/io/fiesta.py @@ -17,14 +17,16 @@ from monty.io import zopen from monty.json import MSONable + from pymatgen.core.structure import Molecule if TYPE_CHECKING: from pathlib import Path - from pymatgen.util.typing import Tuple3Ints from typing_extensions import Self + from pymatgen.util.typing import Tuple3Ints + __author__ = "ndardenne" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -37,7 +39,7 @@ class Nwchem2Fiesta(MSONable): If nwchem.nw is the input, nwchem.out the output, and structure.movecs the "movecs" file, the syntax to run NWCHEM2FIESTA is: NWCHEM2FIESTA - nwchem.nw nwchem.nwout structure.movecs > log_n2f + nwchem.nw nwchem.nwout structure.movecs > log_n2f """ def __init__(self, folder, filename="nwchem", log_file="log_n2f"): diff --git a/src/pymatgen/io/gaussian.py b/src/pymatgen/io/gaussian.py index 36102e6c334..39e7417beec 100644 --- a/src/pymatgen/io/gaussian.py +++ b/src/pymatgen/io/gaussian.py @@ -9,13 +9,14 @@ import numpy as np import scipy.constants as cst from monty.io import zopen +from scipy.stats import norm + from pymatgen.core import Composition, Element, Molecule from pymatgen.core.operations import SymmOp from pymatgen.core.units import Ha_to_eV from pymatgen.electronic_structure.core import Spin from pymatgen.util.coord import get_angle from pymatgen.util.plotting import pretty_plot -from scipy.stats import norm if TYPE_CHECKING: from pathlib import Path @@ -1002,7 +1003,7 @@ def _parse_hessian(self, file, structure): structure: structure in the output file """ # read Hessian matrix under "Force constants in Cartesian coordinates" - # Hessian matrix is in the input orientation framework + # Hessian matrix is in the input orientation framework # WARNING : need #P in the route line ndf = 3 * len(structure) diff --git a/src/pymatgen/io/lammps/data.py b/src/pymatgen/io/lammps/data.py index 46ac9b97765..5fb50b053d3 100644 --- a/src/pymatgen/io/lammps/data.py +++ b/src/pymatgen/io/lammps/data.py @@ -27,18 +27,20 @@ from monty.io import zopen from monty.json import MSONable from monty.serialization import loadfn +from ruamel.yaml import YAML + from pymatgen.core import Element, Lattice, Molecule, Structure from pymatgen.core.operations import SymmOp from pymatgen.util.io_utils import clean_lines -from ruamel.yaml import YAML if TYPE_CHECKING: from collections.abc import Sequence from typing import Any, Literal + from typing_extensions import Self + from pymatgen.core.sites import Site from pymatgen.core.structure import SiteCollection - from typing_extensions import Self __author__ = "Kiran Mathew, Zhi Deng, Tingzheng Hou" __copyright__ = "Copyright 2018, The Materials Virtual Lab" @@ -553,7 +555,7 @@ def disassemble( for t in ff_df.itertuples(index=True, name=None): coeffs_dict = {"coeffs": list(t[1:]), "types": []} if class2_coeffs: - coeffs_dict.update({k: list(v[t[0] - 1]) for k, v in class2_coeffs.items()}) + coeffs_dict |= {k: list(v[t[0] - 1]) for k, v in class2_coeffs.items()} topo_coeffs[kw].append(coeffs_dict) if self.topology: @@ -797,7 +799,7 @@ def from_ff_and_topologies( topology[key] = df[SECTION_HEADERS[key]] topology = {key: values for key, values in topology.items() if not values.empty} - items.update({"atoms": atoms, "velocities": velocities, "topology": topology}) + items |= {"atoms": atoms, "velocities": velocities, "topology": topology} return cls(**items) @classmethod @@ -1160,7 +1162,7 @@ def process_data(data) -> pd.DataFrame: all_data = {kw: process_data(main_data)} if class2_data: - all_data.update({k: process_data(v) for k, v in class2_data.items()}) + all_data |= {k: process_data(v) for k, v in class2_data.items()} return all_data, {f"{kw[:-7]}s": mapper} def to_file(self, filename: str) -> None: diff --git a/src/pymatgen/io/lammps/generators.py b/src/pymatgen/io/lammps/generators.py index 5e1c0a30950..afc2f39b5dc 100644 --- a/src/pymatgen/io/lammps/generators.py +++ b/src/pymatgen/io/lammps/generators.py @@ -15,6 +15,7 @@ from string import Template from monty.io import zopen + from pymatgen.core import Structure from pymatgen.io.core import InputGenerator from pymatgen.io.lammps.data import CombinedData, LammpsData diff --git a/src/pymatgen/io/lammps/inputs.py b/src/pymatgen/io/lammps/inputs.py index 5ea8477d5fd..953f0978b90 100644 --- a/src/pymatgen/io/lammps/inputs.py +++ b/src/pymatgen/io/lammps/inputs.py @@ -19,6 +19,7 @@ from monty.dev import deprecated from monty.io import zopen from monty.json import MSONable + from pymatgen.core import __version__ as CURRENT_VER from pymatgen.io.core import InputFile from pymatgen.io.lammps.data import CombinedData, LammpsData @@ -27,9 +28,10 @@ if TYPE_CHECKING: from os import PathLike - from pymatgen.io.core import InputSet from typing_extensions import Self + from pymatgen.io.core import InputSet + __author__ = "Kiran Mathew, Brandon Wood, Zhi Deng, Manas Likhit, Guillaume Brunin (Matgenix)" __copyright__ = "Copyright 2018, The Materials Virtual Lab" __version__ = "2.0" @@ -909,7 +911,7 @@ def md( with open(template_path, encoding="utf-8") as file: script_template = file.read() settings = other_settings.copy() if other_settings else {} - settings.update({"force_field": force_field, "temperature": temperature, "nsteps": nsteps}) + settings |= {"force_field": force_field, "temperature": temperature, "nsteps": nsteps} script_filename = "in.md" return cls( script_template=script_template, @@ -957,7 +959,7 @@ def get_input_set( # type: ignore[override] input_set = super().get_input_set(template=script_template, variables=settings, filename=script_filename) if data: - input_set.update({data_filename: data}) + input_set |= {data_filename: data} return input_set diff --git a/src/pymatgen/io/lammps/outputs.py b/src/pymatgen/io/lammps/outputs.py index b8e9688c06d..0c1b182ed27 100644 --- a/src/pymatgen/io/lammps/outputs.py +++ b/src/pymatgen/io/lammps/outputs.py @@ -14,6 +14,7 @@ import pandas as pd from monty.io import zopen from monty.json import MSONable + from pymatgen.io.lammps.data import LammpsBox if TYPE_CHECKING: @@ -163,19 +164,19 @@ def _parse_thermo(lines: list[str]) -> pd.DataFrame: # multi line thermo data if re.match(multi_pattern, lines[0]): timestep_marks = [idx for idx, line in enumerate(lines) if re.match(multi_pattern, line)] - timesteps = np.split(lines, timestep_marks)[1:] + time_steps = np.split(lines, timestep_marks)[1:] dicts = [] kv_pattern = r"([0-9A-Za-z_\[\]]+)\s+=\s+([0-9eE\.+-]+)" - for ts in timesteps: + for ts in time_steps: data = {} step = re.match(multi_pattern, ts[0]) assert step is not None data["Step"] = int(step[1]) - data.update({k: float(v) for k, v in re.findall(kv_pattern, "".join(ts[1:]))}) + data |= {k: float(v) for k, v in re.findall(kv_pattern, "".join(ts[1:]))} dicts.append(data) df = pd.DataFrame(dicts) # rearrange the sequence of columns - columns = ["Step"] + [k for k, v in re.findall(kv_pattern, "".join(timesteps[0][1:]))] + columns = ["Step"] + [k for k, v in re.findall(kv_pattern, "".join(time_steps[0][1:]))] df = df[columns] # one line thermo data else: diff --git a/src/pymatgen/io/lammps/sets.py b/src/pymatgen/io/lammps/sets.py index 54e63afd5f2..d61510f3d67 100644 --- a/src/pymatgen/io/lammps/sets.py +++ b/src/pymatgen/io/lammps/sets.py @@ -18,9 +18,10 @@ from pymatgen.io.lammps.inputs import LammpsInputFile if TYPE_CHECKING: - from pymatgen.util.typing import PathLike from typing_extensions import Self + from pymatgen.util.typing import PathLike + __author__ = "Ryan Kingsbury, Guillaume Brunin (Matgenix)" __copyright__ = "Copyright 2021, The Materials Project" __version__ = "0.2" diff --git a/src/pymatgen/io/lammps/utils.py b/src/pymatgen/io/lammps/utils.py index cc8537d8af6..2ccca0ea2e0 100644 --- a/src/pymatgen/io/lammps/utils.py +++ b/src/pymatgen/io/lammps/utils.py @@ -11,6 +11,7 @@ import numpy as np from monty.dev import deprecated from monty.tempfile import ScratchDir + from pymatgen.core.operations import SymmOp from pymatgen.core.structure import Molecule from pymatgen.io.babel import BabelMolAdaptor diff --git a/src/pymatgen/io/lmto.py b/src/pymatgen/io/lmto.py index 5c6057d8190..b8c36bafc6c 100644 --- a/src/pymatgen/io/lmto.py +++ b/src/pymatgen/io/lmto.py @@ -11,6 +11,7 @@ import numpy as np from monty.io import zopen + from pymatgen.core.structure import Structure from pymatgen.core.units import Ry_to_eV, bohr_to_angstrom from pymatgen.electronic_structure.core import Spin @@ -280,7 +281,7 @@ class LMTOCopl: "length": bond length} efermi (float): The Fermi energy in Ry or eV. energies (list): Sequence of energies in Ry or eV. - is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized. + is_spin_polarized (bool): True if the calculation is spin-polarized. """ def __init__(self, filename="COPL", to_eV=False): diff --git a/src/pymatgen/io/lobster/__init__.py b/src/pymatgen/io/lobster/__init__.py index f5e4a5f4234..e4b1100a2d6 100644 --- a/src/pymatgen/io/lobster/__init__.py +++ b/src/pymatgen/io/lobster/__init__.py @@ -1,5 +1,5 @@ """ -This package implements modules for input and output to and from Lobster. It +This package implements modules for input and output to and from LOBSTER. It imports the key classes form both lobster.inputs and lobster_outputs to allow most classes to be simply called as pymatgen.io.lobster.Lobsterin for example, to retain backwards compatibility. diff --git a/src/pymatgen/io/lobster/inputs.py b/src/pymatgen/io/lobster/inputs.py index f3759a38b93..9eef7c654a0 100644 --- a/src/pymatgen/io/lobster/inputs.py +++ b/src/pymatgen/io/lobster/inputs.py @@ -22,6 +22,7 @@ from monty.io import zopen from monty.json import MSONable from monty.serialization import loadfn + from pymatgen.core.structure import Structure from pymatgen.io.vasp import Vasprun from pymatgen.io.vasp.inputs import Incar, Kpoints, Potcar @@ -31,9 +32,10 @@ if TYPE_CHECKING: from typing import Any, ClassVar, Literal + from typing_extensions import Self + from pymatgen.core.composition import Composition from pymatgen.util.typing import PathLike, Tuple3Ints - from typing_extensions import Self MODULE_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/src/pymatgen/io/lobster/lobsterenv.py b/src/pymatgen/io/lobster/lobsterenv.py index 7982b025c99..919ea6bb678 100644 --- a/src/pymatgen/io/lobster/lobsterenv.py +++ b/src/pymatgen/io/lobster/lobsterenv.py @@ -20,6 +20,7 @@ import numpy as np from monty.dev import deprecated + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments @@ -31,9 +32,10 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure from pymatgen.core.periodic_table import Element - from typing_extensions import Self __author__ = "Janine George" __copyright__ = "Copyright 2021, The Materials Project" diff --git a/src/pymatgen/io/lobster/outputs.py b/src/pymatgen/io/lobster/outputs.py index d17c550ba9b..e671fa03fe2 100644 --- a/src/pymatgen/io/lobster/outputs.py +++ b/src/pymatgen/io/lobster/outputs.py @@ -22,6 +22,7 @@ import numpy as np from monty.io import zopen from monty.json import MSONable + from pymatgen.core.structure import Structure from pymatgen.electronic_structure.bandstructure import LobsterBandStructureSymmLine from pymatgen.electronic_structure.core import Orbital, Spin @@ -64,7 +65,7 @@ class Cohpcar: efermi (float): The Fermi energy in eV. energies (Sequence[float]): Sequence of energies in eV. Note that LOBSTER shifts the energies so that the Fermi energy is at zero. - is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized. + is_spin_polarized (bool): True if the calculation is spin polarized. orb_cohp (dict[str, Dict[str, Dict[str, Any]]]): A dictionary containing the orbital-resolved COHPs of the form: orb_cohp[label] = {bond_data["orb_label"]: { "COHP": {Spin.up: cohps, Spin.down:cohps}, @@ -165,18 +166,16 @@ def __init__( } elif label in orb_cohp: - orb_cohp[label].update( - { - bond_data["orb_label"]: { - "COHP": cohp, - "ICOHP": icohp, - "orbitals": orbs, - "length": bond_data["length"], - "sites": bond_data["sites"], - "cells": bond_data["cells"], - } + orb_cohp[label] |= { + bond_data["orb_label"]: { + "COHP": cohp, + "ICOHP": icohp, + "orbitals": orbs, + "length": bond_data["length"], + "sites": bond_data["sites"], + "cells": bond_data["cells"], } - ) + } else: # present for Lobster versions older than Lobster 2.2.0 if bond_num == 0: @@ -218,17 +217,15 @@ def __init__( } elif label in orb_cohp: - orb_cohp[label].update( - { - bond_data["orb_label"]: { - "COHP": cohp, - "ICOHP": icohp, - "orbitals": orbs, - "length": bond_data["length"], - "sites": bond_data["sites"], - } + orb_cohp[label] |= { + bond_data["orb_label"]: { + "COHP": cohp, + "ICOHP": icohp, + "orbitals": orbs, + "length": bond_data["length"], + "sites": bond_data["sites"], } - ) + } else: # present for Lobster versions older than Lobster 2.2.0 if bond_num == 0: @@ -278,6 +275,7 @@ def _get_bond_data(line: str, are_multi_center_cobis: bool = False) -> dict: indices, a tuple containing the orbitals (if orbital-resolved), and a label for the orbitals (if orbital-resolved). """ + if not are_multi_center_cobis: line_new = line.rsplit("(", 1) length = float(line_new[-1][:-1]) @@ -327,8 +325,8 @@ class Icohplist(MSONable): """Read ICOHPLIST/ICOOPLIST files generated by LOBSTER. Attributes: - are_coops (bool): Indicates whether the object is consisting of COOPs. - is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized. + are_coops (bool): Indicates whether the object consists of COOPs. + is_spin_polarized (bool): True if the calculation is spin polarized. Icohplist (dict[str, Dict[str, Union[float, int, Dict[Spin, float]]]]): Dict containing the listfile data of the form: { bond: "length": bond length, @@ -356,7 +354,7 @@ def __init__( filename: Name of the ICOHPLIST file. If it is None, the default file name will be chosen, depending on the value of are_coops is_spin_polarized: Boolean to indicate if the calculation is spin polarized - icohpcollection: IcohpCollection Object. + icohpcollection: IcohpCollection Object """ self._filename = filename @@ -383,14 +381,14 @@ def __init__( if len(data) == 0: raise RuntimeError("ICOHPLIST file contains no data.") - # Which Lobster version? + # Determine LOBSTER version if len(data[0].split()) == 8: version = "3.1.1" elif len(data[0].split()) == 6: version = "2.2.1" - warnings.warn("Please consider using the new Lobster version. See www.cohp.de.") + warnings.warn("Please consider using the new LOBSTER version. See www.cohp.de.") else: - raise ValueError + raise ValueError("Unsupported LOBSTER version.") # If the calculation is spin polarized, the line in the middle # of the file will be another header line. @@ -407,9 +405,9 @@ def __init__( data_orbitals = [] for line in data: if "_" not in line.split()[1]: - data_without_orbitals += [line] + data_without_orbitals.append(line) else: - data_orbitals += [line] + data_orbitals.append(line) else: data_without_orbitals = data @@ -422,49 +420,45 @@ def __init__( else: n_bonds = len(data_without_orbitals) - labels, atoms1, atoms2, lens, translations, nums, icohps = [], [], [], [], [], [], [] - - # initialize static variables - label = "" - atom1 = "" - atom2 = "" - length = None - num = None - translation = [] + labels: list[str] = [] + atoms1: list[str] = [] + atoms2: list[str] = [] + lens: list[float] = [] + translations: list[tuple[int, int, int]] = [] + nums: list[int] = [] + icohps: list[dict[Spin, float]] = [] for bond in range(n_bonds): - line = data_without_orbitals[bond].split() - icohp = {} + line_parts = data_without_orbitals[bond].split() + + label = f"{line_parts[0]}" + atom1 = str(line_parts[1]) + atom2 = str(line_parts[2]) + length = float(line_parts[3]) + + icohp: dict[Spin, float] = {} if version == "2.2.1": - label = f"{line[0]}" - atom1 = str(line[1]) - atom2 = str(line[2]) - length = float(line[3]) - icohp[Spin.up] = float(line[4]) - num = int(line[5]) - translation = [0, 0, 0] + icohp[Spin.up] = float(line_parts[4]) + num = int(line_parts[5]) + translation = (0, 0, 0) if self.is_spin_polarized: icohp[Spin.down] = float(data_without_orbitals[bond + n_bonds + 1].split()[4]) - elif version == "3.1.1": - label = f"{line[0]}" - atom1 = str(line[1]) - atom2 = str(line[2]) - length = float(line[3]) - translation = [int(line[4]), int(line[5]), int(line[6])] - icohp[Spin.up] = float(line[7]) + else: # version == "3.1.1" + translation = (int(line_parts[4]), int(line_parts[5]), int(line_parts[6])) + icohp[Spin.up] = float(line_parts[7]) num = 1 if self.is_spin_polarized: icohp[Spin.down] = float(data_without_orbitals[bond + n_bonds + 1].split()[7]) - labels += [label] - atoms1 += [atom1] - atoms2 += [atom2] - lens += [length] - translations += [translation] - nums += [num] - icohps += [icohp] + labels.append(label) + atoms1.append(atom1) + atoms2.append(atom2) + lens.append(length) + translations.append(translation) + nums.append(num) + icohps.append(icohp) list_orb_icohp: list[dict] | None = None if self.orbitalwise: @@ -474,17 +468,17 @@ def __init__( for i_data_orb in range(n_orbs): data_orb = data_orbitals[i_data_orb] icohp = {} - line = data_orb.split() - label = f"{line[0]}" + line_parts = data_orb.split() + label = f"{line_parts[0]}" orbs = re.findall(r"_(.*?)(?=\s)", data_orb) orb_label, orbitals = get_orb_from_str(orbs) - icohp[Spin.up] = float(line[7]) + icohp[Spin.up] = float(line_parts[7]) if self.is_spin_polarized: icohp[Spin.down] = float(data_orbitals[n_orbs + i_data_orb].split()[7]) if len(list_orb_icohp) < int(label): - list_orb_icohp += [{orb_label: {"icohp": icohp, "orbitals": orbitals}}] + list_orb_icohp.append({orb_label: {"icohp": icohp, "orbitals": orbitals}}) else: list_orb_icohp[int(label) - 1][orb_label] = {"icohp": icohp, "orbitals": orbitals} @@ -498,7 +492,7 @@ def __init__( list_atom1=atoms1, list_atom2=atoms2, list_length=lens, - list_translation=translations, + list_translation=translations, # type: ignore[arg-type] list_num=nums, list_icohp=icohps, is_spin_polarized=self.is_spin_polarized, @@ -529,7 +523,7 @@ class NciCobiList: """Read NcICOBILIST (multi-center ICOBI) files generated by LOBSTER. Attributes: - is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized. + is_spin_polarized (bool): True if the calculation is spin polarized. NciCobiList (dict): Dict containing the listfile data of the form: {bond: "number_of_atoms": number of atoms involved in the multi-center interaction, "ncicobi": {Spin.up: Nc-ICOBI(Ef) spin up, Spin.down: ...}}, @@ -541,11 +535,12 @@ def __init__( filename: PathLike | None = "NcICOBILIST.lobster", ) -> None: """ - LOBSTER < 4.1.0: no COBI/ICOBI/NcICOBI. + LOBSTER < 4.1.0: no COBI/ICOBI/NcICOBI Args: filename: Name of the NcICOBILIST file. """ + # LOBSTER list files have an extra trailing blank line # and we don't need the header. with zopen(filename, mode="rt") as file: @@ -573,7 +568,7 @@ def __init__( data_without_orbitals = [] for line in data: if "_" not in str(line.split()[3:]) and "s]" not in str(line.split()[3:]): - data_without_orbitals += [line] + data_without_orbitals.append(line) else: data_without_orbitals = data @@ -604,17 +599,19 @@ def __init__( if self.is_spin_polarized: ncicobi[Spin.down] = float(data_without_orbitals[bond + n_bonds + 1].split()[2]) - self.list_labels += [label] - self.list_n_atoms += [n_atoms] - self.list_ncicobi += [ncicobi] - self.list_interaction_type += [interaction_type] - self.list_num += [num] + self.list_labels.append(label) + self.list_n_atoms.append(n_atoms) + self.list_ncicobi.append(ncicobi) + self.list_interaction_type.append(interaction_type) + self.list_num.append(num) # TODO: add functions to get orbital resolved NcICOBIs @property def ncicobi_list(self) -> dict[Any, dict[str, Any]]: - """Returns: ncicobilist.""" + """ + Returns: ncicobilist. + """ ncicobi_list = {} for idx in range(len(self.list_labels)): ncicobi_list[str(idx + 1)] = { @@ -645,7 +642,7 @@ class Doscar: the Spin.up contribution at each of the energies. itdensities[Spin.down]: numpy array of the total density of states for the Spin.down contribution at each of the energies. If is_spin_polarized=False, itdensities[Spin.up]: numpy array of the total density of states. - is_spin_polarized (bool): Boolean. Tells if the system is spin polarized. + is_spin_polarized (bool): Whether the system is spin polarized. """ def __init__( @@ -687,7 +684,7 @@ def _parse_doscar(self): for nd in range(1, ndos): line = file.readline().split() cdos[nd] = np.array(line) - dos += [cdos] + dos.append(cdos) doshere = np.array(dos[0]) if len(doshere[0, :]) == 5: self._is_spin_polarized = True @@ -709,7 +706,7 @@ def _parse_doscar(self): for orb_num, j in enumerate(range(1, ncol)): orb = orbitals[atom + 1][orb_num] pdos[orb][spin] = data[:, j] - pdoss += [pdos] + pdoss.append(pdos) else: tdensities[Spin.up] = doshere[:, 1] tdensities[Spin.down] = doshere[:, 2] @@ -727,7 +724,7 @@ def _parse_doscar(self): pdos[orb][spin] = data[:, j] if j % 2 == 0: orb_num += 1 - pdoss += [pdos] + pdoss.append(pdos) self._efermi = efermi self._pdos = pdoss @@ -743,32 +740,32 @@ def _parse_doscar(self): @property def completedos(self) -> LobsterCompleteDos: - """LobsterCompleteDos.""" + """LobsterCompleteDos""" return self._completedos @property def pdos(self) -> list: - """Projected DOS.""" + """Projected DOS""" return self._pdos @property def tdos(self) -> Dos: - """Total DOS.""" + """Total DOS""" return self._tdos @property def energies(self) -> np.ndarray: - """Energies.""" + """Energies""" return self._energies @property def tdensities(self) -> dict[Spin, np.ndarray]: - """Total densities as a np.ndarray.""" + """total densities as a np.ndarray""" return self._tdensities @property def itdensities(self) -> dict[Spin, np.ndarray]: - """Integrated total densities as a np.ndarray.""" + """integrated total densities as a np.ndarray""" return self._itdensities @property @@ -804,7 +801,7 @@ def __init__( atomlist: list of atoms in the structure types: list of unique species in the structure mulliken: list of Mulliken charges - loewdin: list of Loewdin charges. + loewdin: list of Loewdin charges """ self._filename = filename self.num_atoms = num_atoms @@ -820,15 +817,15 @@ def __init__( raise RuntimeError("CHARGE file contains no data.") self.num_atoms = len(data) - for atom in range(self.num_atoms): - line = data[atom].split() - self.atomlist += [line[1] + line[0]] - self.types += [line[1]] - self.mulliken += [float(line[2])] - self.loewdin += [float(line[3])] + for atom_idx in range(self.num_atoms): + line_parts = data[atom_idx].split() + self.atomlist.append(line_parts[1] + line_parts[0]) + self.types.append(line_parts[1]) + self.mulliken.append(float(line_parts[2])) + self.loewdin.append(float(line_parts[3])) def get_structure_with_charges(self, structure_filename: PathLike) -> Structure: - """Get a Structure with Mulliken and Loewdin charges as site properties. + """Get a Structure with Mulliken and Loewdin charges as site properties Args: structure_filename: filename of POSCAR @@ -877,8 +874,8 @@ class Lobsterout(MSONable): has_grosspopulation (bool): Whether GROSSPOP.lobster is present. info_lines (str): String with additional infos on the run. info_orthonormalization (str): String with infos on orthonormalization. - is_restart_from_projection (bool): Boolean that indicates that calculation was restarted - from existing projection file. + is_restart_from_projection (bool): Whether calculation was restarted from existing + projection file. lobster_version (str): String that indicates Lobster version. number_of_spins (int): Integer indicating the number of spins. number_of_threads (int): Integer that indicates how many threads were used. @@ -925,7 +922,7 @@ def __init__(self, filename: PathLike | None, **kwargs) -> None: """ Args: filename: The lobsterout file. - **kwargs: dict to initialize Lobsterout instance. + **kwargs: dict to initialize Lobsterout instance """ self.filename = filename if kwargs: @@ -1007,7 +1004,7 @@ def __init__(self, filename: PathLike | None, **kwargs) -> None: raise ValueError("must provide either filename or kwargs to initialize Lobsterout") def get_doc(self) -> dict[str, Any]: - """Get the LobsterDict with all the information stored in lobsterout.""" + """Get a dict with all the information in lobsterout.""" return { # Check if LOBSTER starts from a projection "restart_from_projection": self.is_restart_from_projection, @@ -1038,7 +1035,7 @@ def get_doc(self) -> dict[str, Any]: } def as_dict(self) -> dict: - """MSONable dict.""" + """MSONable dict""" dct = dict(vars(self)) dct["@module"] = type(self).__module__ dct["@class"] = type(self).__name__ @@ -1089,9 +1086,9 @@ def _get_spillings(data, number_of_spins): splitrow = row.split() if len(splitrow) > 2 and splitrow[2] == "spilling:": if splitrow[1] == "charge": - charge_spilling += [np.float64(splitrow[3].replace("%", "")) / 100.0] + charge_spilling.append(np.float64(splitrow[3].replace("%", "")) / 100.0) if splitrow[1] == "total": - total_spilling += [np.float64(splitrow[3].replace("%", "")) / 100.0] + total_spilling.append(np.float64(splitrow[3].replace("%", "")) / 100.0) if len(charge_spilling) == number_of_spins and len(total_spilling) == number_of_spins: break @@ -1107,8 +1104,8 @@ def _get_elements_basistype_basisfunctions(data): basisfunctions = [] for row in data: if begin and not end: - splitrow = row.split() - if splitrow[0] not in [ + row_parts = row.split() + if row_parts[0] not in { "INFO:", "WARNING:", "setting", @@ -1117,11 +1114,11 @@ def _get_elements_basistype_basisfunctions(data): "saving", "spillings", "writing", - ]: - elements += [splitrow[0]] - basistype += [splitrow[1].replace("(", "").replace(")", "")] + }: + elements.append(row_parts[0]) + basistype.append(row_parts[1].replace("(", "").replace(")", "")) # last sign is a '' - basisfunctions += [splitrow[2:]] + basisfunctions += [row_parts[2:]] else: end = True if "setting up local basis functions..." in row: @@ -1158,7 +1155,7 @@ def _get_warning_orthonormalization(data): for row in data: splitrow = row.split() if "orthonormalized" in splitrow: - orthowarning += [" ".join(splitrow[1:])] + orthowarning.append(" ".join(splitrow[1:])) return orthowarning @staticmethod @@ -1167,7 +1164,7 @@ def _get_all_warning_lines(data): for row in data: splitrow = row.split() if len(splitrow) > 0 and splitrow[0] == "WARNING:": - ws += [" ".join(splitrow[1:])] + ws.append(" ".join(splitrow[1:])) return ws @staticmethod @@ -1176,7 +1173,7 @@ def _get_all_info_lines(data): for row in data: splitrow = row.split() if len(splitrow) > 0 and splitrow[0] == "INFO:": - infos += [" ".join(splitrow[1:])] + infos.append(" ".join(splitrow[1:])) return infos @@ -1189,7 +1186,7 @@ class Fatband: The first index of the array refers to the band and the second to the index of the kpoint. The kpoints are ordered according to the order of the kpoints_array attribute. If the band structure is not spin polarized, we only store one data set under Spin.up. - is_spin_polarized (bool): Boolean that tells you whether this was a spin-polarized calculation. + is_spin_polarized (bool): Whether this was a spin-polarized calculation. kpoints_array (list[np.ndarray]): List of kpoints as numpy arrays, in frac_coords of the given lattice by default. label_dict (dict[str, Union[str, np.ndarray]]): Dictionary that links a kpoint (in frac coords or Cartesian @@ -1217,10 +1214,10 @@ def __init__( "FATBAND_*" files will be read kpoints_file (PathLike): KPOINTS file for bandstructure calculation, typically "KPOINTS". vasprun_file (PathLike): Corresponding vasprun file. - Instead, the Fermi energy from the DFT run can be provided. Then, + Instead, the Fermi level from the DFT run can be provided. Then, this value should be set to None. structure (Structure): Structure object. - efermi (float): fermi energy in eV. + efermi (float): Fermi level in eV. """ warnings.warn("Make sure all relevant FATBAND files were generated and read in!") warnings.warn("Use Lobster 3.2.0 or newer for fatband calculations!") @@ -1259,7 +1256,7 @@ def __init__( filenames = "." for name in os.listdir(filenames): if fnmatch.fnmatch(name, "FATBAND_*.lobster"): - filenames_new += [os.path.join(filenames, name)] + filenames_new.append(os.path.join(filenames, name)) filenames = filenames_new if len(filenames) == 0: raise ValueError("No FATBAND files in folder or given") @@ -1267,17 +1264,17 @@ def __init__( with zopen(name, mode="rt") as file: contents = file.read().split("\n") - atom_names += [os.path.split(name)[1].split("_")[1].capitalize()] + atom_names.append(os.path.split(name)[1].split("_")[1].capitalize()) parameters = contents[0].split() - atom_type += [re.split(r"[0-9]+", parameters[3])[0].capitalize()] - orbital_names += [parameters[4]] + atom_type.append(re.split(r"[0-9]+", parameters[3])[0].capitalize()) + orbital_names.append(parameters[4]) # get atomtype orbital dict atom_orbital_dict = {} # type: dict - for iatom, atom in enumerate(atom_names): + for idx, atom in enumerate(atom_names): if atom not in atom_orbital_dict: atom_orbital_dict[atom] = [] - atom_orbital_dict[atom] += [orbital_names[iatom]] + atom_orbital_dict[atom].append(orbital_names[idx]) # test if there are the same orbitals twice or if two different formats were used or if all necessary orbitals # are there for items in atom_orbital_dict.values(): @@ -1285,9 +1282,9 @@ def __init__( raise ValueError("The are two FATBAND files for the same atom and orbital. The program will stop.") split = [] for item in items: - split += [item.split("_")[0]] + split.append(item.split("_")[0]) for number in collections.Counter(split).values(): - if number not in (1, 3, 5, 7): + if number not in {1, 3, 5, 7}: raise ValueError( "Make sure all relevant orbitals were generated and that no duplicates (2p and 2p_x) are " "present" @@ -1312,7 +1309,7 @@ def __init__( linenumbers = [] for iline, line in enumerate(contents[1 : self.nbands * 2 + 4]): if line.split()[0] == "#": - linenumbers += [iline] + linenumbers.append(iline) if ifilename == 0: self.is_spinpolarized = len(linenumbers) == 2 @@ -1362,7 +1359,7 @@ def __init__( ] ) if ifilename == 0: - kpoints_array += [KPOINT] + kpoints_array.append(KPOINT) linenumber = 0 iband = 0 @@ -1405,7 +1402,7 @@ def get_bandstructure(self) -> LobsterBandStructureSymmLine: kpoints=self.kpoints_array, eigenvals=self.eigenvals, lattice=self.lattice, - efermi=self.efermi, + efermi=self.efermi, # type: ignore[arg-type] labels_dict=self.label_dict, structure=self.structure, projections=self.p_eigenvals, @@ -1414,7 +1411,6 @@ def get_bandstructure(self) -> LobsterBandStructureSymmLine: class Bandoverlaps(MSONable): """Read in bandOverlaps.lobster files. These files are not created during every Lobster run. - Attributes: band_overlaps_dict (dict[Spin, Dict[str, Dict[str, Union[float, np.ndarray]]]]): A dictionary containing the band overlap data of the form: {spin: {"kpoint as string": {"maxDeviation": @@ -1454,7 +1450,7 @@ def __init__( def _read(self, contents: list, spin_numbers: list): """ - Will read in all contents of the file. + Will read in all contents of the file Args: contents: list of strings @@ -1474,7 +1470,7 @@ def _read(self, contents: list, spin_numbers: list): kpoint_array = [] for kpointel in kpoint: if kpointel not in {"at", "k-point", ""}: - kpoint_array += [float(kpointel)] + kpoint_array.append(float(kpointel)) elif "maxDeviation" in line: if spin not in self.band_overlaps_dict: @@ -1486,29 +1482,28 @@ def _read(self, contents: list, spin_numbers: list): if "matrices" not in self.band_overlaps_dict[spin]: self.band_overlaps_dict[spin]["matrices"] = [] maxdev = line.split(" ")[2] - self.band_overlaps_dict[spin]["max_deviations"] += [float(maxdev)] + self.band_overlaps_dict[spin]["max_deviations"].append(float(maxdev)) self.band_overlaps_dict[spin]["k_points"] += [kpoint_array] - self.max_deviation += [float(maxdev)] + self.max_deviation.append(float(maxdev)) overlaps = [] else: rows = [] for el in line.split(" "): if el != "": - rows += [float(el)] + rows.append(float(el)) overlaps += [rows] if len(overlaps) == len(rows): self.band_overlaps_dict[spin]["matrices"] += [np.matrix(overlaps)] def has_good_quality_maxDeviation(self, limit_maxDeviation: float = 0.1) -> bool: - """ - Will check if the maxDeviation from the ideal bandoverlap is smaller or equal to limit_maxDeviation. + """Will check if the maxDeviation from the ideal bandoverlap is smaller or equal to limit_maxDeviation Args: limit_maxDeviation: limit of the maxDeviation Returns: - Boolean that will give you information about the quality of the projection. + bool: True if the quality of the projection is good. """ return all(deviation <= limit_maxDeviation for deviation in self.max_deviation) @@ -1530,7 +1525,7 @@ def has_good_quality_check_occupied_bands( limit_deviation (float): limit of the maxDeviation Returns: - Boolean that will give you information about the quality of the projection + bool: True if the quality of the projection is good. """ for matrix in self.band_overlaps_dict[Spin.up]["matrices"]: for iband1, band1 in enumerate(matrix): @@ -1583,7 +1578,7 @@ def __init__(self, filename: str = "GROSSPOP.lobster", list_dict_grosspop: list[ """ Args: filename: filename of the "GROSSPOP.lobster" file - list_dict_grosspop: List of dictionaries including all information about the gross populations. + list_dict_grosspop: List of dictionaries including all information about the gross populations """ # opens file self._filename = filename @@ -1606,10 +1601,10 @@ def __init__(self, filename: str = "GROSSPOP.lobster", list_dict_grosspop: list[ small_dict["Mulliken GP"][cleanline[0]] = float(cleanline[1]) small_dict["Loewdin GP"][cleanline[0]] = float(cleanline[2]) if "total" in cleanline[0]: - self.list_dict_grosspop += [small_dict] + self.list_dict_grosspop.append(small_dict) def get_structure_with_total_grosspop(self, structure_filename: str) -> Structure: - """Get a Structure with Mulliken and Loewdin total grosspopulations as site properties. + """Get a Structure with Mulliken and Loewdin total grosspopulations as site properties Args: structure_filename (str): filename of POSCAR @@ -1618,7 +1613,6 @@ def get_structure_with_total_grosspop(self, structure_filename: str) -> Structur Structure Object with Mulliken and Loewdin total grosspopulations as site properties. """ struct = Structure.from_file(structure_filename) - # site_properties: dict[str, Any] = {} mullikengp = [] loewdingp = [] for grosspop in self.list_dict_grosspop: @@ -1667,9 +1661,9 @@ def _parse_file(filename): splitline = line.split() if len(splitline) >= 6: points += [[float(splitline[0]), float(splitline[1]), float(splitline[2])]] - distance += [float(splitline[3])] - real += [float(splitline[4])] - imaginary += [float(splitline[5])] + distance.append(float(splitline[3])) + real.append(float(splitline[4])) + imaginary.append(float(splitline[5])) if len(real) != grid[0] * grid[1] * grid[2] or len(imaginary) != grid[0] * grid[1] * grid[2]: raise ValueError("Something went wrong while reading the file") @@ -1715,9 +1709,9 @@ def set_volumetric_data(self, grid, structure): "coordinates 0.0 0.0 0.0 coordinates 1.0 1.0 1.0 box bandlist 1 " ) - new_x += [x_here] - new_y += [y_here] - new_z += [z_here] + new_x.append(x_here) + new_y.append(y_here) + new_z.append(z_here) new_real += [self.real[runner]] new_imaginary += [self.imaginary[runner]] @@ -1884,7 +1878,7 @@ def __init__( sitepotentials_loewdin: Loewdin site potential sitepotentials_mulliken: Mulliken site potential madelungenergies_loewdin: Madelung energy based on the Loewdin approach - madelungenergies_mulliken: Madelung energy based on the Mulliken approach. + madelungenergies_mulliken: Madelung energy based on the Mulliken approach """ self._filename = filename self.ewald_splitting = [] if ewald_splitting is None else ewald_splitting @@ -1909,17 +1903,17 @@ def __init__( data = data[5:-1] self.num_atoms = len(data) - 2 for atom in range(self.num_atoms): - line = data[atom].split() - self.atomlist += [line[1] + str(line[0])] - self.types += [line[1]] - self.sitepotentials_mulliken += [float(line[2])] - self.sitepotentials_loewdin += [float(line[3])] + line_parts = data[atom].split() + self.atomlist.append(line_parts[1] + str(line_parts[0])) + self.types.append(line_parts[1]) + self.sitepotentials_mulliken.append(float(line_parts[2])) + self.sitepotentials_loewdin.append(float(line_parts[3])) self.madelungenergies_mulliken = float(data[self.num_atoms + 1].split()[3]) self.madelungenergies_loewdin = float(data[self.num_atoms + 1].split()[4]) def get_structure_with_site_potentials(self, structure_filename): - """Get a Structure with Mulliken and Loewdin charges as site properties. + """Get a Structure with Mulliken and Loewdin charges as site properties Args: structure_filename: filename of POSCAR @@ -1982,7 +1976,7 @@ def get_orb_from_str(orbs): list of tw Orbital objects """ # TODO: also useful for plotting of DOS - orb_labs = [ + orb_labs = ( "s", "p_y", "p_z", @@ -1999,7 +1993,7 @@ def get_orb_from_str(orbs): "f_xz^2", "f_z(x^2-y^2)", "f_x(x^2-3y^2)", - ] + ) orbitals = [(int(orb[0]), Orbital(orb_labs.index(orb[1:]))) for orb in orbs] orb_label = "" @@ -2054,8 +2048,9 @@ def __init__(self, e_fermi=None, filename: str = "hamiltonMatrices.lobster"): Args: filename: filename for the hamiltonMatrices file, typically "hamiltonMatrices.lobster". e_fermi: fermi level in eV for the structure only - relevant if input file contains hamilton matrices data. + relevant if input file contains hamilton matrices data """ + self._filename = filename # hamiltonMatrices with zopen(self._filename, mode="rt") as file: @@ -2100,21 +2095,21 @@ def _parse_matrix(file_data, pattern, e_fermi): for idx, line in enumerate(file_data): line = line.strip() if "Real parts" in line: - start_inxs_real += [idx + 1] + start_inxs_real.append(idx + 1) if idx == 1: # ignore the first occurrence as files start with real matrices pass else: - end_inxs_imag += [idx - 1] + end_inxs_imag.append(idx - 1) matches = re.search(pattern, file_data[idx - 1]) if matches and len(matches.groups()) == 2: k_point = matches.group(2) complex_matrices[k_point] = {} if "Imag parts" in line: - end_inxs_real += [idx - 1] - start_inxs_imag += [idx + 1] + end_inxs_real.append(idx - 1) + start_inxs_imag.append(idx + 1) # explicitly add the last line as files end with imaginary matrix if idx == len(file_data) - 1: - end_inxs_imag += [len(file_data)] + end_inxs_imag.append(len(file_data)) # extract matrix data and store diagonal elements matrix_real = [] @@ -2135,13 +2130,13 @@ def _parse_matrix(file_data, pattern, e_fermi): matches = re.search(pattern, file_data[start_inx_real - 2]) if matches and len(matches.groups()) == 2: - spin = Spin.up if matches.group(1) == "1" else Spin.down - k_point = matches.group(2) - complex_matrices[k_point].update({spin: comp_matrix}) + spin = Spin.up if matches[1] == "1" else Spin.down + k_point = matches[2] + complex_matrices[k_point] |= {spin: comp_matrix} elif matches and len(matches.groups()) == 1: - k_point = matches.group(1) - complex_matrices.update({k_point: comp_matrix}) - matrix_diagonal_values += [comp_matrix.real.diagonal() - e_fermi] + k_point = matches[1] + complex_matrices |= {k_point: comp_matrix} + matrix_diagonal_values.append(comp_matrix.real.diagonal() - e_fermi) # extract elements basis functions as list elements_basis_functions = [ diff --git a/src/pymatgen/io/nwchem.py b/src/pymatgen/io/nwchem.py index 8b39dcb40e3..d3ddddbc440 100644 --- a/src/pymatgen/io/nwchem.py +++ b/src/pymatgen/io/nwchem.py @@ -29,6 +29,7 @@ import numpy as np from monty.io import zopen from monty.json import MSONable + from pymatgen.analysis.excitation import ExcitationSpectrum from pymatgen.core.structure import Molecule, Structure from pymatgen.core.units import Energy, FloatWithUnit @@ -309,7 +310,7 @@ def dft_task(cls, mol, xc="b3lyp", **kwargs): theory is always "dft" for a dft task. """ t = NwTask.from_molecule(mol, theory="dft", **kwargs) - t.theory_directives.update({"xc": xc, "mult": t.spin_multiplicity}) + t.theory_directives |= {"xc": xc, "mult": t.spin_multiplicity} return t @classmethod @@ -843,10 +844,10 @@ def isfloatstring(in_str): cosmo_scf_energy = energies[-1] energies[-1] = {} energies[-1]["cosmo scf"] = cosmo_scf_energy - energies[-1].update({"gas phase": Energy(match[1], "Ha").to("eV")}) + energies[-1] |= {"gas phase": Energy(match[1], "Ha").to("eV")} if match := energy_sol_patt.search(line): - energies[-1].update({"sol phase": Energy(match[1], "Ha").to("eV")}) + energies[-1] |= {"sol phase": Energy(match[1], "Ha").to("eV")} if match := preamble_patt.search(line): try: @@ -910,23 +911,21 @@ def isfloatstring(in_str): for jj in range(ii + 1, len_hess): projected_hessian[ii].append(projected_hessian[jj][ii]) - data.update( - { - "job_type": job_type, - "energies": energies, - "corrections": corrections, - "molecules": molecules, - "structures": structures, - "basis_set": basis_set, - "errors": errors, - "has_error": len(errors) > 0, - "frequencies": frequencies, - "normal_frequencies": normal_frequencies, - "hessian": hessian, - "projected_hessian": projected_hessian, - "forces": all_forces, - "task_time": time, - } - ) + data |= { + "job_type": job_type, + "energies": energies, + "corrections": corrections, + "molecules": molecules, + "structures": structures, + "basis_set": basis_set, + "errors": errors, + "has_error": len(errors) > 0, + "frequencies": frequencies, + "normal_frequencies": normal_frequencies, + "hessian": hessian, + "projected_hessian": projected_hessian, + "forces": all_forces, + "task_time": time, + } return data diff --git a/src/pymatgen/io/openff.py b/src/pymatgen/io/openff.py index bef308811d5..37e610409c0 100644 --- a/src/pymatgen/io/openff.py +++ b/src/pymatgen/io/openff.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np + import pymatgen from pymatgen.analysis.graphs import MoleculeGraph from pymatgen.analysis.local_env import OpenBabelNN, metal_edge_extender diff --git a/src/pymatgen/io/optimade.py b/src/pymatgen/io/optimade.py new file mode 100644 index 00000000000..cbc1237a5d3 --- /dev/null +++ b/src/pymatgen/io/optimade.py @@ -0,0 +1,194 @@ +""" +This module provides conversion between structure entries following the +OPTIMADE (https://optimade.org) standard and pymatgen Structure objects. + +The code is adapted from the `optimade.adapters.structures.pymatgen` module in +optimade-python-tools (https://github.com/Materials-Consortia/optimade-python-tools), +and aims to work without requiring the explicit installation of the `optimade-python-tools`. + +""" + +from __future__ import annotations + +import itertools +import json +import math +import re +from functools import reduce +from typing import TYPE_CHECKING + +from pymatgen.core.structure import Lattice, Structure + +if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any + + +__author__ = "Matthew Evans" + + +def _pymatgen_species( + nsites: int, + species_at_sites: list[str], +) -> list[dict[str, float]]: + """Create list of {"symbol": "concentration"} per site for constructing pymatgen Species objects. + Removes vacancies, if they are present. + + This function is adapted from the `optimade.adapters.structures.pymatgen` module in `optimade-python-tools`, + with some of the generality removed (in terms of partial occupancy). + + """ + species = [{"name": _, "concentration": [1.0], "chemical_symbols": [_]} for _ in set(species_at_sites)] + species_dict = {_["name"]: _ for _ in species} + + pymatgen_species = [] + for site_number in range(nsites): + species_name = species_at_sites[site_number] + current_species = species_dict[species_name] + + chemical_symbols = [] + concentration = [] + for index, symbol in enumerate(current_species["chemical_symbols"]): + if symbol == "vacancy": + # Skip. This is how pymatgen handles vacancies; + # to not include them, while keeping the concentration in a site less than 1. + continue + chemical_symbols.append(symbol) + concentration.append(current_species["concentration"][index]) + + pymatgen_species.append(dict(zip(chemical_symbols, concentration))) + + return pymatgen_species + + +def _optimade_anonymous_element_generator() -> Generator[str, None, None]: + """Generator that yields the next symbol in the A, B, Aa, ... Az OPTIMADE anonymous + element naming scheme. + + """ + from string import ascii_lowercase + + for size in itertools.count(1): + for tuple_strings in itertools.product(ascii_lowercase, repeat=size): + list_strings = list(tuple_strings) + list_strings[0] = list_strings[0].upper() + yield "".join(list_strings) + + +def _optimade_reduce_or_anonymize_formula(formula: str, alphabetize: bool = True, anonymize: bool = False) -> str: + """Takes an input formula, reduces it and either alphabetizes or anonymizes it + following the OPTIMADE standard. + + """ + + numbers: list[int] = [int(n.strip() or 1) for n in re.split(r"[A-Z][a-z]*", formula)[1:]] + # Need to remove leading 1 from split and convert to ints + + species: list[str] = re.findall("[A-Z][a-z]*", formula) + + gcd = reduce(math.gcd, numbers) + + if not len(species) == len(numbers): + raise ValueError(f"Something is wrong with the input formula: {formula}") + + numbers = [n // gcd for n in numbers] + + if anonymize: + numbers = sorted(numbers, reverse=True) + species = [s for _, s in zip(numbers, _optimade_anonymous_element_generator())] + + elif alphabetize: + species, numbers = zip(*sorted(zip(species, numbers))) # type: ignore[assignment] + + return "".join(f"{s}{n if n != 1 else ''}" for n, s in zip(numbers, species)) + + +class OptimadeStructureAdapter: + """Adapter serves as a bridge between OPTIMADE structures and pymatgen objects.""" + + @staticmethod + def get_optimade_structure(structure: Structure, **kwargs) -> dict[str, str | dict[str, Any]]: + """Get a dictionary in the OPTIMADE Structure format from a pymatgen structure or molecule. + + Args: + structure (Structure): pymatgen Structure + **kwargs: passed to the ASE Atoms constructor + + Returns: + A dictionary serialization of the structure in the OPTIMADE format. + + """ + if not structure.is_ordered: + raise ValueError("OPTIMADE Adapter currently only supports ordered structures") + + attributes: dict[str, Any] = {} + attributes["cartesian_site_positions"] = structure.lattice.get_cartesian_coords(structure.frac_coords).tolist() + attributes["lattice_vectors"] = structure.lattice.matrix.tolist() + attributes["species_at_sites"] = [_.symbol for _ in structure.species] + attributes["species"] = [ + {"name": _.symbol, "chemical_symbols": [_.symbol], "concentration": [1]} + for _ in set(structure.composition.elements) + ] + attributes["dimension_types"] = [int(_) for _ in structure.lattice.pbc] + attributes["nperiodic_dimensions"] = sum(attributes["dimension_types"]) + attributes["nelements"] = len(structure.composition.elements) + attributes["chemical_formula_anonymous"] = _optimade_reduce_or_anonymize_formula( + structure.composition.formula, anonymize=True + ) + attributes["elements"] = sorted([_.symbol for _ in structure.composition.elements]) + attributes["chemical_formula_reduced"] = _optimade_reduce_or_anonymize_formula( + structure.composition.formula, anonymize=False + ) + attributes["chemical_formula_descriptive"] = structure.composition.formula + attributes["elements_ratios"] = [structure.composition.get_atomic_fraction(e) for e in attributes["elements"]] + attributes["nsites"] = len(attributes["species_at_sites"]) + + attributes["last_modified"] = None + attributes["immutable_id"] = None + attributes["structure_features"] = [] + + return {"attributes": attributes} + + @staticmethod + def get_structure(resource: dict) -> Structure: + """Get pymatgen structure from an OPTIMADE structure resource. + + Args: + resource: OPTIMADE structure resource as a dictionary, JSON string, or the + corresponding attributes dictionary (i.e., `resource["attributes"]`). + + Returns: + Structure: Equivalent pymatgen Structure + + """ + if isinstance(resource, str): + try: + resource = json.loads(resource) + except json.JSONDecodeError as exc: + raise ValueError(f"Could not decode the input OPTIMADE resource as JSON: {exc}") + + if "attributes" not in resource: + resource = {"attributes": resource} + + _id = resource.get("id", None) + attributes = resource["attributes"] + properties: dict[str, Any] = {"optimade_id": _id} + + # Take any prefixed attributes and save them as properties + custom_properties = {k: v for k, v in attributes.items() if k.startswith("_")} + if custom_properties: + properties["optimade_attributes"] = custom_properties + + return Structure( + lattice=Lattice( + attributes["lattice_vectors"], + [bool(d) for d in attributes["dimension_types"]], # type: ignore[arg-type] + ), + species=_pymatgen_species( + nsites=attributes["nsites"], + species_at_sites=attributes["species_at_sites"], + ), + coords=attributes["cartesian_site_positions"], + coords_are_cartesian=True, + properties=properties, + ) diff --git a/src/pymatgen/io/packmol.py b/src/pymatgen/io/packmol.py index d1f1f7f6c6a..ed159029864 100644 --- a/src/pymatgen/io/packmol.py +++ b/src/pymatgen/io/packmol.py @@ -26,6 +26,7 @@ class that provides a run() method for running packmol locally. from typing import TYPE_CHECKING import numpy as np + from pymatgen.core import Molecule from pymatgen.io.core import InputGenerator, InputSet diff --git a/src/pymatgen/io/phonopy.py b/src/pymatgen/io/phonopy.py index 6e156b2e2e5..d0d39f014e4 100644 --- a/src/pymatgen/io/phonopy.py +++ b/src/pymatgen/io/phonopy.py @@ -5,13 +5,14 @@ import numpy as np from monty.dev import requires from monty.serialization import loadfn +from scipy.interpolate import InterpolatedUnivariateSpline + from pymatgen.core import Lattice, Structure from pymatgen.phonon.bandstructure import PhononBandStructure, PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos, PhononDos from pymatgen.phonon.gruneisen import GruneisenParameter, GruneisenPhononBandStructureSymmLine from pymatgen.phonon.thermal_displacements import ThermalDisplacementMatrices from pymatgen.symmetry.bandstructure import HighSymmKpath -from scipy.interpolate import InterpolatedUnivariateSpline try: from phonopy import Phonopy diff --git a/src/pymatgen/io/pwmat/inputs.py b/src/pymatgen/io/pwmat/inputs.py index 1051405ffe9..bd143572216 100644 --- a/src/pymatgen/io/pwmat/inputs.py +++ b/src/pymatgen/io/pwmat/inputs.py @@ -8,13 +8,15 @@ import numpy as np from monty.io import zopen from monty.json import MSONable + from pymatgen.core import Lattice, Structure from pymatgen.symmetry.kpath import KPathSeek if TYPE_CHECKING: - from pymatgen.util.typing import PathLike from typing_extensions import Self + from pymatgen.util.typing import PathLike + __author__ = "Hanyu Liu" __email__ = "domainofbuaa@gmail.com" __date__ = "2024-1-16" @@ -397,7 +399,7 @@ def from_str(cls, data: str, mag: bool = False) -> Self: if mag: magmoms = ac_extractor.get_magmoms() for idx, tmp_site in enumerate(structure): - tmp_site.properties.update({"magmom": magmoms[idx]}) + tmp_site.properties |= {"magmom": magmoms[idx]} return cls(structure) @@ -496,8 +498,8 @@ def __init__( """ self.reciprocal_lattice: np.ndarray = reciprocal_lattice self.kpath: dict = {} - self.kpath.update({"kpoints": kpoints}) - self.kpath.update({"path": path}) + self.kpath |= {"kpoints": kpoints} + self.kpath |= {"path": path} self.density = density @classmethod @@ -515,7 +517,7 @@ def from_structure(cls, structure: Structure, dim: int, density: float = 0.01) - kpts_2d: dict[str, np.ndarray] = {} for tmp_name, tmp_kpt in kpath_set.kpath["kpoints"].items(): if (tmp_kpt[2]) == 0: - kpts_2d.update({tmp_name: tmp_kpt}) + kpts_2d |= {tmp_name: tmp_kpt} path_2d: list[list[str]] = [] for tmp_path in kpath_set.kpath["path"]: @@ -598,8 +600,8 @@ def __init__(self, reciprocal_lattice: np.ndarray, kpts: dict[str, list], path: """ self.reciprocal_lattice: np.ndarray = reciprocal_lattice self.kpath: dict = {} - self.kpath.update({"kpoints": kpts}) - self.kpath.update({"path": path}) + self.kpath |= {"kpoints": kpts} + self.kpath |= {"path": path} self.density = density @classmethod diff --git a/src/pymatgen/io/pwmat/outputs.py b/src/pymatgen/io/pwmat/outputs.py index 0ebd966efd2..1d2bbef2256 100644 --- a/src/pymatgen/io/pwmat/outputs.py +++ b/src/pymatgen/io/pwmat/outputs.py @@ -8,6 +8,7 @@ import numpy as np from monty.io import zopen from monty.json import MSONable + from pymatgen.io.pwmat.inputs import ACstrExtractor, AtomConfig, LineLocator if TYPE_CHECKING: @@ -135,17 +136,17 @@ def _parse_sefv(self) -> list[dict]: tmp_chunk: str = "" for _ in range(self.chunk_sizes[ii]): tmp_chunk += mvt.readline() - tmp_step.update({"atom_config": AtomConfig.from_str(tmp_chunk)}) - tmp_step.update({"e_tot": ACstrExtractor(tmp_chunk).get_e_tot()[0]}) - tmp_step.update({"atom_forces": ACstrExtractor(tmp_chunk).get_atom_forces().reshape(-1, 3)}) + tmp_step["atom_config"] = AtomConfig.from_str(tmp_chunk) + tmp_step["e_tot"] = ACstrExtractor(tmp_chunk).get_e_tot()[0] + tmp_step["atom_forces"] = ACstrExtractor(tmp_chunk).get_atom_forces().reshape(-1, 3) e_atoms: np.ndarray | None = ACstrExtractor(tmp_chunk).get_atom_forces() if e_atoms is not None: - tmp_step.update({"atom_energies": ACstrExtractor(tmp_chunk).get_atom_energies()}) + tmp_step["atom_energies"] = ACstrExtractor(tmp_chunk).get_atom_energies() else: print(f"Ionic step #{ii} : Energy deposition is turn down.") virial: np.ndarray | None = ACstrExtractor(tmp_chunk).get_virial() if virial is not None: - tmp_step.update({"virial": virial.reshape(3, 3)}) + tmp_step["virial"] = virial.reshape(3, 3) else: print(f"Ionic step #{ii} : No virial information.") ionic_steps.append(tmp_step) @@ -261,9 +262,9 @@ def _parse_kpt(self) -> tuple[np.ndarray, np.ndarray, dict[str, np.ndarray]]: kpts_weight[ii] = float(tmp_row_lst[3].strip()) if len(tmp_row_lst) == 5: - hsps.update( - {tmp_row_lst[4]: np.array([float(tmp_row_lst[0]), float(tmp_row_lst[1]), float(tmp_row_lst[2])])} - ) + hsps |= { + tmp_row_lst[4]: np.array([float(tmp_row_lst[0]), float(tmp_row_lst[1]), float(tmp_row_lst[2])]) + } return kpts, kpts_weight, hsps @property diff --git a/src/pymatgen/io/pwscf.py b/src/pymatgen/io/pwscf.py index 2af861b9941..abcbf2bb8c2 100644 --- a/src/pymatgen/io/pwscf.py +++ b/src/pymatgen/io/pwscf.py @@ -8,6 +8,7 @@ from monty.io import zopen from monty.re import regrep + from pymatgen.core import Element, Lattice, Structure from pymatgen.util.io_utils import clean_lines @@ -291,7 +292,7 @@ def input_mode(line): if match := re.match(r"(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line): key = match[1].strip() key_ = match[2].strip() - val = match[3].strip() + val = match[3].strip().rstrip(",") if key_ != "": if sections[section].get(key) is None: val_ = [0.0] * 20 # MAX NTYP DEFINITION @@ -305,7 +306,7 @@ def input_mode(line): sections[section][key] = PWInput.proc_val(key, val) elif mode[0] == "pseudo": - if match := re.match(r"(\w+)\s+(\d*.\d*)\s+(.*)", line): + if match := re.match(r"(\w+\d*[\+-]?)\s+(\d*.\d*)\s+(.*)", line): pseudo[match[1].strip()] = match[3].strip() elif mode[0] == "kpoints": @@ -317,7 +318,7 @@ def input_mode(line): elif mode[0] == "structure": m_l = re.match(r"(-?\d+\.?\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) - m_p = re.match(r"(\w+)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) + m_p = re.match(r"(\w+\d*[\+-]?)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) if m_l: lattice += [ float(m_l[1]), diff --git a/src/pymatgen/io/qchem/inputs.py b/src/pymatgen/io/qchem/inputs.py index bb20603c8bd..7776940c27d 100644 --- a/src/pymatgen/io/qchem/inputs.py +++ b/src/pymatgen/io/qchem/inputs.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from monty.io import zopen + from pymatgen.core import Molecule from pymatgen.io.core import InputFile @@ -291,7 +292,7 @@ def get_str(self) -> str: def multi_job_string(job_list: list[QCInput]) -> str: """ Args: - job_list (): List of jobs. + job_list (list[QCInput]): List of QChem jobs. Returns: str: String representation of a multi-job input file. @@ -372,8 +373,8 @@ def write_multi_job_file(job_list: list[QCInput], filename: str): """Write a multijob file. Args: - job_list (): List of jobs. - filename (): Filename + job_list (list[QCInput]): List of QChem jobs. + filename (str): Name of the file to write. """ with zopen(filename, mode="wt") as file: file.write(QCInput.multi_job_string(job_list)) @@ -451,10 +452,10 @@ def molecule_template(molecule: Molecule | list[Molecule] | Literal["read"]) -> return "\n".join(mol_list) @staticmethod - def rem_template(rem: dict) -> str: + def rem_template(rem: dict[str, Any]) -> str: """ Args: - rem (): + rem (dict[str, Any]): REM section. Returns: str: REM template. @@ -472,7 +473,7 @@ def opt_template(opt: dict[str, list]) -> str: Optimization template. Args: - opt (): + opt (dict[str, list]): Optimization section. Returns: str: Optimization template. @@ -497,7 +498,7 @@ def pcm_template(pcm: dict) -> str: PCM run template. Args: - pcm (): + pcm (dict): PCM section. Returns: str: PCM template. @@ -514,7 +515,7 @@ def solvent_template(solvent: dict) -> str: """Solvent template. Args: - solvent (): + solvent (dict): Solvent section. Returns: str: Solvent section. @@ -530,7 +531,7 @@ def solvent_template(solvent: dict) -> str: def smx_template(smx: dict) -> str: """ Args: - smx (): + smx (dict): Solvation model with short-range corrections. Returns: str: Solvation model with short-range corrections. @@ -603,7 +604,7 @@ def van_der_waals_template(radii: dict[str, float], mode: str = "atomic") -> str def plots_template(plots: dict) -> str: """ Args: - plots (): + plots (dict): Plots section. Returns: str: Plots section. @@ -618,7 +619,7 @@ def plots_template(plots: dict) -> str: def nbo_template(nbo: dict) -> str: """ Args: - nbo (): + nbo (dict): NBO section. Returns: str: NBO section. @@ -654,7 +655,7 @@ def svp_template(svp: dict) -> str: def geom_opt_template(geom_opt: dict) -> str: """ Args: - geom_opt (): + geom_opt (dict): Geometry optimization section. Returns: str: Geometry optimization section. diff --git a/src/pymatgen/io/qchem/outputs.py b/src/pymatgen/io/qchem/outputs.py index 6db29b910ac..c2cccd2ddc9 100644 --- a/src/pymatgen/io/qchem/outputs.py +++ b/src/pymatgen/io/qchem/outputs.py @@ -16,6 +16,7 @@ import pandas as pd from monty.io import zopen from monty.json import MSONable, jsanitize + from pymatgen.analysis.graphs import MoleculeGraph from pymatgen.analysis.local_env import OpenBabelNN from pymatgen.core import Molecule diff --git a/src/pymatgen/io/qchem/sets.py b/src/pymatgen/io/qchem/sets.py index 13e77d3118d..d85cfbad903 100644 --- a/src/pymatgen/io/qchem/sets.py +++ b/src/pymatgen/io/qchem/sets.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING from monty.io import zopen + from pymatgen.io.qchem.inputs import QCInput from pymatgen.io.qchem.utils import lower_and_check_unique diff --git a/src/pymatgen/io/res.py b/src/pymatgen/io/res.py index 6218392ed1e..e884ee13c98 100644 --- a/src/pymatgen/io/res.py +++ b/src/pymatgen/io/res.py @@ -10,26 +10,27 @@ from __future__ import annotations -import datetime import re from dataclasses import dataclass +from datetime import date, datetime, timezone from typing import TYPE_CHECKING from monty.io import zopen from monty.json import MSONable + from pymatgen.core import Element, Lattice, PeriodicSite, Structure from pymatgen.entries.computed_entries import ComputedStructureEntry from pymatgen.io.core import ParseError if TYPE_CHECKING: from collections.abc import Iterator - from datetime import date from pathlib import Path from typing import Any, Callable, Literal - from pymatgen.util.typing import Tuple3Ints, Vector3D from typing_extensions import Self + from pymatgen.util.typing import Tuple3Ints, Vector3D + @dataclass(frozen=True) class AirssTITL: @@ -419,9 +420,9 @@ def _parse_date(cls, string: str) -> date: raise ResParseError(f"Could not parse the date from {string=}.") day, month, year, *_ = match.groups() - month_num = datetime.datetime.strptime(month, "%b").replace(tzinfo=datetime.timezone.utc).month + month_num = datetime.strptime(month, "%b").replace(tzinfo=timezone.utc).month - return datetime.date(int(year), month_num, int(day)) + return date(int(year), month_num, int(day)) def _raise_or_none(self, err: ResParseError) -> None: if self.parse_rems != "strict": diff --git a/src/pymatgen/io/shengbte.py b/src/pymatgen/io/shengbte.py index 7ccdee00c23..bd8009d93cb 100644 --- a/src/pymatgen/io/shengbte.py +++ b/src/pymatgen/io/shengbte.py @@ -8,6 +8,7 @@ import numpy as np from monty.dev import requires from monty.json import MSONable + from pymatgen.core.structure import Structure from pymatgen.io.vasp import Kpoints diff --git a/src/pymatgen/io/template.py b/src/pymatgen/io/template.py index 6103f09d105..2ee08031ede 100644 --- a/src/pymatgen/io/template.py +++ b/src/pymatgen/io/template.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING from monty.io import zopen + from pymatgen.io.core import InputGenerator, InputSet if TYPE_CHECKING: diff --git a/src/pymatgen/io/vasp/incar_parameters.json b/src/pymatgen/io/vasp/incar_parameters.json index c662e725645..d9a056ba508 100644 --- a/src/pymatgen/io/vasp/incar_parameters.json +++ b/src/pymatgen/io/vasp/incar_parameters.json @@ -650,7 +650,7 @@ "type": "bool" }, "LREAL": { - "type": "Union[bool, str]", + "type": "(bool, str)", "values": [ false, true, diff --git a/src/pymatgen/io/vasp/inputs.py b/src/pymatgen/io/vasp/inputs.py index 8f5e601b58c..5ae331e11af 100644 --- a/src/pymatgen/io/vasp/inputs.py +++ b/src/pymatgen/io/vasp/inputs.py @@ -31,21 +31,23 @@ from monty.os import cd from monty.os.path import zpath from monty.serialization import dumpfn, loadfn +from tabulate import tabulate + from pymatgen.core import SETTINGS, Element, Lattice, Structure, get_el_sp from pymatgen.electronic_structure.core import Magmom from pymatgen.util.io_utils import clean_lines from pymatgen.util.string import str_delimited from pymatgen.util.typing import Kpoint, Tuple3Floats, Tuple3Ints, Vector3D -from tabulate import tabulate if TYPE_CHECKING: from collections.abc import Iterator from typing import Any, ClassVar, Literal from numpy.typing import ArrayLike + from typing_extensions import Self + from pymatgen.symmetry.bandstructure import HighSymmKpath from pymatgen.util.typing import PathLike - from typing_extensions import Self __author__ = "Shyue Ping Ong, Geoffroy Hautier, Rickard Armiento, Vincent L Chevrier, Stephen Dacek" @@ -1026,10 +1028,10 @@ def check_params(self) -> None: continue # Check value and its type - param_type = incar_params[tag].get("type") - allowed_values = incar_params[tag].get("values") + param_type: str = incar_params[tag].get("type") + allowed_values: list[Any] = incar_params[tag].get("values") - if param_type is not None and type(val).__name__ != param_type: + if param_type is not None and not isinstance(val, eval(param_type)): warnings.warn(f"{tag}: {val} is not a {param_type}", BadIncarWarning, stacklevel=2) # Only check value when it's not None, @@ -2751,7 +2753,7 @@ def __init__( """ super().__init__(**kwargs) self._potcar_filename = "POTCAR" + (".spec" if potcar_spec else "") - self.update({"INCAR": incar, "KPOINTS": kpoints, "POSCAR": poscar, self._potcar_filename: potcar}) + self |= {"INCAR": incar, "KPOINTS": kpoints, "POSCAR": poscar, self._potcar_filename: potcar} if optional_files is not None: self.update(optional_files) diff --git a/src/pymatgen/io/vasp/optics.py b/src/pymatgen/io/vasp/optics.py index bc52f8ec197..9e1e1b66cf7 100644 --- a/src/pymatgen/io/vasp/optics.py +++ b/src/pymatgen/io/vasp/optics.py @@ -10,15 +10,17 @@ import scipy.constants import scipy.special from monty.json import MSONable +from tqdm import tqdm + from pymatgen.electronic_structure.core import Spin from pymatgen.io.vasp.outputs import Vasprun, Waveder -from tqdm import tqdm if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray - from pymatgen.util.typing import PathLike from typing_extensions import Self + from pymatgen.util.typing import PathLike + __author__ = "Jimmy-Xuan Shen" __copyright__ = "Copyright 2022, The Materials Project" __maintainer__ = "Jimmy-Xuan Shen" diff --git a/src/pymatgen/io/vasp/outputs.py b/src/pymatgen/io/vasp/outputs.py index 9466ca1b7d2..40634973a1f 100644 --- a/src/pymatgen/io/vasp/outputs.py +++ b/src/pymatgen/io/vasp/outputs.py @@ -2,7 +2,6 @@ from __future__ import annotations -import datetime import itertools import logging import math @@ -13,6 +12,7 @@ from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass +from datetime import datetime, timezone from glob import glob from io import StringIO from pathlib import Path @@ -24,6 +24,7 @@ from monty.os.path import zpath from monty.re import regrep from numpy.testing import assert_allclose + from pymatgen.core import Composition, Element, Lattice, Structure from pymatgen.core.trajectory import Trajectory from pymatgen.core.units import unitized @@ -50,9 +51,10 @@ from xml.etree.ElementTree import Element as XML_Element from numpy.typing import NDArray - from pymatgen.util.typing import PathLike from typing_extensions import Self + from pymatgen.util.typing import PathLike + logger = logging.getLogger(__name__) @@ -279,7 +281,7 @@ def __init__( eigenvalues and magnetization. Defaults to False. Set to True to obtain projected eigenvalues and magnetization. **Note that this can take an extreme amount of time and memory.** So use this wisely. - parse_potcar_file (PathLike/bool): Whether to parse the potcar file to read + parse_potcar_file (bool | PathLike): Whether to parse the potcar file to read the potcar hashes for the potcar_spec attribute. Defaults to True, where no hashes will be determined and the potcar_spec dictionaries will read {"symbol": ElSymbol, "hash": None}. By Default, looks in @@ -843,7 +845,7 @@ def get_computed_entry( ComputedStructureEntry/ComputedEntry """ if entry_id is None: - entry_id = f"vasprun-{datetime.datetime.now(tz=datetime.timezone.utc)}" + entry_id = f"vasprun-{datetime.now(tz=timezone.utc)}" param_names = { "is_hubbard", "hubbards", @@ -897,9 +899,8 @@ def get_band_structure( the band structure data. Set this flag to ignore it. (Default: False) Returns: - a BandStructure object (or more specifically a - BandStructureSymmLine object if the run is detected to be a run - along symmetry lines) + BandStructure (or more specifically a BandStructureSymmLine object if the run + is detected to be a run along symmetry lines) Two types of runs along symmetry lines are accepted: non-sc with Line-Mode in the KPOINT file or hybrid, self-consistent with a @@ -1031,7 +1032,7 @@ def get_band_structure( eigenvals, lattice_new, e_fermi, - labels_dict, + labels_dict, # type: ignore[arg-type] structure=self.final_structure, projections=p_eig_vals, ) @@ -1557,7 +1558,7 @@ def _parse_ionic_step(self, elem: XML_Element) -> dict[str, float]: def _parse_dos(elem: XML_Element) -> tuple[Dos, Dos, list[dict]]: """Parse density of states (DOS).""" efermi = float(elem.find("i").text) # type: ignore[union-attr, arg-type] - energies = None + energies: NDArray | None = None tdensities = {} idensities = {} @@ -1586,6 +1587,8 @@ def _parse_dos(elem: XML_Element) -> tuple[Dos, Dos, list[dict]]: pdos[orb][spin] = data[:, col_idx] # type: ignore[index] pdoss.append(pdos) elem.clear() + + assert energies is not None return Dos(efermi, energies, tdensities), Dos(efermi, energies, idensities), pdoss @staticmethod @@ -1889,7 +1892,7 @@ class Outcar: final_energy (float): Final energy after extrapolation of sigma back to 0, i.e. energy(sigma->0). final_energy_wo_entrp (float): Final energy before extrapolation of sigma, i.e. energy without entropy. final_fr_energy (float): Final "free energy", i.e. free energy TOTEN. - has_onsite_density_matrices (bool): Boolean for if onsite density matrices have been set. + has_onsite_density_matrices (bool): Whether onsite density matrices have been set. lcalcpol (bool): If LCALCPOL has been set. lepsilon (bool): If LEPSILON has been set. nelect (float): Returns the number of electrons in the calculation. diff --git a/src/pymatgen/io/vasp/sets.py b/src/pymatgen/io/vasp/sets.py index ab97bf85037..8defbfddcb5 100644 --- a/src/pymatgen/io/vasp/sets.py +++ b/src/pymatgen/io/vasp/sets.py @@ -45,6 +45,7 @@ from monty.dev import deprecated from monty.json import MSONable from monty.serialization import loadfn + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import Element, PeriodicSite, SiteCollection, Species, Structure from pymatgen.io.core import InputGenerator @@ -58,9 +59,10 @@ if TYPE_CHECKING: from typing import Callable, Literal, Union - from pymatgen.util.typing import PathLike, Tuple3Ints, Vector3D from typing_extensions import Self + from pymatgen.util.typing import PathLike, Tuple3Ints, Vector3D + UserPotcarFunctional = Union[ Literal["PBE", "PBE_52", "PBE_54", "LDA", "LDA_52", "LDA_54", "PW91", "LDA_US", "PW91_US"], None ] @@ -146,7 +148,7 @@ class VaspInputSet(InputGenerator, abc.ABC): force_gamma (bool): Force gamma centered kpoint generation. Default (False) is to use the Automatic Density kpoint scheme, which will use the Gamma centered generation scheme for hexagonal cells, and Monkhorst-Pack otherwise. - reduce_structure (None/str): Before generating the input files, generate the + reduce_structure (str | None): Before generating the input files, generate the reduced structure. Default (None), does not alter the structure. Valid values: None, "niggli", "LLL". vdw: Adds default parameters for van-der-Waals functionals supported by VASP to @@ -1488,7 +1490,7 @@ def __post_init__(self) -> None: ) if self.xc_functional.upper() == "R2SCAN": - self._config_dict["INCAR"].update({"METAGGA": "R2SCAN", "ALGO": "ALL", "GGA": None}) + self._config_dict["INCAR"] |= {"METAGGA": "R2SCAN", "ALGO": "ALL", "GGA": None} if self.xc_functional.upper().endswith("+U"): self._config_dict["INCAR"]["LDAU"] = True @@ -2885,9 +2887,8 @@ def batch_write_input( Defaults to True. subfolder (callable): Function to create subdirectory name from structure. Defaults to simply "formula_count". - sanitize (bool): Boolean indicating whether to sanitize the - structure before writing the VASP input files. Sanitized output - are generally easier for viewing and certain forms of analysis. + sanitize (bool): Whether to sanitize the structure before writing the VASP input files. + Sanitized output are generally easier for viewing and certain forms of analysis. Defaults to False. include_cif (bool): Whether to output a CIF as well. CIF files are generally better supported in visualization programs. diff --git a/src/pymatgen/io/wannier90.py b/src/pymatgen/io/wannier90.py index 566c14a2679..fe417f586de 100644 --- a/src/pymatgen/io/wannier90.py +++ b/src/pymatgen/io/wannier90.py @@ -30,7 +30,7 @@ class Unk: data (numpy.ndarray): Numpy array that contains the wavefunction data in the UNK file. The shape should be (nbnd, ngx, ngy, ngz) for regular calculations and (nbnd, 2, ngx, ngy, ngz) for noncollinear calculations. - is_noncollinear (bool): Boolean that specifies if data is from a noncollinear calculation. + is_noncollinear (bool): True if data is from a noncollinear calculation. nbnd (int): Number of bands in data. ng (tuple): Sequence of three integers that correspond to the grid size of the given data. The definition is ng = (ngx, ngy, ngz). diff --git a/src/pymatgen/io/xr.py b/src/pymatgen/io/xr.py index 3fb644c11bd..4ad17556412 100644 --- a/src/pymatgen/io/xr.py +++ b/src/pymatgen/io/xr.py @@ -15,6 +15,7 @@ import numpy as np from monty.io import zopen + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure @@ -37,7 +38,7 @@ class Xr: def __init__(self, structure: Structure): """ Args: - structure (Structure/IStructure): Structure object to create the Xr object. + structure (Structure | IStructure): Structure object to create the Xr object. """ if not structure.is_ordered: raise ValueError("Xr file can only be constructed from ordered structure") diff --git a/src/pymatgen/io/xtb/inputs.py b/src/pymatgen/io/xtb/inputs.py index cbc6a413dbe..b5bcca3b4b6 100644 --- a/src/pymatgen/io/xtb/inputs.py +++ b/src/pymatgen/io/xtb/inputs.py @@ -23,7 +23,7 @@ class CRESTInput(MSONable): """ - An object representing CREST input files. + An object representing CREST input files. Because CREST is controlled through command line flags and external files, the CRESTInput class mainly consists of methods for containing and writing external files. diff --git a/src/pymatgen/io/xtb/outputs.py b/src/pymatgen/io/xtb/outputs.py index 60ca8c52834..4751c251a8f 100644 --- a/src/pymatgen/io/xtb/outputs.py +++ b/src/pymatgen/io/xtb/outputs.py @@ -7,6 +7,7 @@ import re from monty.json import MSONable + from pymatgen.core import Molecule from pymatgen.io.xyz import XYZ diff --git a/src/pymatgen/io/xyz.py b/src/pymatgen/io/xyz.py index 28914a47010..81430321a28 100644 --- a/src/pymatgen/io/xyz.py +++ b/src/pymatgen/io/xyz.py @@ -8,6 +8,7 @@ import pandas as pd from monty.io import zopen + from pymatgen.core import Molecule, Structure from pymatgen.core.structure import SiteCollection diff --git a/src/pymatgen/io/zeopp.py b/src/pymatgen/io/zeopp.py index 5ebee530013..fa8299a2925 100644 --- a/src/pymatgen/io/zeopp.py +++ b/src/pymatgen/io/zeopp.py @@ -31,6 +31,7 @@ from monty.dev import requires from monty.io import zopen from monty.tempfile import ScratchDir + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Molecule, Structure from pymatgen.io.cssr import Cssr diff --git a/src/pymatgen/phonon/bandstructure.py b/src/pymatgen/phonon/bandstructure.py index 17cb899b948..f8354fdb743 100644 --- a/src/pymatgen/phonon/bandstructure.py +++ b/src/pymatgen/phonon/bandstructure.py @@ -7,6 +7,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.electronic_structure.bandstructure import Kpoint @@ -17,9 +18,10 @@ from typing import Any from numpy.typing import ArrayLike - from pymatgen.util.typing import Tuple3Ints from typing_extensions import Self + from pymatgen.util.typing import Tuple3Ints + def get_reasonable_repetitions(n_atoms: int) -> Tuple3Ints: """Choose the number of repetitions in a supercell @@ -252,7 +254,7 @@ def asr_breaking(self, tol_eigendisplacements: float = 1e-5) -> np.ndarray | Non """Get the breaking of the acoustic sum rule for the three acoustic modes, if Gamma is present. None otherwise. If eigendisplacements are available they are used to determine the acoustic - modes: selects the bands corresponding to the eigendisplacements that + modes: selects the bands corresponding to the eigendisplacements that represent to a translation within tol_eigendisplacements. If these are not identified or eigendisplacements are missing the first 3 modes will be used (indices [:3]). diff --git a/src/pymatgen/phonon/dos.py b/src/pymatgen/phonon/dos.py index 0e3e0ea2296..e4e26c192d4 100644 --- a/src/pymatgen/phonon/dos.py +++ b/src/pymatgen/phonon/dos.py @@ -8,9 +8,10 @@ import scipy.constants as const from monty.functools import lazy_property from monty.json import MSONable +from scipy.ndimage import gaussian_filter1d + from pymatgen.core.structure import Structure from pymatgen.util.coord import get_linear_interpolated_value -from scipy.ndimage import gaussian_filter1d if TYPE_CHECKING: from collections.abc import Sequence diff --git a/src/pymatgen/phonon/gruneisen.py b/src/pymatgen/phonon/gruneisen.py index 67517b4faa7..a4aee2cc443 100644 --- a/src/pymatgen/phonon/gruneisen.py +++ b/src/pymatgen/phonon/gruneisen.py @@ -8,12 +8,13 @@ import scipy.constants as const from monty.dev import requires from monty.json import MSONable +from scipy.interpolate import UnivariateSpline + from pymatgen.core import Structure from pymatgen.core.lattice import Lattice from pymatgen.core.units import amu_to_kg from pymatgen.phonon.bandstructure import PhononBandStructure, PhononBandStructureSymmLine from pymatgen.phonon.dos import PhononDos -from scipy.interpolate import UnivariateSpline try: import phonopy diff --git a/src/pymatgen/phonon/ir_spectra.py b/src/pymatgen/phonon/ir_spectra.py index 9626e6c5726..462f4fa94b9 100644 --- a/src/pymatgen/phonon/ir_spectra.py +++ b/src/pymatgen/phonon/ir_spectra.py @@ -11,6 +11,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.core.spectrum import Spectrum from pymatgen.core.structure import Structure from pymatgen.util.plotting import add_fig_kwargs diff --git a/src/pymatgen/phonon/plotter.py b/src/pymatgen/phonon/plotter.py index d5dd63bc2ec..64f58499212 100644 --- a/src/pymatgen/phonon/plotter.py +++ b/src/pymatgen/phonon/plotter.py @@ -12,6 +12,7 @@ from matplotlib.collections import LineCollection from matplotlib.colors import LinearSegmentedColormap from monty.json import jsanitize + from pymatgen.electronic_structure.plotter import BSDOSPlotter, plot_brillouin_zone from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.gruneisen import GruneisenPhononBandStructureSymmLine @@ -24,6 +25,7 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure + from pymatgen.core import Structure from pymatgen.phonon.dos import PhononDos from pymatgen.phonon.gruneisen import GruneisenParameter @@ -653,7 +655,7 @@ def plot_compare( **kwargs: passed to ax.plot(). Returns: - a matplotlib object with both band structures + plt.Axes: with two band structures. """ unit = freq_units(units) legend_kwargs = legend_kwargs or {} @@ -1184,7 +1186,7 @@ def plot_compare_gs(self, other_plotter: GruneisenPhononBSPlotter) -> Axes: ValueError: if the two plotters are incompatible (due to different data lengths) Returns: - a matplotlib object with both band structures + plt.Axes: with both band structures """ data_orig = self.bs_plot_data() data = other_plotter.bs_plot_data() diff --git a/src/pymatgen/phonon/thermal_displacements.py b/src/pymatgen/phonon/thermal_displacements.py index 818b5f6f911..b021cbc57b3 100644 --- a/src/pymatgen/phonon/thermal_displacements.py +++ b/src/pymatgen/phonon/thermal_displacements.py @@ -8,6 +8,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core.structure import Structure from pymatgen.io.cif import CifFile, CifParser, CifWriter, str2float diff --git a/src/pymatgen/symmetry/analyzer.py b/src/pymatgen/symmetry/analyzer.py index f404aa2170b..5a06a9c4b48 100644 --- a/src/pymatgen/symmetry/analyzer.py +++ b/src/pymatgen/symmetry/analyzer.py @@ -25,6 +25,7 @@ import numpy as np import scipy.cluster import spglib + from pymatgen.core.lattice import Lattice from pymatgen.core.operations import SymmOp from pymatgen.core.structure import Molecule, PeriodicSite, Structure @@ -36,11 +37,12 @@ from typing import Any, Literal from numpy.typing import NDArray + from spglib import SpglibDataset + from pymatgen.core import Element, Species from pymatgen.core.sites import Site from pymatgen.symmetry.groups import CrystalSystem from pymatgen.util.typing import Kpoint - from spglib import SpglibDataset LatticeType = Literal["cubic", "hexagonal", "monoclinic", "orthorhombic", "rhombohedral", "tetragonal", "triclinic"] @@ -85,7 +87,7 @@ def __init__( ) -> None: """ Args: - structure (Structure/IStructure): Structure to find symmetry + structure (Structure | IStructure): Structure to find symmetry symprec (float): Tolerance for symmetry finding. Defaults to 0.01, which is fairly strict and works well for properly refined structures with atoms in the proper symmetry coordinates. For @@ -275,10 +277,9 @@ def _get_symmetry(self) -> tuple[NDArray, NDArray]: # [1e-4, 2e-4, 1e-4] # (these are in fractional coordinates, so should be small denominator # fractions) - _translations: list = [] - for trans in dct["translations"]: - _translations.append([float(Fraction(c).limit_denominator(1000)) for c in trans]) - translations: NDArray = np.array(_translations) + translations: NDArray = np.array( + [[float(Fraction(c).limit_denominator(1000)) for c in trans] for trans in dct["translations"]] + ) # Fractional translations of 1 are more simply 0 translations[np.abs(translations) == 1] = 0 @@ -1345,7 +1346,7 @@ def is_valid_op(self, symm_op: SymmOp) -> bool: symm_op (SymmOp): Symmetry operation to test. Returns: - bool: Whether SymmOp is valid for Molecule. + bool: True if SymmOp is valid for Molecule. """ coords = self.centered_mol.cart_coords for site in self.centered_mol: @@ -1673,7 +1674,7 @@ def are_symmetrically_equivalent( are symmetrically similar. Returns: - bool: Whether the two sets of sites are symmetrically equivalent. + bool: True if the two sets of sites are symmetrically equivalent. """ def in_sites(site): diff --git a/src/pymatgen/symmetry/bandstructure.py b/src/pymatgen/symmetry/bandstructure.py index c37009792e1..05d416a0fbe 100644 --- a/src/pymatgen/symmetry/bandstructure.py +++ b/src/pymatgen/symmetry/bandstructure.py @@ -9,6 +9,7 @@ import networkx as nx import numpy as np + from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.core import Spin from pymatgen.symmetry.analyzer import cite_conventional_cell_algo diff --git a/src/pymatgen/symmetry/groups.py b/src/pymatgen/symmetry/groups.py index fab21815d5e..375fd17d5f6 100644 --- a/src/pymatgen/symmetry/groups.py +++ b/src/pymatgen/symmetry/groups.py @@ -17,17 +17,19 @@ import numpy as np from monty.design_patterns import cached_class from monty.serialization import loadfn + from pymatgen.util.string import Stringify if TYPE_CHECKING: from typing import ClassVar, Literal from numpy.typing import ArrayLike + from typing_extensions import Self + from pymatgen.core.lattice import Lattice # Don't import at runtime to avoid circular import from pymatgen.core.operations import SymmOp # noqa: TCH004 - from typing_extensions import Self CrystalSystem = Literal["cubic", "hexagonal", "monoclinic", "orthorhombic", "tetragonal", "triclinic", "trigonal"] @@ -253,7 +255,7 @@ def __init__(self, int_symbol: str, hexagonal: bool = True) -> None: notation is a LaTeX-like string, with screw axes being represented by an underscore. For example, "P6_3/mmc". Alternative settings can be accessed by adding a ":identifier". - For example, the hexagonal setting for rhombohedral cells can be + For example, the hexagonal setting for rhombohedral cells can be accessed by adding a ":H", e.g. "R-3m:H". To find out all possible settings for a spacegroup, use the get_settings() classmethod. Alternative origin choices can be indicated by a diff --git a/src/pymatgen/symmetry/kpath.py b/src/pymatgen/symmetry/kpath.py index a1983f5047d..72a877d16ed 100644 --- a/src/pymatgen/symmetry/kpath.py +++ b/src/pymatgen/symmetry/kpath.py @@ -12,6 +12,7 @@ import numpy as np import spglib from monty.dev import requires + from pymatgen.core.lattice import Lattice from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, cite_conventional_cell_algo @@ -1342,11 +1343,7 @@ def _choose_path( # Choose remaining unconnected key points for k-path. The ones that remain are # those with inversion symmetry. Connect them to gamma. - unconnected = [] - - for idx in range(len(key_points_inds_orbits)): - if idx not in point_orbits_in_path: - unconnected.append(idx) + unconnected = [idx for idx in range(len(key_points_inds_orbits)) if idx not in point_orbits_in_path] for ind in unconnected: connect = False diff --git a/src/pymatgen/symmetry/maggroups.py b/src/pymatgen/symmetry/maggroups.py index 1b7b72ee33a..15713dfe31d 100644 --- a/src/pymatgen/symmetry/maggroups.py +++ b/src/pymatgen/symmetry/maggroups.py @@ -11,6 +11,7 @@ import numpy as np from monty.design_patterns import cached_class + from pymatgen.core.operations import MagSymmOp from pymatgen.electronic_structure.core import Magmom from pymatgen.symmetry.groups import SymmetryGroup, in_array_list @@ -20,9 +21,10 @@ if TYPE_CHECKING: from collections.abc import Sequence - from pymatgen.core.lattice import Lattice from typing_extensions import Self + from pymatgen.core.lattice import Lattice + __author__ = "Matthew Horton, Shyue Ping Ong" MAGSYMM_DATA = os.path.join(os.path.dirname(__file__), "symm_data_magnetic.sqlite") @@ -235,19 +237,15 @@ def _parse_lattice(b): return None raw_lattice = [b[i : i + 4] for i in range(0, len(b), 4)] - lattice = [] - - for r in raw_lattice: - lattice.append( - { - "vector": [r[0] / r[3], r[1] / r[3], r[2] / r[3]], - "str": f"({Fraction(r[0] / r[3]).limit_denominator()}," - f"{Fraction(r[1] / r[3]).limit_denominator()}," - f"{Fraction(r[2] / r[3]).limit_denominator()})+", - } - ) - - return lattice + return [ + { + "vector": [r[0] / r[3], r[1] / r[3], r[2] / r[3]], + "str": f"({Fraction(r[0] / r[3]).limit_denominator()}," + f"{Fraction(r[1] / r[3]).limit_denominator()}," + f"{Fraction(r[2] / r[3]).limit_denominator()})+", + } + for r in raw_lattice + ] def _parse_transformation(b): """Parse compact binary representation into transformation between OG and BNS settings.""" @@ -557,9 +555,7 @@ def _write_all_magnetic_space_groups_to_file(filename): "http://stokes.byu.edu/iso/magnetic_data.txt\n" "Used with kind permission from Professor Branton Campbell, BYU\n\n" ) - all_msgs = [] - for i in range(1, 1652): - all_msgs.append(MagneticSpaceGroup(i)) + all_msgs = list(map(MagneticSpaceGroup, range(1, 1652))) for msg in all_msgs: out += f"\n{msg.data_str()}\n\n--------\n" with open(filename, mode="w") as file: diff --git a/src/pymatgen/symmetry/settings.py b/src/pymatgen/symmetry/settings.py index 99597c1cbe3..3abd7f547f3 100644 --- a/src/pymatgen/symmetry/settings.py +++ b/src/pymatgen/symmetry/settings.py @@ -7,11 +7,12 @@ from typing import TYPE_CHECKING import numpy as np +from sympy import Matrix +from sympy.parsing.sympy_parser import parse_expr + from pymatgen.core.lattice import Lattice from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.util.string import transformation_to_string -from sympy import Matrix -from sympy.parsing.sympy_parser import parse_expr if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/pymatgen/symmetry/site_symmetries.py b/src/pymatgen/symmetry/site_symmetries.py index ed889cb5f72..c73a321270a 100644 --- a/src/pymatgen/symmetry/site_symmetries.py +++ b/src/pymatgen/symmetry/site_symmetries.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import numpy as np + from pymatgen.core.operations import SymmOp from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/src/pymatgen/symmetry/structure.py b/src/pymatgen/symmetry/structure.py index 093b5f8af39..a4761ea74bd 100644 --- a/src/pymatgen/symmetry/structure.py +++ b/src/pymatgen/symmetry/structure.py @@ -5,15 +5,17 @@ from typing import TYPE_CHECKING import numpy as np -from pymatgen.core.structure import PeriodicSite, Structure from tabulate import tabulate +from pymatgen.core.structure import PeriodicSite, Structure + if TYPE_CHECKING: from collections.abc import Sequence - from pymatgen.symmetry.analyzer import SpacegroupOperations from typing_extensions import Self + from pymatgen.symmetry.analyzer import SpacegroupOperations + class SymmetrizedStructure(Structure): """This class represents a symmetrized structure, i.e. a structure @@ -116,8 +118,7 @@ def __str__(self) -> str: row = [str(idx), site.species_string] row.extend([f"{j:>10.6f}" for j in site.frac_coords]) row.append(self.wyckoff_symbols[idx]) - for key in keys: - row.append(props[key][idx]) + row += [props[key][idx] for key in keys] data.append(row) outs.append(tabulate(data, headers=["#", "SP", "a", "b", "c", "Wyckoff", *keys])) return "\n".join(outs) diff --git a/src/pymatgen/transformations/advanced_transformations.py b/src/pymatgen/transformations/advanced_transformations.py index c6107edc562..d68a10425f7 100644 --- a/src/pymatgen/transformations/advanced_transformations.py +++ b/src/pymatgen/transformations/advanced_transformations.py @@ -16,6 +16,7 @@ from monty.dev import requires from monty.fractions import lcm from monty.json import MSONable + from pymatgen.analysis.adsorption import AdsorbateSiteFinder from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.analysis.energy_models import SymmetryModel @@ -819,7 +820,7 @@ def apply_transformation( Structure | list[Structure]: Structure(s) after MagOrderTransformation. """ if not structure.is_ordered: - raise ValueError("Create an ordered approximation of your input structure first.") + raise ValueError("Create an ordered approximation of your input structure first.") # retrieve order parameters order_parameters = [MagOrderParameterConstraint.from_dict(op_dict) for op_dict in self.order_parameter] diff --git a/src/pymatgen/transformations/site_transformations.py b/src/pymatgen/transformations/site_transformations.py index 7ba3151fc90..220769fd0f4 100644 --- a/src/pymatgen/transformations/site_transformations.py +++ b/src/pymatgen/transformations/site_transformations.py @@ -14,6 +14,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.analysis.ewald import EwaldMinimizer, EwaldSummation from pymatgen.analysis.local_env import MinimumDistanceNN from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/src/pymatgen/transformations/standard_transformations.py b/src/pymatgen/transformations/standard_transformations.py index 616c7de84a7..e27406ec574 100644 --- a/src/pymatgen/transformations/standard_transformations.py +++ b/src/pymatgen/transformations/standard_transformations.py @@ -12,6 +12,7 @@ import numpy as np from numpy import around + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.analysis.elasticity.strain import Deformation from pymatgen.analysis.ewald import EwaldMinimizer, EwaldSummation @@ -24,9 +25,10 @@ from pymatgen.transformations.transformation_abc import AbstractTransformation if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core.sites import PeriodicSite from pymatgen.util.typing import SpeciesLike - from typing_extensions import Self logger = logging.getLogger(__name__) @@ -424,7 +426,7 @@ class OrderDisorderedStructureTransformation(AbstractTransformation): these will be treated separately if the difference is above a threshold tolerance. currently this is .1 - For example, if a fraction of .25 Li is on sites 0, 1, 2, 3 and .5 on sites + For example, if a fraction of .25 Li is on sites 0, 1, 2, 3 and .5 on sites 4, 5, 6, 7 then 1 site from [0, 1, 2, 3] will be filled and 2 sites from [4, 5, 6, 7] will be filled, even though a lower energy combination might be found by putting all lithium in sites [4, 5, 6, 7]. @@ -455,7 +457,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | """For this transformation, the apply_transformation method will return only the ordered structure with the lowest Ewald energy, to be consistent with the method signature of the other transformations. - However, all structures are stored in the all_structures attribute in + However, all structures are stored in the all_structures attribute in the transformation object for easy access. Args: @@ -722,7 +724,7 @@ def inverse(self): class DiscretizeOccupanciesTransformation(AbstractTransformation): - """Discretizes the site occupancies in a disordered structure; useful for + """Discretize the site occupancies in a disordered structure; useful for grouping similar structures or as a pre-processing step for order-disorder transformations. """ @@ -747,14 +749,14 @@ def __init__(self, max_denominator=5, tol: float | None = None, fix_denominator= self.tol = tol if tol is not None else 1 / (4 * max_denominator) self.fix_denominator = fix_denominator - def apply_transformation(self, structure): - """Discretizes the site occupancies in the structure. + def apply_transformation(self, structure) -> Structure: + """Discretize the site occupancies in the structure. Args: structure: disordered Structure to discretize occupancies Returns: - A new disordered Structure with occupancies discretized + Structure: new disordered Structure instance with occupancies discretized """ if structure.is_ordered: return structure diff --git a/src/pymatgen/util/coord.py b/src/pymatgen/util/coord.py index e00dceb5655..eabbc2dadfa 100644 --- a/src/pymatgen/util/coord.py +++ b/src/pymatgen/util/coord.py @@ -11,6 +11,7 @@ import numpy as np from monty.json import MSONable + from pymatgen.util import coord_cython if TYPE_CHECKING: @@ -18,6 +19,7 @@ from typing import Literal from numpy.typing import ArrayLike + from pymatgen.util.typing import PbcLike diff --git a/src/pymatgen/util/due.py b/src/pymatgen/util/due.py index 2f77758a45c..44c77b94611 100644 --- a/src/pymatgen/util/due.py +++ b/src/pymatgen/util/due.py @@ -5,10 +5,10 @@ from .due import due, Doi, BibTeX, Text -See https://github.com/duecredit/duecredit/blob/master/README.md for examples. +See https://github.com/duecredit/duecredit/blob/master/README.md for examples. Origin: Originally a part of the duecredit -Copyright: 2015-2021 DueCredit developers +Copyright: 2015-2021 DueCredit developers License: BSD-2 """ diff --git a/src/pymatgen/util/joblib.py b/src/pymatgen/util/joblib.py new file mode 100644 index 00000000000..043d8c6be5d --- /dev/null +++ b/src/pymatgen/util/joblib.py @@ -0,0 +1,51 @@ +"""This module provides utility functions for getting progress bar with joblib.""" + +from __future__ import annotations + +import contextlib +import os +from typing import TYPE_CHECKING, Any + +import joblib + +if TYPE_CHECKING: + from collections.abc import Iterator + + from tqdm import tqdm + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object: tqdm) -> Iterator[None]: + """Context manager to patch joblib to report into tqdm progress bar given + as argument. + """ + + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __call__(self, *args: tuple, **kwargs: dict[str, Any]) -> None: + """This will be called after each batch, to update the progress bar.""" + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() + + +@contextlib.contextmanager +def set_python_warnings(warnings): + """Context manager to set the PYTHONWARNINGS environment variable to the + given value. This is useful for preventing spam when using parallel processing. + """ + original_warnings = os.environ.get("PYTHONWARNINGS") + os.environ["PYTHONWARNINGS"] = warnings + try: + yield + finally: + if original_warnings is None: + del os.environ["PYTHONWARNINGS"] + else: + os.environ["PYTHONWARNINGS"] = original_warnings diff --git a/src/pymatgen/util/plotting.py b/src/pymatgen/util/plotting.py index b94cba844e0..068d1c2a064 100644 --- a/src/pymatgen/util/plotting.py +++ b/src/pymatgen/util/plotting.py @@ -12,6 +12,7 @@ import numpy as np import palettable.colorbrewer.diverging from matplotlib import cm, colors + from pymatgen.core import Element if TYPE_CHECKING: @@ -92,10 +93,10 @@ def pretty_plot_two_axis( examples. Makes it easier to create plots with different axes. Args: - x (np.ndarray/list): Data for x-axis. - y1 (dict/np.ndarray/list): Data for y1 axis (left). If a dict, it will + x (Sequence[float]): Data for x-axis. + y1 (Sequence[float] | dict[str, Sequence[float]]): Data for y1 axis (left). If a dict, it will be interpreted as a {label: sequence}. - y2 (dict/np.ndarray/list): Data for y2 axis (right). If a dict, it will + y2 (Sequence[float] | dict[str, Sequence[float]]): Data for y2 axis (right). If a dict, it will be interpreted as a {label: sequence}. xlabel (str): If not None, this will be the label for the x-axis. y1label (str): If not None, this will be the label for the y1-axis. @@ -711,7 +712,7 @@ def wrapper(*args, **kwargs): tight_layout True to call fig.tight_layout (default: False) ax_grid True (False) to add (remove) grid from all axes in fig. Default: None i.e. fig is left unchanged. - ax_annotate Add labels to subplots e.g. (a), (b). + ax_annotate Add labels to subplots e.g. (a), (b). Default: False fig_close Close figure. Default: False. ================ ==================================================== diff --git a/src/pymatgen/util/provenance.py b/src/pymatgen/util/provenance.py index 0e72704b1cc..23ff3ce2a5e 100644 --- a/src/pymatgen/util/provenance.py +++ b/src/pymatgen/util/provenance.py @@ -2,14 +2,15 @@ from __future__ import annotations -import datetime import json import re import sys +from datetime import datetime, timezone from io import StringIO from typing import TYPE_CHECKING, NamedTuple from monty.json import MontyDecoder, MontyEncoder + from pymatgen.core.structure import Molecule, Structure try: @@ -41,7 +42,7 @@ def is_valid_bibtex(reference: str) -> bool: reference: A String reference in BibTeX format. Returns: - Boolean indicating if reference is valid bibtex. + bool: True if reference is valid bibtex. """ # str is necessary since pybtex seems to have an issue with unicode. The # filter expression removes all non-ASCII characters. @@ -255,7 +256,7 @@ def __init__( if not all(sys.getsizeof(h) < MAX_HNODE_SIZE for h in history): raise ValueError(f"One or more history nodes exceeds the maximum size limit of {MAX_HNODE_SIZE} bytes") - self.created_at = created_at or datetime.datetime.utcnow() + self.created_at = created_at or f"{datetime.now(tz=timezone.utc):%Y-%m-%d %H:%M:%S.%f%z}" def as_dict(self): """Get MSONable dict.""" diff --git a/src/pymatgen/util/testing/__init__.py b/src/pymatgen/util/testing/__init__.py index 04aea7094fa..75c3a6be9ef 100644 --- a/src/pymatgen/util/testing/__init__.py +++ b/src/pymatgen/util/testing/__init__.py @@ -17,6 +17,7 @@ import pytest from monty.json import MontyDecoder, MontyEncoder, MSONable from monty.serialization import loadfn + from pymatgen.core import ROOT, SETTINGS, Structure if TYPE_CHECKING: diff --git a/src/pymatgen/util/testing/aims.py b/src/pymatgen/util/testing/aims.py index ce5b85525a7..1220974a8b5 100644 --- a/src/pymatgen/util/testing/aims.py +++ b/src/pymatgen/util/testing/aims.py @@ -10,6 +10,7 @@ import numpy as np from monty.io import zopen + from pymatgen.core import Molecule, Structure @@ -83,6 +84,7 @@ def comp_system( generator_cls: type, properties: list[str] | None = None, prev_dir: str | None | Path = None, + user_kpt_settings: dict[str, Any] | None = None, ) -> None: """Compare files generated by tests with ones in reference directories. @@ -95,16 +97,24 @@ def comp_system( generator_cls (type): The class of the generator properties (list[str] | None): The list of properties to calculate prev_dir (str | Path | None): The previous directory to pull outputs from + user_kpt_settings (dict[str, Any] | None): settings for k-point density in FHI-aims Raises: AssertionError: If the input files are not the same """ + if user_kpt_settings is None: + user_kpt_settings = {} + k_point_density = user_params.pop("k_point_density", 20) try: - generator = generator_cls(user_params=user_params, k_point_density=k_point_density) + generator = generator_cls( + user_params=user_params, + k_point_density=k_point_density, + user_kpoints_settings=user_kpt_settings, + ) except TypeError: - generator = generator_cls(user_params=user_params) + generator = generator_cls(user_params=user_params, user_kpoints_settings=user_kpt_settings) input_set = generator.get_input_set(structure, prev_dir, properties) input_set.write_input(work_dir / test_name) @@ -116,7 +126,7 @@ def compare_single_files(ref_file: str | Path, test_file: str | Path) -> None: """Compare single files generated by tests with ones in reference directories. Args: - ref_file (str | Path): The reference file to cmpare against + ref_file (str | Path): The reference file to compare against test_file (str | Path): The file to compare against the reference Raises: diff --git a/src/pymatgen/util/typing.py b/src/pymatgen/util/typing.py index 7f498c5e93f..9455460ba31 100644 --- a/src/pymatgen/util/typing.py +++ b/src/pymatgen/util/typing.py @@ -7,9 +7,12 @@ from collections.abc import Sequence from os import PathLike as OsPathLike -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union + +from numpy.typing import NDArray from pymatgen.core import Composition, DummySpecies, Element, Species +from pymatgen.electronic_structure.core import Magmom, Spin if TYPE_CHECKING: # needed to avoid circular imports from pymatgen.analysis.cost import CostEntry # type: ignore[attr-defined] @@ -25,6 +28,12 @@ PathLike = Union[str, OsPathLike] PbcLike = tuple[bool, bool, bool] +# Things that can be cast to a Spin +SpinLike = Union[Spin, Literal[-1, 1, "up", "down"]] + +# Things that can be cast to a magnetic moment +MagMomentLike = Union[float, Sequence[float], NDArray, Magmom] + # Things that can be cast to a Species-like object using get_el_sp SpeciesLike = Union[str, Element, Species, DummySpecies] diff --git a/src/pymatgen/vis/plotters.py b/src/pymatgen/vis/plotters.py index e77d4c92ccb..5670e24f8ae 100644 --- a/src/pymatgen/vis/plotters.py +++ b/src/pymatgen/vis/plotters.py @@ -5,6 +5,7 @@ import importlib import matplotlib.pyplot as plt + from pymatgen.util.plotting import pretty_plot diff --git a/src/pymatgen/vis/structure_chemview.py b/src/pymatgen/vis/structure_chemview.py index 73069535a91..9007bb9446f 100644 --- a/src/pymatgen/vis/structure_chemview.py +++ b/src/pymatgen/vis/structure_chemview.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np + from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/src/pymatgen/vis/structure_vtk.py b/src/pymatgen/vis/structure_vtk.py index cdb15313a80..2bc4293e2e1 100644 --- a/src/pymatgen/vis/structure_vtk.py +++ b/src/pymatgen/vis/structure_vtk.py @@ -12,6 +12,7 @@ import numpy as np from monty.dev import requires from monty.serialization import loadfn + from pymatgen.core import PeriodicSite, Species, Structure from pymatgen.util.coord import in_coord_list @@ -360,7 +361,7 @@ def add_partial_sphere(self, coords, radius, color, start=0, end=360, opacity=1. Args: coords (nd.array): Coordinates radius (float): Radius of sphere - color (): Color of sphere. + color (tuple): RGB color of sphere start (float): Starting angle. end (float): Ending angle. opacity (float): Opacity. @@ -516,7 +517,7 @@ def add_triangle( color: Color for triangle as RGB. center: The "central atom" of the triangle opacity: opacity of the triangle - draw_edges: If set to True, the a line will be drawn at each edge + draw_edges: If set to True, the a line will be drawn at each edge edges_color: Color of the line for the edges edges_linewidth: Width of the line drawn for the edges """ @@ -568,8 +569,8 @@ def add_faces(self, faces, color, opacity=0.35): Adding face of polygon. Args: - faces (): Coordinates of the faces. - color (): Color. + faces (list): Coordinates of the faces. + color (tuple): RGB color. opacity (float): Opacity """ for face in faces: @@ -628,13 +629,19 @@ def add_faces(self, faces, color, opacity=0.35): else: raise ValueError("Number of points for a face should be >= 3") - def add_edges(self, edges, type="line", linewidth=2, color=(0.0, 0.0, 0.0)): # noqa: A002 + def add_edges( + self, + edges: Sequence[Sequence[Sequence[float]]], + type: str = "line", # noqa: A002 + linewidth: float = 2, + color: tuple[float, float, float] = (0.0, 0.0, 0.0), + ) -> None: """ Args: - edges (): List of edges - type (): placeholder - linewidth (): Width of line - color (nd.array/tuple): RGB color. + edges (Sequence): List of edges. Each edge is a list of two points. + type (str): Type of the edge. Defaults to "line". Unused. + linewidth (float): Width of the line. + color (tuple[float, float, float]): RGB color. """ points = vtk.vtkPoints() lines = vtk.vtkCellArray() @@ -921,7 +928,7 @@ def __init__( bonding determination. Defaults to an empty list. Useful when trying to visualize a certain atom type in the framework (e.g., Li in a Li-ion battery cathode material). - animated_movie_options (): Used for moving. + animated_movie_options (dict): Options for animated movie. """ super().__init__( element_color_mapping=element_color_mapping, @@ -941,12 +948,11 @@ def __init__( self.set_animated_movie_options(animated_movie_options=animated_movie_options) def set_structures(self, structures: Sequence[Structure], tags=None): - """ - Add list of structures to the visualizer. + """Add list of structures to the visualizer. Args: structures (list[Structures]): structures to be visualized. - tags (): List of tags. + tags (list[dict]): List of tags to be applied to the structures. """ self.structures = structures self.istruct = 0 @@ -1053,7 +1059,7 @@ def apply_tags(self): def set_animated_movie_options(self, animated_movie_options=None): """ Args: - animated_movie_options (): animated movie options. + animated_movie_options (dict): Options for animated movie. """ if animated_movie_options is None: self.animated_movie_options = self.DEFAULT_ANIMATED_MOVIE_OPTIONS.copy() diff --git a/tasks.py b/tasks.py index d8b87a9f60d..607c5276143 100644 --- a/tasks.py +++ b/tasks.py @@ -9,17 +9,18 @@ from __future__ import annotations -import datetime import json import os import re import subprocess import webbrowser +from datetime import datetime, timezone from typing import TYPE_CHECKING import requests from invoke import task from monty.os import cd + from pymatgen.core import __version__ if TYPE_CHECKING: @@ -149,7 +150,7 @@ def update_changelog(ctx: Context, version: str | None = None, dry_run: bool = F dry_run (bool, optional): If True, the function will only print the changes without updating the actual change log file. Defaults to False. """ - version = version or f"{datetime.datetime.now(tz=datetime.timezone.utc):%Y.%-m.%-d}" + version = version or f"{datetime.now(tz=timezone.utc):%Y.%-m.%-d}" output = subprocess.check_output(["git", "log", "--pretty=format:%s", f"v{__version__}..HEAD"]) lines = [] ignored_commits = [] @@ -196,7 +197,7 @@ def release(ctx: Context, version: str | None = None, nodoc: bool = False) -> No version (str, optional): The version to release. nodoc (bool, optional): Whether to skip documentation generation. """ - version = version or f"{datetime.datetime.now(tz=datetime.timezone.utc):%Y.%-m.%-d}" + version = version or f"{datetime.now(tz=timezone.utc):%Y.%-m.%-d}" ctx.run("rm -r dist build pymatgen.egg-info", warn=True) set_ver(ctx, version) if not nodoc: diff --git a/tests/alchemy/test_filters.py b/tests/alchemy/test_filters.py index 418213e1dd3..16737ae67fa 100644 --- a/tests/alchemy/test_filters.py +++ b/tests/alchemy/test_filters.py @@ -4,6 +4,7 @@ from unittest import TestCase from monty.json import MontyDecoder + from pymatgen.alchemy.filters import ( ContainsSpecieFilter, RemoveDuplicatesFilter, diff --git a/tests/alchemy/test_materials.py b/tests/alchemy/test_materials.py index 0ec5c800f01..c667b27277e 100644 --- a/tests/alchemy/test_materials.py +++ b/tests/alchemy/test_materials.py @@ -4,6 +4,7 @@ from copy import deepcopy import pytest + from pymatgen.alchemy.filters import ContainsSpecieFilter from pymatgen.alchemy.materials import TransformedStructure from pymatgen.core import SETTINGS diff --git a/tests/analysis/chemenv/connectivity/test_connected_components.py b/tests/analysis/chemenv/connectivity/test_connected_components.py index dc9b01e62ea..2768fc5b815 100644 --- a/tests/analysis/chemenv/connectivity/test_connected_components.py +++ b/tests/analysis/chemenv/connectivity/test_connected_components.py @@ -7,6 +7,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.analysis.chemenv.connectivity.connected_components import ConnectedComponent from pymatgen.analysis.chemenv.connectivity.connectivity_finder import ConnectivityFinder from pymatgen.analysis.chemenv.connectivity.environment_nodes import EnvironmentNode diff --git a/tests/analysis/chemenv/coordination_environments/test_chemenv_strategies.py b/tests/analysis/chemenv/coordination_environments/test_chemenv_strategies.py index ef2109769e5..631aed31b01 100644 --- a/tests/analysis/chemenv/coordination_environments/test_chemenv_strategies.py +++ b/tests/analysis/chemenv/coordination_environments/test_chemenv_strategies.py @@ -1,6 +1,8 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import ( AdditionalConditionInt, AngleCutoffFloat, @@ -9,7 +11,6 @@ SimplestChemenvStrategy, ) from pymatgen.util.testing import PymatgenTest -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/coordination_environments/test_coordination_geometries.py b/tests/analysis/chemenv/coordination_environments/test_coordination_geometries.py index 03a323c65e9..3cbda25a5fe 100644 --- a/tests/analysis/chemenv/coordination_environments/test_coordination_geometries.py +++ b/tests/analysis/chemenv/coordination_environments/test_coordination_geometries.py @@ -3,6 +3,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import ( AllCoordinationGeometries, CoordinationGeometry, @@ -10,7 +12,6 @@ SeparationPlane, ) from pymatgen.util.testing import PymatgenTest -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/coordination_environments/test_coordination_geometry_finder.py b/tests/analysis/chemenv/coordination_environments/test_coordination_geometry_finder.py index 2428253b876..fc26a3c33af 100644 --- a/tests/analysis/chemenv/coordination_environments/test_coordination_geometry_finder.py +++ b/tests/analysis/chemenv/coordination_environments/test_coordination_geometry_finder.py @@ -3,6 +3,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import AllCoordinationGeometries from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import ( AbstractGeometry, @@ -11,7 +13,6 @@ ) from pymatgen.core.structure import Lattice, Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/coordination_environments/test_read_write.py b/tests/analysis/chemenv/coordination_environments/test_read_write.py index 0bb61b6163a..42ee8b8cb91 100644 --- a/tests/analysis/chemenv/coordination_environments/test_read_write.py +++ b/tests/analysis/chemenv/coordination_environments/test_read_write.py @@ -3,6 +3,8 @@ import json from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import ( AngleNbSetWeight, CNBiasNbSetWeight, @@ -21,7 +23,6 @@ from pymatgen.analysis.chemenv.coordination_environments.voronoi import DetailedVoronoiContainer from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/coordination_environments/test_structure_environments.py b/tests/analysis/chemenv/coordination_environments/test_structure_environments.py index 8ce3f92ec4b..b9a8ec160d7 100644 --- a/tests/analysis/chemenv/coordination_environments/test_structure_environments.py +++ b/tests/analysis/chemenv/coordination_environments/test_structure_environments.py @@ -5,6 +5,8 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import ( MultiWeightsChemenvStrategy, SimplestChemenvStrategy, @@ -16,7 +18,6 @@ ) from pymatgen.core import Species, Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/coordination_environments/test_voronoi.py b/tests/analysis/chemenv/coordination_environments/test_voronoi.py index 2adaa1f343c..cfe7b6af420 100644 --- a/tests/analysis/chemenv/coordination_environments/test_voronoi.py +++ b/tests/analysis/chemenv/coordination_environments/test_voronoi.py @@ -3,6 +3,7 @@ import random import numpy as np + from pymatgen.analysis.chemenv.coordination_environments.voronoi import DetailedVoronoiContainer from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure diff --git a/tests/analysis/chemenv/coordination_environments/test_weights.py b/tests/analysis/chemenv/coordination_environments/test_weights.py index 5295d47c635..f4f117dbe38 100644 --- a/tests/analysis/chemenv/coordination_environments/test_weights.py +++ b/tests/analysis/chemenv/coordination_environments/test_weights.py @@ -3,6 +3,8 @@ import json import pytest +from pytest import approx + from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import ( AngleNbSetWeight, CNBiasNbSetWeight, @@ -15,7 +17,6 @@ ) from pymatgen.analysis.chemenv.coordination_environments.structure_environments import StructureEnvironments from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/utils/test_coordination_geometry_utils.py b/tests/analysis/chemenv/utils/test_coordination_geometry_utils.py index 7124904f7cd..10e8ff8eb2e 100644 --- a/tests/analysis/chemenv/utils/test_coordination_geometry_utils.py +++ b/tests/analysis/chemenv/utils/test_coordination_geometry_utils.py @@ -5,9 +5,10 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.chemenv.utils.coordination_geometry_utils import Plane from pymatgen.util.testing import PymatgenTest -from pytest import approx __author__ = "David Waroquiers" diff --git a/tests/analysis/chemenv/utils/test_func_utils.py b/tests/analysis/chemenv/utils/test_func_utils.py index e01c3d66cc8..265c26d39f4 100644 --- a/tests/analysis/chemenv/utils/test_func_utils.py +++ b/tests/analysis/chemenv/utils/test_func_utils.py @@ -2,12 +2,13 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.analysis.chemenv.utils.func_utils import ( CSMFiniteRatioFunction, CSMInfiniteRatioFunction, DeltaCSMRatioFunction, ) -from pytest import approx __author__ = "waroquiers" diff --git a/tests/analysis/chemenv/utils/test_graph_utils.py b/tests/analysis/chemenv/utils/test_graph_utils.py index cf7ef01fa30..54ca8e9ef08 100644 --- a/tests/analysis/chemenv/utils/test_graph_utils.py +++ b/tests/analysis/chemenv/utils/test_graph_utils.py @@ -2,6 +2,7 @@ import pytest from numpy.testing import assert_allclose + from pymatgen.analysis.chemenv.connectivity.environment_nodes import EnvironmentNode from pymatgen.analysis.chemenv.utils.graph_utils import MultiGraphCycle, SimpleGraphCycle, get_delta from pymatgen.util.testing import PymatgenTest diff --git a/tests/analysis/chemenv/utils/test_math_utils.py b/tests/analysis/chemenv/utils/test_math_utils.py index 3b0cab7d229..87420c9c61e 100644 --- a/tests/analysis/chemenv/utils/test_math_utils.py +++ b/tests/analysis/chemenv/utils/test_math_utils.py @@ -2,6 +2,7 @@ import numpy as np from numpy.testing import assert_allclose + from pymatgen.analysis.chemenv.utils.math_utils import ( _cartesian_product, cosinus_step, diff --git a/tests/analysis/diffraction/test_neutron.py b/tests/analysis/diffraction/test_neutron.py index 5256710bae1..f9f54d9d4fa 100644 --- a/tests/analysis/diffraction/test_neutron.py +++ b/tests/analysis/diffraction/test_neutron.py @@ -1,11 +1,12 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.analysis.diffraction.neutron import NDCalculator from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.util.testing import PymatgenTest -from pytest import approx """ These calculated values were verified with VESTA and FullProf. diff --git a/tests/analysis/diffraction/test_tem.py b/tests/analysis/diffraction/test_tem.py index edbe77256ca..7b5a79cea67 100644 --- a/tests/analysis/diffraction/test_tem.py +++ b/tests/analysis/diffraction/test_tem.py @@ -6,11 +6,12 @@ import pandas as pd import plotly.graph_objects as go from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.diffraction.tem import TEMCalculator from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.util.testing import PymatgenTest -from pytest import approx __author__ = "Frank Wan, Jason Liang" __copyright__ = "Copyright 2019, The Materials Project" diff --git a/tests/analysis/diffraction/test_xrd.py b/tests/analysis/diffraction/test_xrd.py index ed4f9bcb2d2..9663d6fbf7d 100644 --- a/tests/analysis/diffraction/test_xrd.py +++ b/tests/analysis/diffraction/test_xrd.py @@ -1,11 +1,12 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.analysis.diffraction.xrd import XRDCalculator from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.util.testing import PymatgenTest -from pytest import approx """ TODO: Modify unittest doc. diff --git a/tests/analysis/elasticity/test_elastic.py b/tests/analysis/elasticity/test_elastic.py index 85de3ac4728..4fe8ecdc5da 100644 --- a/tests/analysis/elasticity/test_elastic.py +++ b/tests/analysis/elasticity/test_elastic.py @@ -8,6 +8,9 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx +from scipy.misc import central_diff_weights + from pymatgen.analysis.elasticity.elastic import ( ComplianceTensor, ElasticTensor, @@ -26,8 +29,6 @@ from pymatgen.core.tensors import Tensor from pymatgen.core.units import FloatWithUnit from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx -from scipy.misc import central_diff_weights TEST_DIR = f"{TEST_FILES_DIR}/analysis/elasticity" diff --git a/tests/analysis/elasticity/test_strain.py b/tests/analysis/elasticity/test_strain.py index 2ffadf975e5..e64de89f18d 100644 --- a/tests/analysis/elasticity/test_strain.py +++ b/tests/analysis/elasticity/test_strain.py @@ -3,6 +3,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.analysis.elasticity.strain import Deformation, DeformedStructureSet, Strain, convert_strain_to_deformation from pymatgen.core.structure import Structure from pymatgen.core.tensors import Tensor diff --git a/tests/analysis/elasticity/test_stress.py b/tests/analysis/elasticity/test_stress.py index 66a656243ed..cc79322de83 100644 --- a/tests/analysis/elasticity/test_stress.py +++ b/tests/analysis/elasticity/test_stress.py @@ -3,10 +3,11 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.elasticity.strain import Deformation from pymatgen.analysis.elasticity.stress import Stress from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestStress(PymatgenTest): diff --git a/tests/analysis/ferroelectricity/test_polarization.py b/tests/analysis/ferroelectricity/test_polarization.py index 67f90933990..8c5de6645fc 100644 --- a/tests/analysis/ferroelectricity/test_polarization.py +++ b/tests/analysis/ferroelectricity/test_polarization.py @@ -2,6 +2,8 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.ferroelectricity.polarization import ( EnergyTrend, Polarization, @@ -12,7 +14,6 @@ from pymatgen.io.vasp.inputs import Potcar from pymatgen.io.vasp.outputs import Outcar from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/vasp/fixtures/BTO_221_99_polarization" bto_folders = ["nonpolar_polarization"] diff --git a/tests/analysis/interfaces/test_coherent_interface.py b/tests/analysis/interfaces/test_coherent_interface.py index 1284bc8c731..081cc3c4b62 100644 --- a/tests/analysis/interfaces/test_coherent_interface.py +++ b/tests/analysis/interfaces/test_coherent_interface.py @@ -1,6 +1,7 @@ from __future__ import annotations from numpy.testing import assert_allclose + from pymatgen.analysis.interfaces.coherent_interfaces import ( CoherentInterfaceBuilder, from_2d_to_3d, diff --git a/tests/analysis/interfaces/test_substrate_analyzer.py b/tests/analysis/interfaces/test_substrate_analyzer.py index 39e75ea53f5..a842cd4dad7 100644 --- a/tests/analysis/interfaces/test_substrate_analyzer.py +++ b/tests/analysis/interfaces/test_substrate_analyzer.py @@ -1,6 +1,7 @@ from __future__ import annotations from numpy.testing import assert_allclose + from pymatgen.analysis.elasticity.elastic import ElasticTensor from pymatgen.analysis.interfaces.substrate_analyzer import SubstrateAnalyzer from pymatgen.symmetry.analyzer import SpacegroupAnalyzer diff --git a/tests/analysis/interfaces/test_zsl.py b/tests/analysis/interfaces/test_zsl.py index d4fc0bdef2a..3cc7732eb73 100644 --- a/tests/analysis/interfaces/test_zsl.py +++ b/tests/analysis/interfaces/test_zsl.py @@ -2,6 +2,8 @@ import numpy as np from numpy.testing import assert_array_equal +from pytest import approx + from pymatgen.analysis.interfaces.zsl import ( ZSLGenerator, fast_norm, @@ -12,7 +14,6 @@ ) from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import PymatgenTest -from pytest import approx __author__ = "Shyam Dwaraknath" __copyright__ = "Copyright 2016, The Materials Project" diff --git a/tests/analysis/magnetism/test_analyzer.py b/tests/analysis/magnetism/test_analyzer.py index 12e0f3029a6..bcdd02f1970 100644 --- a/tests/analysis/magnetism/test_analyzer.py +++ b/tests/analysis/magnetism/test_analyzer.py @@ -6,6 +6,8 @@ import pytest from monty.serialization import loadfn from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.magnetism import ( CollinearMagneticStructureAnalyzer, MagneticStructureEnumerator, @@ -14,7 +16,6 @@ ) from pymatgen.core import Element, Lattice, Species, Structure from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/magnetic_orderings" diff --git a/tests/analysis/magnetism/test_heisenberg.py b/tests/analysis/magnetism/test_heisenberg.py index 968e28524ce..3eea3bd6782 100644 --- a/tests/analysis/magnetism/test_heisenberg.py +++ b/tests/analysis/magnetism/test_heisenberg.py @@ -3,6 +3,7 @@ from unittest import TestCase import pandas as pd + from pymatgen.analysis.magnetism.heisenberg import HeisenbergMapper from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR diff --git a/tests/analysis/magnetism/test_jahnteller.py b/tests/analysis/magnetism/test_jahnteller.py index 3a5da6a5ef4..6cc3356b4aa 100644 --- a/tests/analysis/magnetism/test_jahnteller.py +++ b/tests/analysis/magnetism/test_jahnteller.py @@ -3,10 +3,11 @@ from unittest import TestCase import numpy as np +from pytest import approx + from pymatgen.analysis.magnetism.jahnteller import JahnTellerAnalyzer, Species from pymatgen.core import Structure from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx class TestJahnTeller(TestCase): diff --git a/tests/analysis/solar/test_slme.py b/tests/analysis/solar/test_slme.py index 906dba63bc7..4d2e12c1ef0 100644 --- a/tests/analysis/solar/test_slme.py +++ b/tests/analysis/solar/test_slme.py @@ -1,8 +1,9 @@ from __future__ import annotations +from pytest import approx + from pymatgen.analysis.solar.slme import optics, slme from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/solar" diff --git a/tests/analysis/structure_prediction/test_dopant_predictor.py b/tests/analysis/structure_prediction/test_dopant_predictor.py index a763421b0bb..f19d032624d 100644 --- a/tests/analysis/structure_prediction/test_dopant_predictor.py +++ b/tests/analysis/structure_prediction/test_dopant_predictor.py @@ -2,13 +2,14 @@ from unittest import TestCase +from pytest import approx + from pymatgen.analysis.local_env import CrystalNN from pymatgen.analysis.structure_prediction.dopant_predictor import ( get_dopants_from_shannon_radii, get_dopants_from_substitution_probabilities, ) from pymatgen.core import Species, Structure -from pytest import approx class TestDopantPrediction(TestCase): diff --git a/tests/analysis/structure_prediction/test_substitution_probability.py b/tests/analysis/structure_prediction/test_substitution_probability.py index eb9115995a3..e6c5f9786d0 100644 --- a/tests/analysis/structure_prediction/test_substitution_probability.py +++ b/tests/analysis/structure_prediction/test_substitution_probability.py @@ -3,13 +3,14 @@ import json from unittest import TestCase +from pytest import approx + from pymatgen.analysis.structure_prediction.substitution_probability import ( SubstitutionPredictor, SubstitutionProbability, ) from pymatgen.core import Composition, Species from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/struct_predictor" diff --git a/tests/analysis/structure_prediction/test_volume_predictor.py b/tests/analysis/structure_prediction/test_volume_predictor.py index 069a9dbc86f..9f6a150aeb7 100644 --- a/tests/analysis/structure_prediction/test_volume_predictor.py +++ b/tests/analysis/structure_prediction/test_volume_predictor.py @@ -1,10 +1,11 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.analysis.structure_prediction.volume_predictor import DLSVolumePredictor, RLSVolumePredictor from pymatgen.core import Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/structure_prediction" diff --git a/tests/analysis/test_adsorption.py b/tests/analysis/test_adsorption.py index 76fdc09a0cd..9f611a2feaa 100644 --- a/tests/analysis/test_adsorption.py +++ b/tests/analysis/test_adsorption.py @@ -2,6 +2,7 @@ import numpy as np from numpy.testing import assert_allclose + from pymatgen.analysis.adsorption import AdsorbateSiteFinder, generate_all_slabs, get_rot, reorient_z from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Molecule, Structure diff --git a/tests/analysis/test_bond_dissociation.py b/tests/analysis/test_bond_dissociation.py index 6dd0d7f9af6..b29081413c9 100644 --- a/tests/analysis/test_bond_dissociation.py +++ b/tests/analysis/test_bond_dissociation.py @@ -4,6 +4,7 @@ import pytest from monty.serialization import loadfn + from pymatgen.analysis.bond_dissociation import BondDissociationEnergies from pymatgen.util.testing import TEST_FILES_DIR diff --git a/tests/analysis/test_bond_valence.py b/tests/analysis/test_bond_valence.py index 5a0202e075b..8bfc158a574 100644 --- a/tests/analysis/test_bond_valence.py +++ b/tests/analysis/test_bond_valence.py @@ -1,10 +1,11 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.analysis.bond_valence import BVAnalyzer, calculate_bv_sum, calculate_bv_sum_unordered from pymatgen.core import Composition, Species, Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/bond_valence" diff --git a/tests/analysis/test_chempot_diagram.py b/tests/analysis/test_chempot_diagram.py index 5465e1d050d..648d6a5a164 100644 --- a/tests/analysis/test_chempot_diagram.py +++ b/tests/analysis/test_chempot_diagram.py @@ -2,6 +2,8 @@ import numpy as np from plotly.graph_objects import Figure +from pytest import approx + from pymatgen.analysis.chempot_diagram import ( ChemicalPotentialDiagram, get_2d_orthonormal_vector, @@ -11,7 +13,6 @@ from pymatgen.core.composition import Element from pymatgen.entries.entry_tools import EntrySet from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis" diff --git a/tests/analysis/test_cost.py b/tests/analysis/test_cost.py index 4aa598bc5fa..c20bed2fd54 100644 --- a/tests/analysis/test_cost.py +++ b/tests/analysis/test_cost.py @@ -2,9 +2,10 @@ from unittest import TestCase +from pytest import approx + from pymatgen.analysis.cost import CostAnalyzer, CostDBCSV, CostDBElements from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/cost" diff --git a/tests/analysis/test_dimensionality.py b/tests/analysis/test_dimensionality.py index b83c2cb5748..8b99c6fd5e2 100644 --- a/tests/analysis/test_dimensionality.py +++ b/tests/analysis/test_dimensionality.py @@ -3,6 +3,7 @@ import networkx as nx import pytest from monty.serialization import loadfn + from pymatgen.analysis.dimensionality import ( calculate_dimensionality_of_site, get_dimensionality_cheon, diff --git a/tests/analysis/test_disorder.py b/tests/analysis/test_disorder.py index 2384bc4e296..a4b930096ac 100644 --- a/tests/analysis/test_disorder.py +++ b/tests/analysis/test_disorder.py @@ -1,9 +1,10 @@ from __future__ import annotations +from pytest import approx + from pymatgen.analysis.disorder import get_warren_cowley_parameters from pymatgen.core import Element, Structure from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestOrderParameter(PymatgenTest): diff --git a/tests/analysis/test_energy_models.py b/tests/analysis/test_energy_models.py index a2c4aa40612..36d3024ac75 100644 --- a/tests/analysis/test_energy_models.py +++ b/tests/analysis/test_energy_models.py @@ -1,11 +1,12 @@ from __future__ import annotations +from pytest import approx + from pymatgen.analysis.energy_models import EwaldElectrostaticModel, IsingModel, SymmetryModel from pymatgen.core import Species from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx class TestEwaldElectrostaticModel: diff --git a/tests/analysis/test_eos.py b/tests/analysis/test_eos.py index f31c3723eb7..4c02d040fd4 100644 --- a/tests/analysis/test_eos.py +++ b/tests/analysis/test_eos.py @@ -2,9 +2,10 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.eos import EOS, NumericalEOS from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestEOS(PymatgenTest): diff --git a/tests/analysis/test_ewald.py b/tests/analysis/test_ewald.py index 7ada2e539dc..285152bb8fd 100644 --- a/tests/analysis/test_ewald.py +++ b/tests/analysis/test_ewald.py @@ -4,10 +4,11 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.analysis.ewald import EwaldMinimizer, EwaldSummation from pymatgen.core.structure import Structure from pymatgen.util.testing import VASP_IN_DIR -from pytest import approx class TestEwaldSummation(TestCase): diff --git a/tests/analysis/test_fragmenter.py b/tests/analysis/test_fragmenter.py index 26b2840b3ae..611edcf58dd 100644 --- a/tests/analysis/test_fragmenter.py +++ b/tests/analysis/test_fragmenter.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from pymatgen.analysis.fragmenter import Fragmenter from pymatgen.analysis.graphs import MoleculeGraph from pymatgen.analysis.local_env import OpenBabelNN diff --git a/tests/analysis/test_functional_groups.py b/tests/analysis/test_functional_groups.py index bc59a7da57d..e53a62960a0 100644 --- a/tests/analysis/test_functional_groups.py +++ b/tests/analysis/test_functional_groups.py @@ -3,6 +3,7 @@ from unittest import TestCase import pytest + from pymatgen.analysis.functional_groups import FunctionalGroupExtractor from pymatgen.analysis.graphs import MoleculeGraph from pymatgen.analysis.local_env import OpenBabelNN diff --git a/tests/analysis/test_graphs.py b/tests/analysis/test_graphs.py index 0f89320f02b..a0935af1c9b 100644 --- a/tests/analysis/test_graphs.py +++ b/tests/analysis/test_graphs.py @@ -10,6 +10,8 @@ import networkx.algorithms.isomorphism as iso import pytest from monty.serialization import loadfn +from pytest import approx + from pymatgen.analysis.graphs import MoleculeGraph, MolGraphSplitError, PeriodicSite, StructureGraph from pymatgen.analysis.local_env import ( CovalentBondNN, @@ -23,7 +25,6 @@ from pymatgen.core import Lattice, Molecule, Site, Structure from pymatgen.core.structure import FunctionalGroups from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx try: from openbabel import openbabel @@ -372,7 +373,7 @@ def test_draw(self): bc_square_sg_r.draw_graph_to_file(f"{self.tmp_path}/bc_square_r.pdf", algo="neato", image_labels=False) # ensure PDF files were created - pdfs = {path.split("/") for path in glob(f"{self.tmp_path}/*.pdf")} + pdfs = {path.split("/")[-1] for path in glob(f"{self.tmp_path}/*.pdf")} expected_pdfs = { "bc_square_r_single.pdf", "bc_square_r.pdf", diff --git a/tests/analysis/test_hhi.py b/tests/analysis/test_hhi.py index d1c95587124..105072426fb 100644 --- a/tests/analysis/test_hhi.py +++ b/tests/analysis/test_hhi.py @@ -1,8 +1,9 @@ from __future__ import annotations -from pymatgen.analysis.hhi import HHIModel from pytest import approx +from pymatgen.analysis.hhi import HHIModel + class TestHHIModel: def test_hhi(self): diff --git a/tests/analysis/test_interface_reactions.py b/tests/analysis/test_interface_reactions.py index 0b9665fef4f..3c47a6f9f3c 100644 --- a/tests/analysis/test_interface_reactions.py +++ b/tests/analysis/test_interface_reactions.py @@ -8,12 +8,13 @@ from numpy.testing import assert_allclose from pandas import DataFrame from plotly.graph_objects import Figure +from scipy.spatial import ConvexHull + from pymatgen.analysis.interface_reactions import GrandPotentialInterfacialReactivity, InterfacialReactivity from pymatgen.analysis.phase_diagram import GrandPotentialPhaseDiagram, PhaseDiagram from pymatgen.analysis.reaction_calculator import Reaction from pymatgen.core.composition import Composition, Element from pymatgen.entries.computed_entries import ComputedEntry -from scipy.spatial import ConvexHull class TestInterfaceReaction(TestCase): diff --git a/tests/analysis/test_local_env.py b/tests/analysis/test_local_env.py index 1fbed74da82..62adc327b57 100644 --- a/tests/analysis/test_local_env.py +++ b/tests/analysis/test_local_env.py @@ -7,6 +7,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph from pymatgen.analysis.local_env import ( BrunnerNNReal, @@ -36,7 +38,6 @@ ) from pymatgen.core import Element, Lattice, Molecule, Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/local_env/fragmenter_files" @@ -1201,7 +1202,6 @@ def test_weighted_cn(self): def test_weighted_cn_no_oxid(self): cnn = CrystalNN(weighted_cn=True) - cn_array = [] # fmt: off expected_array = [ 5.8962, 5.8996, 5.8962, 5.8996, 5.7195, 5.7195, 5.7202, 5.7194, 4.0012, 4.0012, @@ -1210,8 +1210,7 @@ def test_weighted_cn_no_oxid(self): ] # fmt: on struct = self.lifepo4.copy().remove_oxidation_states() - for idx in range(len(struct)): - cn_array.append(cnn.get_cn(struct, idx, use_weights=True)) + cn_array = [cnn.get_cn(struct, idx, use_weights=True) for idx in range(len(struct))] assert_allclose(expected_array, cn_array, 2) diff --git a/tests/analysis/test_molecule_matcher.py b/tests/analysis/test_molecule_matcher.py index cca70ae1c2c..485e3c75f2f 100644 --- a/tests/analysis/test_molecule_matcher.py +++ b/tests/analysis/test_molecule_matcher.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.analysis.molecule_matcher import ( BruteForceOrderMatcher, GeneticOrderMatcher, @@ -17,7 +19,6 @@ from pymatgen.core.structure import Lattice, Molecule, Structure from pymatgen.io.xyz import XYZ from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx try: from openbabel import openbabel diff --git a/tests/analysis/test_nmr.py b/tests/analysis/test_nmr.py index 86961996a71..14490ec7a0c 100644 --- a/tests/analysis/test_nmr.py +++ b/tests/analysis/test_nmr.py @@ -2,9 +2,10 @@ import numpy as np from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.analysis.nmr import ChemicalShielding, ElectricFieldGradient from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestChemicalShieldingNotation(PymatgenTest): diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 997a7cd0bac..4e1e2dfb3f1 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -13,6 +13,8 @@ import pytest from monty.serialization import dumpfn, loadfn from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.phase_diagram import ( CompoundPhaseDiagram, GrandPotentialPhaseDiagram, @@ -31,7 +33,6 @@ from pymatgen.entries.computed_entries import ComputedEntry from pymatgen.entries.entry_tools import EntrySet from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis" diff --git a/tests/analysis/test_piezo.py b/tests/analysis/test_piezo.py index 0bd01459e6f..bf3eb7534c2 100644 --- a/tests/analysis/test_piezo.py +++ b/tests/analysis/test_piezo.py @@ -5,6 +5,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal + from pymatgen.analysis.piezo import PiezoTensor from pymatgen.util.testing import PymatgenTest diff --git a/tests/analysis/test_piezo_sensitivity.py b/tests/analysis/test_piezo_sensitivity.py index 19f2444f028..eafaaddc6b4 100644 --- a/tests/analysis/test_piezo_sensitivity.py +++ b/tests/analysis/test_piezo_sensitivity.py @@ -5,9 +5,10 @@ import pickle import numpy as np -import pymatgen import pytest from numpy.testing import assert_allclose + +import pymatgen from pymatgen.analysis.piezo_sensitivity import ( BornEffectiveCharge, ForceConstantMatrix, diff --git a/tests/analysis/test_pourbaix_diagram.py b/tests/analysis/test_pourbaix_diagram.py index 6ea90a801b3..3119877f33b 100644 --- a/tests/analysis/test_pourbaix_diagram.py +++ b/tests/analysis/test_pourbaix_diagram.py @@ -7,12 +7,13 @@ import matplotlib.pyplot as plt import numpy as np from monty.serialization import dumpfn, loadfn +from pytest import approx + from pymatgen.analysis.pourbaix_diagram import IonEntry, MultiEntry, PourbaixDiagram, PourbaixEntry, PourbaixPlotter from pymatgen.core.composition import Composition from pymatgen.core.ion import Ion from pymatgen.entries.computed_entries import ComputedEntry from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/pourbaix_diagram" diff --git a/tests/analysis/test_quasi_harmonic_debye_approx.py b/tests/analysis/test_quasi_harmonic_debye_approx.py index fce94e37d44..6c0c5816d2d 100644 --- a/tests/analysis/test_quasi_harmonic_debye_approx.py +++ b/tests/analysis/test_quasi_harmonic_debye_approx.py @@ -4,6 +4,7 @@ import numpy as np from numpy.testing import assert_allclose + from pymatgen.analysis.eos import EOS from pymatgen.analysis.quasiharmonic import QuasiHarmonicDebyeApprox from pymatgen.core.structure import Structure @@ -192,7 +193,7 @@ def test_debye_temperature(self): def test_gruneisen_parameter(self): gamma = self.qhda.gruneisen_parameter(0, self.qhda.ev_eos_fit.v0) - assert_allclose(gamma, 2.188302, atol=1e-3) + assert_allclose(gamma, 2.188302, atol=1e-2) def test_thermal_conductivity(self): kappa = self.qhda.thermal_conductivity(self.T, self.opt_vol) diff --git a/tests/analysis/test_quasirrho.py b/tests/analysis/test_quasirrho.py index 37829e37a75..63cd0015ad4 100644 --- a/tests/analysis/test_quasirrho.py +++ b/tests/analysis/test_quasirrho.py @@ -3,6 +3,7 @@ from unittest import TestCase import pytest + from pymatgen.analysis.quasirrho import QuasiRRHO, get_avg_mom_inertia from pymatgen.io.gaussian import GaussianOutput from pymatgen.io.qchem.outputs import QCOutput diff --git a/tests/analysis/test_reaction_calculator.py b/tests/analysis/test_reaction_calculator.py index f8584fd5549..dcb06819574 100644 --- a/tests/analysis/test_reaction_calculator.py +++ b/tests/analysis/test_reaction_calculator.py @@ -6,10 +6,11 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.analysis.reaction_calculator import BalancedReaction, ComputedReaction, Reaction, ReactionError from pymatgen.core.composition import Composition from pymatgen.entries.computed_entries import ComputedEntry -from pytest import approx class TestReaction: diff --git a/tests/analysis/test_structure_analyzer.py b/tests/analysis/test_structure_analyzer.py index 1639729a56a..12a0c02f991 100644 --- a/tests/analysis/test_structure_analyzer.py +++ b/tests/analysis/test_structure_analyzer.py @@ -4,6 +4,8 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.structure_analyzer import ( RelaxationAnalyzer, VoronoiAnalyzer, @@ -17,7 +19,6 @@ from pymatgen.core import Element, Lattice, Structure from pymatgen.io.vasp.outputs import Xdatcar from pymatgen.util.testing import VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx class TestVoronoiAnalyzer(PymatgenTest): diff --git a/tests/analysis/test_structure_matcher.py b/tests/analysis/test_structure_matcher.py index 28707ef44de..c658585cf75 100644 --- a/tests/analysis/test_structure_matcher.py +++ b/tests/analysis/test_structure_matcher.py @@ -6,6 +6,8 @@ import pytest from monty.json import MontyDecoder from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.analysis.structure_matcher import ( ElementComparator, FrameworkComparator, @@ -16,7 +18,6 @@ from pymatgen.core import Element, Lattice, Structure, SymmOp from pymatgen.util.coord import find_in_coord_list_pbc from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/structure_matcher" diff --git a/tests/analysis/test_surface_analysis.py b/tests/analysis/test_surface_analysis.py index 07195437d80..0e58a2099d5 100644 --- a/tests/analysis/test_surface_analysis.py +++ b/tests/analysis/test_surface_analysis.py @@ -3,11 +3,12 @@ import json from numpy.testing import assert_allclose +from pytest import approx +from sympy import Number, Symbol + from pymatgen.analysis.surface_analysis import NanoscaleStability, SlabEntry, SurfaceEnergyPlotter, WorkFunctionAnalyzer from pymatgen.entries.computed_entries import ComputedStructureEntry from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx -from sympy import Number, Symbol __author__ = "Richard Tran" __copyright__ = "Copyright 2012, The Materials Project" diff --git a/tests/analysis/test_transition_state.py b/tests/analysis/test_transition_state.py index 0e404e02fc8..96841bc482a 100644 --- a/tests/analysis/test_transition_state.py +++ b/tests/analysis/test_transition_state.py @@ -4,6 +4,7 @@ from matplotlib import pyplot as plt from numpy.testing import assert_allclose + from pymatgen.analysis.transition_state import NEBAnalysis, combine_neb_plots from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/analysis/test_wulff.py b/tests/analysis/test_wulff.py index c4656a79097..11109c1e746 100644 --- a/tests/analysis/test_wulff.py +++ b/tests/analysis/test_wulff.py @@ -2,13 +2,14 @@ import json +from pytest import approx + from pymatgen.analysis.wulff import WulffShape from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.coord import in_coord_list from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx __author__ = "Zihan Xu, Richard Tran, Balachandran Radhakrishnan" __copyright__ = "Copyright 2013, The Materials Virtual Lab" diff --git a/tests/analysis/topological/test_spillage.py b/tests/analysis/topological/test_spillage.py index 61652c97696..39224cdfc52 100644 --- a/tests/analysis/topological/test_spillage.py +++ b/tests/analysis/topological/test_spillage.py @@ -1,8 +1,9 @@ from __future__ import annotations +from pytest import approx + from pymatgen.analysis.topological.spillage import SOCSpillage from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/topological" diff --git a/tests/analysis/xas/test_spectrum.py b/tests/analysis/xas/test_spectrum.py index 25772d15fb7..ccc1c3812a5 100644 --- a/tests/analysis/xas/test_spectrum.py +++ b/tests/analysis/xas/test_spectrum.py @@ -6,10 +6,11 @@ import pytest from monty.json import MontyDecoder from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.analysis.xas.spectrum import XAS, site_weighted_spectrum from pymatgen.core import Element from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/spectrum_test" diff --git a/tests/apps/battery/test_analyzer.py b/tests/apps/battery/test_analyzer.py index a7a22ccf015..61426c65510 100644 --- a/tests/apps/battery/test_analyzer.py +++ b/tests/apps/battery/test_analyzer.py @@ -1,10 +1,11 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.apps.battery.analyzer import BatteryAnalyzer from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx class TestBatteryAnalyzer(PymatgenTest): diff --git a/tests/apps/battery/test_conversion_battery.py b/tests/apps/battery/test_conversion_battery.py index 041702625a4..ae17c969353 100644 --- a/tests/apps/battery/test_conversion_battery.py +++ b/tests/apps/battery/test_conversion_battery.py @@ -4,10 +4,11 @@ from unittest import TestCase from monty.json import MontyDecoder +from pytest import approx + from pymatgen.apps.battery.conversion_battery import ConversionElectrode, ConversionVoltagePair from pymatgen.core.composition import Composition from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/apps/battery" diff --git a/tests/apps/battery/test_insertion_battery.py b/tests/apps/battery/test_insertion_battery.py index 06c6ef0802f..4c6b9caa951 100644 --- a/tests/apps/battery/test_insertion_battery.py +++ b/tests/apps/battery/test_insertion_battery.py @@ -4,10 +4,11 @@ from unittest import TestCase from monty.json import MontyDecoder, MontyEncoder +from pytest import approx + from pymatgen.apps.battery.insertion_battery import InsertionElectrode, InsertionVoltagePair from pymatgen.entries.computed_entries import ComputedEntry from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/apps/battery" diff --git a/tests/apps/battery/test_plotter.py b/tests/apps/battery/test_plotter.py index c9e14875e25..8b93b20f8bf 100644 --- a/tests/apps/battery/test_plotter.py +++ b/tests/apps/battery/test_plotter.py @@ -4,6 +4,7 @@ from unittest import TestCase from monty.json import MontyDecoder + from pymatgen.apps.battery.conversion_battery import ConversionElectrode from pymatgen.apps.battery.insertion_battery import InsertionElectrode from pymatgen.apps.battery.plotter import VoltageProfilePlotter diff --git a/tests/apps/borg/test_hive.py b/tests/apps/borg/test_hive.py index 64b64814c70..766086d429b 100644 --- a/tests/apps/borg/test_hive.py +++ b/tests/apps/borg/test_hive.py @@ -3,6 +3,8 @@ import os from unittest import TestCase +from pytest import approx + from pymatgen.apps.borg.hive import ( GaussianToComputedEntryDrone, SimpleVaspToComputedEntryDrone, @@ -10,7 +12,6 @@ ) from pymatgen.entries.computed_entries import ComputedStructureEntry from pymatgen.util.testing import TEST_FILES_DIR, VASP_OUT_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/apps/borg" diff --git a/tests/apps/borg/test_queen.py b/tests/apps/borg/test_queen.py index 14cc20679a9..b95e660f7f8 100644 --- a/tests/apps/borg/test_queen.py +++ b/tests/apps/borg/test_queen.py @@ -1,9 +1,10 @@ from __future__ import annotations +from pytest import approx + from pymatgen.apps.borg.hive import VaspToComputedEntryDrone from pymatgen.apps.borg.queen import BorgQueen from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" diff --git a/tests/command_line/test_bader_caller.py b/tests/command_line/test_bader_caller.py index 5bff1599c62..9b412303eb3 100644 --- a/tests/command_line/test_bader_caller.py +++ b/tests/command_line/test_bader_caller.py @@ -2,15 +2,15 @@ import warnings from shutil import which -from unittest.mock import patch import numpy as np import pytest from monty.shutil import copy_r from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.command_line.bader_caller import BaderAnalysis, bader_analysis_from_path from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/command_line/bader" @@ -59,7 +59,6 @@ def test_init(self): assert len(analysis.data) == 14 # Test Cube file format parsing - copy_r(TEST_DIR, self.tmp_path) analysis = BaderAnalysis(cube_filename=f"{TEST_DIR}/elec.cube.gz") assert len(analysis.data) == 9 @@ -75,17 +74,17 @@ def test_from_path(self): analysis = BaderAnalysis(chgcar_filename=chgcar_path, chgref_filename=chgref_path) analysis_from_path = BaderAnalysis.from_path(from_path_dir) - for key in analysis_from_path.summary: - val, val_from_path = analysis.summary[key], analysis_from_path.summary[key] - if isinstance(analysis_from_path.summary[key], (bool, str)): + for key, val_from_path in analysis_from_path.summary.items(): + val = analysis.summary[key] + if isinstance(val_from_path, (bool, str)): assert val == val_from_path, f"{key=}" elif key == "charge": assert_allclose(val, val_from_path, atol=1e-5) def test_bader_analysis_from_path(self): - summary = bader_analysis_from_path(TEST_DIR) """ Reference summary dict (with bader 1.0) + summary_ref = { "magmom": [4.298761, 4.221997, 4.221997, 3.816685, 4.221997, 4.298763, 0.36292, 0.370516, 0.36292, 0.36292, 0.36292, 0.36292, 0.36292, 0.370516], @@ -101,6 +100,9 @@ def test_bader_analysis_from_path(self): "reference_used": True, } """ + + summary = bader_analysis_from_path(TEST_DIR) + assert set(summary) == { "magmom", "min_dist", @@ -130,12 +132,11 @@ def test_atom_parsing(self): ) def test_missing_file_bader_exe_path(self): - pytest.skip("doesn't reliably raise RuntimeError") - # mock which("bader") to return None so we always fall back to use bader_exe_path - with ( - patch("shutil.which", return_value=None), - pytest.raises( - RuntimeError, match="BaderAnalysis requires the executable bader be in the PATH or the full path " - ), - ): - BaderAnalysis(chgcar_filename=f"{VASP_OUT_DIR}/CHGCAR.Fe3O4.gz", bader_exe_path="") + # Mock which("bader") to return None so we always fall back to use bader_exe_path + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setenv("PATH", "") + + with pytest.raises( + RuntimeError, match="Requires bader or bader.exe to be in the PATH or the absolute path" + ): + BaderAnalysis(chgcar_filename=f"{VASP_OUT_DIR}/CHGCAR.Fe3O4.gz") diff --git a/tests/command_line/test_critic2_caller.py b/tests/command_line/test_critic2_caller.py index 4f1a7cf648f..46218f94e8c 100644 --- a/tests/command_line/test_critic2_caller.py +++ b/tests/command_line/test_critic2_caller.py @@ -4,10 +4,11 @@ from unittest import TestCase import pytest +from pytest import approx + from pymatgen.command_line.critic2_caller import Critic2Analysis, Critic2Caller from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx __author__ = "Matthew Horton" __version__ = "0.1" diff --git a/tests/command_line/test_enumlib_caller.py b/tests/command_line/test_enumlib_caller.py index d21b5b67986..010bb0bd1b8 100644 --- a/tests/command_line/test_enumlib_caller.py +++ b/tests/command_line/test_enumlib_caller.py @@ -5,13 +5,14 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.command_line.enumlib_caller import EnumError, EnumlibAdaptor from pymatgen.core import Element, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.transformations.site_transformations import RemoveSitesTransformation from pymatgen.transformations.standard_transformations import SubstitutionTransformation from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx enum_cmd = which("enum.x") or which("multienum.x") makestr_cmd = which("makestr.x") or which("makeStr.x") or which("makeStr.py") diff --git a/tests/command_line/test_gulp_caller.py b/tests/command_line/test_gulp_caller.py index aa0faf1d16e..054d0e7c023 100644 --- a/tests/command_line/test_gulp_caller.py +++ b/tests/command_line/test_gulp_caller.py @@ -14,6 +14,7 @@ import numpy as np import pytest + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.command_line.gulp_caller import ( BuckinghamPotential, diff --git a/tests/command_line/test_mcsqs_caller.py b/tests/command_line/test_mcsqs_caller.py index 9dd6303a1dc..39714b1f4f1 100644 --- a/tests/command_line/test_mcsqs_caller.py +++ b/tests/command_line/test_mcsqs_caller.py @@ -4,6 +4,7 @@ import pytest from monty.serialization import loadfn + from pymatgen.command_line.mcsqs_caller import run_mcsqs from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/command_line/test_vampire_caller.py b/tests/command_line/test_vampire_caller.py index e8cc2330f02..19efe2eef1f 100644 --- a/tests/command_line/test_vampire_caller.py +++ b/tests/command_line/test_vampire_caller.py @@ -4,10 +4,11 @@ import pandas as pd import pytest +from pytest import approx + from pymatgen.command_line.vampire_caller import VampireCaller from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/analysis/magnetic_orderings" diff --git a/tests/core/test_bonds.py b/tests/core/test_bonds.py index 9154207aba6..777d938619a 100644 --- a/tests/core/test_bonds.py +++ b/tests/core/test_bonds.py @@ -1,9 +1,10 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.core import Element, Site from pymatgen.core.bonds import CovalentBond, get_bond_length, get_bond_order, obtain_all_bond_lengths -from pytest import approx __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" diff --git a/tests/core/test_composition.py b/tests/core/test_composition.py index 7c9b4dcdd33..e1f799f162c 100644 --- a/tests/core/test_composition.py +++ b/tests/core/test_composition.py @@ -10,10 +10,11 @@ import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core import Composition, DummySpecies, Element, Species from pymatgen.core.composition import ChemicalPotential from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestComposition(PymatgenTest): diff --git a/tests/core/test_interface.py b/tests/core/test_interface.py index 1757265fd27..22cd178dd5b 100644 --- a/tests/core/test_interface.py +++ b/tests/core/test_interface.py @@ -2,12 +2,13 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.interface import GrainBoundary, GrainBoundaryGenerator, Interface from pymatgen.core.structure import Structure from pymatgen.core.surface import SlabGenerator from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/core/grain_boundary" diff --git a/tests/core/test_ion.py b/tests/core/test_ion.py index 5e5f5482dfb..a3b4003a7fb 100644 --- a/tests/core/test_ion.py +++ b/tests/core/test_ion.py @@ -4,6 +4,7 @@ from unittest import TestCase import pytest + from pymatgen.core import Composition, Element from pymatgen.core.ion import Ion diff --git a/tests/core/test_lattice.py b/tests/core/test_lattice.py index 91a84688039..ce5cff053aa 100644 --- a/tests/core/test_lattice.py +++ b/tests/core/test_lattice.py @@ -5,10 +5,11 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core.lattice import Lattice, get_points_in_spheres from pymatgen.core.operations import SymmOp from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestLattice(PymatgenTest): diff --git a/tests/core/test_molecular_orbitals.py b/tests/core/test_molecular_orbitals.py index a1699967585..4476c856d17 100644 --- a/tests/core/test_molecular_orbitals.py +++ b/tests/core/test_molecular_orbitals.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from pymatgen.core.molecular_orbitals import MolecularOrbitals from pymatgen.util.testing import PymatgenTest diff --git a/tests/core/test_operations.py b/tests/core/test_operations.py index abd9d4357a6..d2dcc26de42 100644 --- a/tests/core/test_operations.py +++ b/tests/core/test_operations.py @@ -2,6 +2,7 @@ import numpy as np from numpy.testing import assert_allclose + from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.electronic_structure.core import Magmom from pymatgen.util.testing import PymatgenTest diff --git a/tests/core/test_periodic_table.py b/tests/core/test_periodic_table.py index 586a6a513bf..e157edb3ad0 100644 --- a/tests/core/test_periodic_table.py +++ b/tests/core/test_periodic_table.py @@ -8,12 +8,13 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.core import DummySpecies, Element, Species, get_el_sp from pymatgen.core.periodic_table import ElementBase, ElementType from pymatgen.core.units import Ha_to_eV from pymatgen.io.core import ParseError from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestElement(PymatgenTest): diff --git a/tests/core/test_sites.py b/tests/core/test_sites.py index ee3911c29eb..a80689d9dfd 100644 --- a/tests/core/test_sites.py +++ b/tests/core/test_sites.py @@ -5,10 +5,11 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core import Composition, Element, Lattice, PeriodicSite, Site, Species from pymatgen.electronic_structure.core import Magmom from pymatgen.util.testing import PymatgenTest -from pytest import approx class TestSite(PymatgenTest): diff --git a/tests/core/test_spectrum.py b/tests/core/test_spectrum.py index 9f6fe473eba..86fb7e7266b 100644 --- a/tests/core/test_spectrum.py +++ b/tests/core/test_spectrum.py @@ -2,11 +2,12 @@ import numpy as np from numpy.testing import assert_allclose -from pymatgen.core.spectrum import Spectrum -from pymatgen.util.testing import PymatgenTest from pytest import approx from scipy import stats +from pymatgen.core.spectrum import Spectrum +from pymatgen.util.testing import PymatgenTest + class TestSpectrum(PymatgenTest): def setUp(self): diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index 20769d9cde0..2e50eb6d948 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -12,6 +12,8 @@ import pytest from monty.json import MontyDecoder, MontyEncoder from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core import SETTINGS, Composition, Element, Lattice, Species from pymatgen.core.operations import SymmOp from pymatgen.core.structure import ( @@ -28,7 +30,6 @@ from pymatgen.io.cif import CifParser from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest -from pytest import approx try: from ase.atoms import Atoms @@ -1768,6 +1769,7 @@ def test_relax_ase_opt_kwargs(self): assert traj[0] != traj[-1] assert os.path.isfile(traj_file) + @pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency") def test_calculate_m3gnet(self): pytest.importorskip("matgl") calculator = self.get_structure("Si").calculate() @@ -1779,6 +1781,7 @@ def test_calculate_m3gnet(self): assert np.linalg.norm(calculator.results["forces"]) == approx(7.8123485e-06, abs=0.2) assert np.linalg.norm(calculator.results["stress"]) == approx(1.7861567, abs=2) + @pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency") def test_relax_m3gnet(self): matgl = pytest.importorskip("matgl") struct = self.get_structure("Si") @@ -1789,6 +1792,7 @@ def test_relax_m3gnet(self): actual = relaxed.dynamics[key] assert actual == val, f"expected {key} to be {val}, {actual=}" + @pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency") def test_relax_m3gnet_fixed_lattice(self): matgl = pytest.importorskip("matgl") struct = self.get_structure("Si") @@ -1797,6 +1801,7 @@ def test_relax_m3gnet_fixed_lattice(self): assert isinstance(relaxed.calc, matgl.ext.ase.M3GNetCalculator) assert relaxed.dynamics["optimizer"] == "BFGS" + @pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency") def test_relax_m3gnet_with_traj(self): pytest.importorskip("matgl") struct = self.get_structure("Si") diff --git a/tests/core/test_surface.py b/tests/core/test_surface.py index f910efdfa03..d27f46a77df 100644 --- a/tests/core/test_surface.py +++ b/tests/core/test_surface.py @@ -6,8 +6,10 @@ import unittest import numpy as np -import pymatgen from numpy.testing import assert_allclose +from pytest import approx + +import pymatgen from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import Lattice, Structure from pymatgen.core.surface import ( @@ -24,7 +26,6 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.symmetry.groups import SpaceGroup from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx class TestSlab(PymatgenTest): diff --git a/tests/core/test_tensors.py b/tests/core/test_tensors.py index 0be6d488207..28548c3c3ba 100644 --- a/tests/core/test_tensors.py +++ b/tests/core/test_tensors.py @@ -6,11 +6,12 @@ import pytest from monty.serialization import MontyDecoder, loadfn from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.operations import SymmOp from pymatgen.core.tensors import SquareTensor, Tensor, TensorCollection, TensorMapping, itertools, symmetry_reduce from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx class TestTensor(PymatgenTest): diff --git a/tests/core/test_trajectory.py b/tests/core/test_trajectory.py index f767412e7c8..444b476c66b 100644 --- a/tests/core/test_trajectory.py +++ b/tests/core/test_trajectory.py @@ -6,6 +6,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Molecule, Structure from pymatgen.core.trajectory import Trajectory diff --git a/tests/core/test_units.py b/tests/core/test_units.py index f2c2881b189..2abfd5df8f0 100644 --- a/tests/core/test_units.py +++ b/tests/core/test_units.py @@ -4,6 +4,8 @@ import pytest from numpy.testing import assert_array_equal +from pytest import approx + from pymatgen.core.units import ( ArrayWithUnit, Energy, @@ -25,7 +27,6 @@ unitized, ) from pymatgen.util.testing import PymatgenTest -from pytest import approx def test_unit_conversions(): diff --git a/tests/core/test_xcfunc.py b/tests/core/test_xcfunc.py index 8eba2c2dcc8..1cc64bbebc8 100644 --- a/tests/core/test_xcfunc.py +++ b/tests/core/test_xcfunc.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from pymatgen.core.xcfunc import XcFunc from pymatgen.util.testing import PymatgenTest diff --git a/tests/electronic_structure/test_bandstructure.py b/tests/electronic_structure/test_bandstructure.py index 0daa57fcab2..a6d9d9dfc41 100644 --- a/tests/electronic_structure/test_bandstructure.py +++ b/tests/electronic_structure/test_bandstructure.py @@ -8,6 +8,8 @@ import pytest from monty.serialization import loadfn from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.lattice import Lattice from pymatgen.electronic_structure.bandstructure import ( BandStructureSymmLine, @@ -19,7 +21,6 @@ from pymatgen.electronic_structure.plotter import BSPlotterProjected from pymatgen.io.vasp import BSVasprun from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/electronic_structure/bandstructure" diff --git a/tests/electronic_structure/test_boltztrap.py b/tests/electronic_structure/test_boltztrap.py index 8bdb69a5e19..4d1637dd56d 100644 --- a/tests/electronic_structure/test_boltztrap.py +++ b/tests/electronic_structure/test_boltztrap.py @@ -6,11 +6,12 @@ import pytest from monty.serialization import loadfn +from pytest import approx + from pymatgen.electronic_structure.bandstructure import BandStructure from pymatgen.electronic_structure.boltztrap import BoltztrapAnalyzer, BoltztrapRunner from pymatgen.electronic_structure.core import OrbitalType, Spin from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx try: from ase.io.cube import read_cube diff --git a/tests/electronic_structure/test_boltztrap2.py b/tests/electronic_structure/test_boltztrap2.py index b8660fbe61d..1746cd64b87 100644 --- a/tests/electronic_structure/test_boltztrap2.py +++ b/tests/electronic_structure/test_boltztrap2.py @@ -5,10 +5,11 @@ import numpy as np import pytest from monty.serialization import loadfn +from pytest import approx + from pymatgen.electronic_structure.core import OrbitalType, Spin from pymatgen.io.vasp import Vasprun from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx try: from pymatgen.electronic_structure.boltztrap2 import ( @@ -315,7 +316,9 @@ def test_plot(self): assert self.bztPlotter is not None fig = self.bztPlotter.plot_props("S", "mu", "temp", temps=[300, 500]) assert fig is not None + fig = self.bztPlotter.plot_bands() assert fig is not None + fig = self.bztPlotter.plot_dos() assert fig is not None diff --git a/tests/electronic_structure/test_cohp.py b/tests/electronic_structure/test_cohp.py index 2b44b6eca1b..2f9f827ba8b 100644 --- a/tests/electronic_structure/test_cohp.py +++ b/tests/electronic_structure/test_cohp.py @@ -5,6 +5,8 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.electronic_structure.cohp import ( Cohp, CompleteCohp, @@ -14,7 +16,6 @@ ) from pymatgen.electronic_structure.core import Orbital, Spin from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/electronic_structure/cohp" diff --git a/tests/electronic_structure/test_core.py b/tests/electronic_structure/test_core.py index 243d1bef73b..2abad986cc4 100644 --- a/tests/electronic_structure/test_core.py +++ b/tests/electronic_structure/test_core.py @@ -3,6 +3,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.core import Lattice from pymatgen.electronic_structure.core import Magmom, Orbital, Spin diff --git a/tests/electronic_structure/test_dos.py b/tests/electronic_structure/test_dos.py index e515f0411c4..c69cbcc54ba 100644 --- a/tests/electronic_structure/test_dos.py +++ b/tests/electronic_structure/test_dos.py @@ -8,11 +8,12 @@ from monty.io import zopen from monty.serialization import loadfn from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core import Element, Structure from pymatgen.electronic_structure.core import Orbital, OrbitalType, Spin from pymatgen.electronic_structure.dos import DOS, CompleteDos, FermiDos, LobsterCompleteDos from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/electronic_structure/dos" diff --git a/tests/electronic_structure/test_plotter.py b/tests/electronic_structure/test_plotter.py index 304c7216083..9728bc8d5d4 100644 --- a/tests/electronic_structure/test_plotter.py +++ b/tests/electronic_structure/test_plotter.py @@ -10,6 +10,8 @@ import pytest from matplotlib import rc from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.structure import Structure from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.boltztrap import BoltztrapAnalyzer @@ -29,7 +31,6 @@ ) from pymatgen.io.vasp import Vasprun from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx BAND_TEST_DIR = f"{TEST_FILES_DIR}/electronic_structure/bandstructure" diff --git a/tests/entries/test_compatibility.py b/tests/entries/test_compatibility.py index 7eeea95e918..03578f2f135 100644 --- a/tests/entries/test_compatibility.py +++ b/tests/entries/test_compatibility.py @@ -9,9 +9,11 @@ from typing import TYPE_CHECKING from unittest import TestCase -import pymatgen import pytest from monty.json import MontyDecoder +from pytest import approx + +import pymatgen from pymatgen.core import Element, Species from pymatgen.core.composition import Composition from pymatgen.core.lattice import Lattice @@ -32,7 +34,6 @@ ) from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry, ConstantEnergyAdjustment from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx if TYPE_CHECKING: from pymatgen.util.typing import CompositionLike @@ -480,6 +481,17 @@ def test_process_entries(self): entries = self.compat.process_entries([self.entry1, self.entry2, self.entry3, self.entry4]) assert len(entries) == 2 + def test_parallel_process_entries(self): + with pytest.raises(ValueError, match="Parallel processing is not possible with for 'inplace=True'"): + entries = self.compat.process_entries( + [self.entry1, self.entry2, self.entry3, self.entry4], inplace=True, n_workers=2 + ) + + entries = self.compat.process_entries( + [self.entry1, self.entry2, self.entry3, self.entry4], inplace=False, n_workers=2 + ) + assert len(entries) == 2 + def test_msonable(self): compat_dict = self.compat.as_dict() decoder = MontyDecoder() @@ -1878,6 +1890,22 @@ def test_processing_entries_inplace(self): MaterialsProjectAqueousCompatibility().process_entries(entries, inplace=False) assert all(e.correction == e_copy.correction for e, e_copy in zip(entries, entries_copy)) + def test_parallel_process_entries(self): + hydrate_entry = ComputedEntry(Composition("FeH4O2"), -10) # nH2O = 2 + hydrate_entry2 = ComputedEntry(Composition("Li2O2H2"), -10) # nH2O = 0 + + entry_list = [hydrate_entry, hydrate_entry2] + + compat = MaterialsProjectAqueousCompatibility( + o2_energy=-10, h2o_energy=-20, h2o_adjustments=-0.5, solid_compat=None + ) + + with pytest.raises(ValueError, match="Parallel processing is not possible with for 'inplace=True'"): + entries = compat.process_entries(entry_list, inplace=True, n_workers=2) + + entries = compat.process_entries(entry_list, inplace=False, n_workers=2, on_error="raise") + assert len(entries) == 2 + class TestAqueousCorrection(TestCase): def setUp(self): diff --git a/tests/entries/test_computed_entries.py b/tests/entries/test_computed_entries.py index a7a00c7f88c..aee3aade715 100644 --- a/tests/entries/test_computed_entries.py +++ b/tests/entries/test_computed_entries.py @@ -7,6 +7,8 @@ import pytest from monty.json import MontyDecoder +from pytest import approx + from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from pymatgen.entries.computed_entries import ( @@ -21,7 +23,6 @@ ) from pymatgen.io.vasp.outputs import Vasprun from pymatgen.util.testing import TEST_FILES_DIR, VASP_OUT_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/entries" diff --git a/tests/entries/test_correction_calculator.py b/tests/entries/test_correction_calculator.py index c5b8f90a657..02d0fb517a8 100644 --- a/tests/entries/test_correction_calculator.py +++ b/tests/entries/test_correction_calculator.py @@ -3,6 +3,7 @@ from unittest import TestCase import pytest + from pymatgen.entries.correction_calculator import CorrectionCalculator from pymatgen.util.testing import TEST_FILES_DIR diff --git a/tests/entries/test_entry_tools.py b/tests/entries/test_entry_tools.py index a939dd4106e..01f4255d8ff 100644 --- a/tests/entries/test_entry_tools.py +++ b/tests/entries/test_entry_tools.py @@ -4,6 +4,7 @@ import pytest from monty.serialization import dumpfn, loadfn + from pymatgen.core import Element from pymatgen.entries.computed_entries import ComputedEntry from pymatgen.entries.entry_tools import EntrySet, group_entries_by_composition, group_entries_by_structure diff --git a/tests/entries/test_exp_entries.py b/tests/entries/test_exp_entries.py index 4ebb14d9d70..b73be4ad8a4 100644 --- a/tests/entries/test_exp_entries.py +++ b/tests/entries/test_exp_entries.py @@ -4,9 +4,10 @@ from unittest import TestCase from monty.json import MontyDecoder +from pytest import approx + from pymatgen.entries.exp_entries import ExpEntry from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx class TestExpEntry(TestCase): diff --git a/tests/entries/test_mixing_scheme.py b/tests/entries/test_mixing_scheme.py index 33d71c5a5b8..860ab42a0e7 100644 --- a/tests/entries/test_mixing_scheme.py +++ b/tests/entries/test_mixing_scheme.py @@ -108,6 +108,7 @@ import pytest from monty.json import MontyDecoder from numpy.testing import assert_allclose + from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core.lattice import Lattice diff --git a/tests/ext/test_cod.py b/tests/ext/test_cod.py index e2bc6ba9143..56dd47568ba 100644 --- a/tests/ext/test_cod.py +++ b/tests/ext/test_cod.py @@ -6,11 +6,12 @@ import pytest import requests + from pymatgen.ext.cod import COD if "CI" in os.environ: # test is slow and flaky, skip in CI. see # https://github.com/materialsproject/pymatgen/pull/3777#issuecomment-2071217785 - pytest.skip(allow_module_level=True) + pytest.skip(allow_module_level=True, reason="Skip COD test in CI") try: website_down = requests.get("https://www.crystallography.net", timeout=600).status_code != 200 diff --git a/tests/ext/test_matproj.py b/tests/ext/test_matproj.py index 99e13761f82..9e327272efe 100644 --- a/tests/ext/test_matproj.py +++ b/tests/ext/test_matproj.py @@ -7,6 +7,9 @@ import pytest import requests from numpy.testing import assert_allclose +from pytest import approx +from ruamel.yaml import YAML + from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixEntry from pymatgen.analysis.reaction_calculator import Reaction @@ -21,8 +24,6 @@ from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx -from ruamel.yaml import YAML PMG_MAPI_KEY = SETTINGS.get("PMG_MAPI_KEY", "") if (10 < len(PMG_MAPI_KEY) <= 20) and "PMG_MAPI_KEY" in SETTINGS: diff --git a/tests/ext/test_optimade.py b/tests/ext/test_optimade.py index 2af457c38af..ea351acb8c7 100644 --- a/tests/ext/test_optimade.py +++ b/tests/ext/test_optimade.py @@ -2,6 +2,7 @@ import pytest import requests + from pymatgen.ext.optimade import OptimadeRester from pymatgen.util.testing import PymatgenTest diff --git a/tests/io/abinit/test_abiobjects.py b/tests/io/abinit/test_abiobjects.py index 5dcbff0b26d..8bbfa3a7441 100644 --- a/tests/io/abinit/test_abiobjects.py +++ b/tests/io/abinit/test_abiobjects.py @@ -3,6 +3,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core.structure import Structure from pymatgen.core.units import Ha_to_eV, bohr_to_ang from pymatgen.io.abinit.abiobjects import ( @@ -18,7 +20,6 @@ structure_to_abivars, ) from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx class TestLatticeFromAbivars(PymatgenTest): diff --git a/tests/io/abinit/test_inputs.py b/tests/io/abinit/test_inputs.py index 287bfa68401..6becf8953f8 100644 --- a/tests/io/abinit/test_inputs.py +++ b/tests/io/abinit/test_inputs.py @@ -6,6 +6,7 @@ import numpy as np import pytest from numpy.testing import assert_array_equal + from pymatgen.core.structure import Structure from pymatgen.io.abinit.inputs import ( BasicAbinitInput, diff --git a/tests/io/abinit/test_netcdf.py b/tests/io/abinit/test_netcdf.py index acd17dcae7d..7f9fd9661f7 100644 --- a/tests/io/abinit/test_netcdf.py +++ b/tests/io/abinit/test_netcdf.py @@ -7,6 +7,7 @@ import pytest from monty.tempfile import ScratchDir from numpy.testing import assert_allclose, assert_array_equal + from pymatgen.core.structure import Structure from pymatgen.io.abinit import EtsfReader from pymatgen.io.abinit.netcdf import AbinitHeader diff --git a/tests/io/abinit/test_pseudos.py b/tests/io/abinit/test_pseudos.py index 1e1b5a1fe64..a49c71fdf75 100644 --- a/tests/io/abinit/test_pseudos.py +++ b/tests/io/abinit/test_pseudos.py @@ -6,9 +6,10 @@ import pytest from monty.tempfile import ScratchDir +from pytest import approx + from pymatgen.io.abinit.pseudos import Pseudo, PseudoTable from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/abinit" diff --git a/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/control.in.gz b/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/control.in.gz index 9b912b2ffe2..f151b8f5b88 100644 Binary files a/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/control.in.gz and b/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/control.in.gz differ diff --git a/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/geometry.in.gz b/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/geometry.in.gz index c2e720a0366..413cb3eb22a 100644 Binary files a/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/geometry.in.gz and b/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/geometry.in.gz differ diff --git a/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/parameters.json b/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/parameters.json index 28242f36fa7..a6c81a5f1e6 100644 --- a/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/parameters.json +++ b/tests/io/aims/aims_input_generator_ref/static-no-kgrid-si/parameters.json @@ -4,7 +4,7 @@ "species_dir": "/home/tpurcell/git/atomate2/tests/aims/species_dir/light", "k_grid": [ 12, - 12, - 12 + 6, + 4 ] } diff --git a/tests/io/aims/conftest.py b/tests/io/aims/conftest.py index 177aeb38001..5060ba4b259 100644 --- a/tests/io/aims/conftest.py +++ b/tests/io/aims/conftest.py @@ -3,6 +3,7 @@ import os import pytest + from pymatgen.core import SETTINGS module_dir = os.path.dirname(__file__) diff --git a/tests/io/aims/test_aims_inputs.py b/tests/io/aims/test_aims_inputs.py index 5ded1722eaf..d39baafd277 100644 --- a/tests/io/aims/test_aims_inputs.py +++ b/tests/io/aims/test_aims_inputs.py @@ -8,6 +8,7 @@ import pytest from monty.json import MontyDecoder, MontyEncoder from numpy.testing import assert_allclose + from pymatgen.core import SETTINGS from pymatgen.io.aims.inputs import ( ALLOWED_AIMS_CUBE_TYPES, diff --git a/tests/io/aims/test_aims_outputs.py b/tests/io/aims/test_aims_outputs.py index 1d675dad2a3..deb0554f658 100644 --- a/tests/io/aims/test_aims_outputs.py +++ b/tests/io/aims/test_aims_outputs.py @@ -6,6 +6,7 @@ from monty.json import MontyDecoder, MontyEncoder from numpy.testing import assert_allclose + from pymatgen.core import Structure from pymatgen.io.aims.outputs import AimsOutput diff --git a/tests/io/aims/test_aims_parsers.py b/tests/io/aims/test_aims_parsers.py index e6cb7d6eb7a..1d30799122d 100644 --- a/tests/io/aims/test_aims_parsers.py +++ b/tests/io/aims/test_aims_parsers.py @@ -6,6 +6,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.core.tensors import Tensor from pymatgen.io.aims.parsers import ( EV_PER_A3_TO_KBAR, diff --git a/tests/io/aims/test_sets/test_input_set.py b/tests/io/aims/test_sets/test_input_set.py index f5d87fe3881..fb7acf5011d 100644 --- a/tests/io/aims/test_sets/test_input_set.py +++ b/tests/io/aims/test_sets/test_input_set.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest + from pymatgen.core import Structure from pymatgen.io.aims.sets import AimsInputSet diff --git a/tests/io/aims/test_sets/test_md_generator.py b/tests/io/aims/test_sets/test_md_generator.py index 13c6f8f6859..f8b9f71cbef 100644 --- a/tests/io/aims/test_sets/test_md_generator.py +++ b/tests/io/aims/test_sets/test_md_generator.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest + from pymatgen.io.aims.sets.core import MDSetGenerator from pymatgen.util.testing.aims import Si, compare_files diff --git a/tests/io/aims/test_sets/test_static_generator.py b/tests/io/aims/test_sets/test_static_generator.py index 39415758b1e..a38c002897d 100644 --- a/tests/io/aims/test_sets/test_static_generator.py +++ b/tests/io/aims/test_sets/test_static_generator.py @@ -19,7 +19,11 @@ def test_static_si(tmp_path): def test_static_si_no_kgrid(tmp_path): parameters = {"species_dir": "light"} - comp_system(Si, parameters, "static-no-kgrid-si", tmp_path, ref_path, StaticSetGenerator) + Si_supercell = Si.make_supercell([1, 2, 3], in_place=False) + for site in Si_supercell: + # round site.coords to ignore floating point errors + site.coords = [round(x, 15) for x in site.coords] + comp_system(Si_supercell, parameters, "static-no-kgrid-si", tmp_path, ref_path, StaticSetGenerator) def test_static_o2(tmp_path): diff --git a/tests/io/cp2k/test_inputs.py b/tests/io/cp2k/test_inputs.py index 9120a789f6d..a4b344cda19 100644 --- a/tests/io/cp2k/test_inputs.py +++ b/tests/io/cp2k/test_inputs.py @@ -3,6 +3,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core.structure import Molecule, Structure from pymatgen.io.cp2k.inputs import ( BasisFile, @@ -21,7 +23,6 @@ SectionList, ) from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/cp2k" diff --git a/tests/io/cp2k/test_outputs.py b/tests/io/cp2k/test_outputs.py index fafe2b8be91..cd3214af175 100644 --- a/tests/io/cp2k/test_outputs.py +++ b/tests/io/cp2k/test_outputs.py @@ -4,9 +4,10 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.io.cp2k.outputs import Cp2kOutput from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/cp2k" diff --git a/tests/io/cp2k/test_sets.py b/tests/io/cp2k/test_sets.py index c986f02db6a..161faf95117 100644 --- a/tests/io/cp2k/test_sets.py +++ b/tests/io/cp2k/test_sets.py @@ -1,10 +1,11 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.core.structure import Molecule, Structure from pymatgen.io.cp2k.sets import SETTINGS, Cp2kValidationError, DftSet, GaussianTypeOrbitalBasisSet, GthPotential from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/cp2k" @@ -102,7 +103,7 @@ def test_dft_set(self): assert dft_set.check("force_eval/dft/auxiliary_density_matrix_method") # Validator will trip for kpoints + hfx - dft_set.update({"force_eval": {"dft": {"kpoints": {}}}}) + dft_set |= {"force_eval": {"dft": {"kpoints": {}}}} with pytest.raises(Cp2kValidationError, match="CP2K v2022.1: Does not support hartree fock with kpoints"): dft_set.validate() diff --git a/tests/io/exciting/test_inputs.py b/tests/io/exciting/test_inputs.py index 02528023893..a5a221221b3 100644 --- a/tests/io/exciting/test_inputs.py +++ b/tests/io/exciting/test_inputs.py @@ -3,6 +3,7 @@ from xml.etree import ElementTree from numpy.testing import assert_allclose + from pymatgen.core import Lattice, Structure from pymatgen.io.exciting import ExcitingInput from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/io/feff/test_inputs.py b/tests/io/feff/test_inputs.py index cf06dd92bf9..b95bcc536e1 100644 --- a/tests/io/feff/test_inputs.py +++ b/tests/io/feff/test_inputs.py @@ -4,10 +4,11 @@ from unittest import TestCase from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core import Molecule, Structure from pymatgen.io.feff.inputs import Atoms, Header, Paths, Potential, Tags from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx FEFF_TEST_DIR = f"{TEST_FILES_DIR}/io/feff" diff --git a/tests/io/feff/test_sets.py b/tests/io/feff/test_sets.py index 059d2422550..e57e560996c 100644 --- a/tests/io/feff/test_sets.py +++ b/tests/io/feff/test_sets.py @@ -5,6 +5,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.core.structure import Lattice, Molecule, Structure from pymatgen.io.feff.inputs import Atoms, Header, Potential, Tags from pymatgen.io.feff.sets import FEFFDictSet, MPELNESSet, MPEXAFSSet, MPXANESSet diff --git a/tests/io/lammps/test_data.py b/tests/io/lammps/test_data.py index e88732fb033..d2fb6541926 100644 --- a/tests/io/lammps/test_data.py +++ b/tests/io/lammps/test_data.py @@ -10,11 +10,12 @@ import pytest from monty.json import MontyDecoder, MontyEncoder from numpy.testing import assert_allclose +from pytest import approx +from ruamel.yaml import YAML + from pymatgen.core import Element, Lattice, Molecule, Structure from pymatgen.io.lammps.data import CombinedData, ForceField, LammpsBox, LammpsData, Topology, lattice_2_lmpbox from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx -from ruamel.yaml import YAML TEST_DIR = f"{TEST_FILES_DIR}/io/lammps" diff --git a/tests/io/lammps/test_inputs.py b/tests/io/lammps/test_inputs.py index a3abdcceb16..d56765fdad3 100644 --- a/tests/io/lammps/test_inputs.py +++ b/tests/io/lammps/test_inputs.py @@ -6,6 +6,7 @@ import pandas as pd import pytest + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.io.lammps.data import LammpsData diff --git a/tests/io/lammps/test_outputs.py b/tests/io/lammps/test_outputs.py index e982cf21f18..6d5e83c3089 100644 --- a/tests/io/lammps/test_outputs.py +++ b/tests/io/lammps/test_outputs.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from numpy.testing import assert_allclose + from pymatgen.io.lammps.outputs import LammpsDump, parse_lammps_dumps, parse_lammps_log from pymatgen.util.testing import TEST_FILES_DIR diff --git a/tests/io/lobster/test_inputs.py b/tests/io/lobster/test_inputs.py index 62f413a1947..6675eaf70c8 100644 --- a/tests/io/lobster/test_inputs.py +++ b/tests/io/lobster/test_inputs.py @@ -7,6 +7,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core.structure import Structure from pymatgen.electronic_structure.cohp import IcohpCollection from pymatgen.electronic_structure.core import Orbital, Spin @@ -30,7 +32,6 @@ from pymatgen.io.vasp import Vasprun from pymatgen.io.vasp.inputs import Incar, Kpoints, Potcar from pymatgen.util.testing import FAKE_POTCAR_DIR, TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/electronic_structure/cohp" @@ -443,77 +444,77 @@ def test_values(self): "length": 2.88231, "number_of_bonds": 3, "icohp": {Spin.up: -2.18042}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "2": { "length": 3.10144, "number_of_bonds": 3, "icohp": {Spin.up: -1.14347}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "3": { "length": 2.88231, "number_of_bonds": 3, "icohp": {Spin.up: -2.18042}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "4": { "length": 3.10144, "number_of_bonds": 3, "icohp": {Spin.up: -1.14348}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "5": { "length": 3.05001, "number_of_bonds": 3, "icohp": {Spin.up: -1.30006}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "6": { "length": 2.91676, "number_of_bonds": 3, "icohp": {Spin.up: -1.96843}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "7": { "length": 3.05001, "number_of_bonds": 3, "icohp": {Spin.up: -1.30006}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "8": { "length": 2.91676, "number_of_bonds": 3, "icohp": {Spin.up: -1.96843}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "9": { "length": 3.37522, "number_of_bonds": 3, "icohp": {Spin.up: -0.47531}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "10": { "length": 3.07294, "number_of_bonds": 3, "icohp": {Spin.up: -2.38796}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "11": { "length": 3.37522, "number_of_bonds": 3, "icohp": {Spin.up: -0.47531}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, } @@ -522,77 +523,77 @@ def test_values(self): "length": 2.88231, "number_of_bonds": 3, "icohp": {Spin.up: 0.14245}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "2": { "length": 3.10144, "number_of_bonds": 3, "icohp": {Spin.up: -0.04118}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "3": { "length": 2.88231, "number_of_bonds": 3, "icohp": {Spin.up: 0.14245}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "4": { "length": 3.10144, "number_of_bonds": 3, "icohp": {Spin.up: -0.04118}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "5": { "length": 3.05001, "number_of_bonds": 3, "icohp": {Spin.up: -0.03516}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "6": { "length": 2.91676, "number_of_bonds": 3, "icohp": {Spin.up: 0.10745}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "7": { "length": 3.05001, "number_of_bonds": 3, "icohp": {Spin.up: -0.03516}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "8": { "length": 2.91676, "number_of_bonds": 3, "icohp": {Spin.up: 0.10745}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "9": { "length": 3.37522, "number_of_bonds": 3, "icohp": {Spin.up: -0.12395}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "10": { "length": 3.07294, "number_of_bonds": 3, "icohp": {Spin.up: 0.24714}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "11": { "length": 3.37522, "number_of_bonds": 3, "icohp": {Spin.up: -0.12395}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, } @@ -601,14 +602,14 @@ def test_values(self): "length": 2.83189, "number_of_bonds": 2, "icohp": {Spin.up: -0.10218, Spin.down: -0.19701}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, "2": { "length": 2.45249, "number_of_bonds": 1, "icohp": {Spin.up: -0.28485, Spin.down: -0.58279}, - "translation": [0, 0, 0], + "translation": (0, 0, 0), "orbitals": None, }, } diff --git a/tests/io/lobster/test_lobsterenv.py b/tests/io/lobster/test_lobsterenv.py index 9a4230e232e..6ee94d82902 100644 --- a/tests/io/lobster/test_lobsterenv.py +++ b/tests/io/lobster/test_lobsterenv.py @@ -5,6 +5,8 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.analysis.graphs import StructureGraph from pymatgen.core import Element from pymatgen.core.structure import Structure @@ -13,7 +15,6 @@ from pymatgen.io.lobster import Charge, Icohplist from pymatgen.io.lobster.lobsterenv import LobsterNeighbors from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx __author__ = "Janine George" __copyright__ = "Copyright 2021, The Materials Project" diff --git a/tests/io/pwmat/test_inputs.py b/tests/io/pwmat/test_inputs.py index 8ee41de8552..1343ef6832a 100644 --- a/tests/io/pwmat/test_inputs.py +++ b/tests/io/pwmat/test_inputs.py @@ -3,6 +3,7 @@ import pytest from monty.io import zopen from numpy.testing import assert_allclose + from pymatgen.core import Composition, Structure from pymatgen.io.pwmat.inputs import ( ACExtractor, diff --git a/tests/io/qchem/test_inputs.py b/tests/io/qchem/test_inputs.py index c5c0515fa57..6a49daba428 100644 --- a/tests/io/qchem/test_inputs.py +++ b/tests/io/qchem/test_inputs.py @@ -5,6 +5,7 @@ import pytest from monty.serialization import loadfn + from pymatgen.core.structure import Molecule from pymatgen.io.qchem.inputs import QCInput from pymatgen.io.qchem.sets import OptSet diff --git a/tests/io/qchem/test_outputs.py b/tests/io/qchem/test_outputs.py index 31e54f67b3f..bd900253583 100644 --- a/tests/io/qchem/test_outputs.py +++ b/tests/io/qchem/test_outputs.py @@ -7,6 +7,8 @@ import pytest from monty.serialization import dumpfn, loadfn from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.structure import Molecule from pymatgen.io.qchem.outputs import ( QCOutput, @@ -16,7 +18,6 @@ orbital_coeffs_parser, ) from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx try: from openbabel import openbabel diff --git a/tests/io/qchem/test_sets.py b/tests/io/qchem/test_sets.py index 380ce742ddc..a1f43721dea 100644 --- a/tests/io/qchem/test_sets.py +++ b/tests/io/qchem/test_sets.py @@ -3,6 +3,7 @@ import os import pytest + from pymatgen.io.qchem.sets import ( ForceSet, FreqSet, diff --git a/tests/io/qchem/test_utils.py b/tests/io/qchem/test_utils.py index 260dd803462..394ab8c142b 100644 --- a/tests/io/qchem/test_utils.py +++ b/tests/io/qchem/test_utils.py @@ -5,6 +5,7 @@ import pytest from monty.io import zopen + from pymatgen.io.qchem.utils import lower_and_check_unique, process_parsed_hess from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/io/test_adf.py b/tests/io/test_adf.py index 702fbc87d72..d3b6785320f 100644 --- a/tests/io/test_adf.py +++ b/tests/io/test_adf.py @@ -1,9 +1,10 @@ from __future__ import annotations +from pytest import approx + from pymatgen.core.structure import Molecule from pymatgen.io.adf import AdfInput, AdfKey, AdfOutput, AdfTask from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx __author__ = "Xin Chen, chenxin13@mails.tsinghua.edu.cn" diff --git a/tests/io/test_ase.py b/tests/io/test_ase.py index 5ed59b6afff..daf7b0c77bd 100644 --- a/tests/io/test_ase.py +++ b/tests/io/test_ase.py @@ -3,6 +3,7 @@ import numpy as np import pytest from monty.json import MontyDecoder, jsanitize + from pymatgen.core import Composition, Lattice, Molecule, Structure from pymatgen.core.structure import StructureError from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms diff --git a/tests/io/test_atat.py b/tests/io/test_atat.py index 6cf13574495..a9562f92511 100644 --- a/tests/io/test_atat.py +++ b/tests/io/test_atat.py @@ -1,10 +1,11 @@ from __future__ import annotations from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.structure import Structure from pymatgen.io.atat import Mcsqs from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/atat/mcsqs" diff --git a/tests/io/test_babel.py b/tests/io/test_babel.py index 589330d94c5..e1724176ec4 100644 --- a/tests/io/test_babel.py +++ b/tests/io/test_babel.py @@ -4,13 +4,14 @@ from unittest import TestCase import pytest +from pytest import approx + from pymatgen.analysis.graphs import MoleculeGraph from pymatgen.analysis.molecule_matcher import MoleculeMatcher from pymatgen.core.structure import Molecule from pymatgen.io.babel import BabelMolAdaptor from pymatgen.io.xyz import XYZ from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx pybel = pytest.importorskip("openbabel.pybel") diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index 8abc5127813..f536a4f4021 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -2,13 +2,14 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import Composition, DummySpecies, Element, Lattice, Species, Structure from pymatgen.electronic_structure.core import Magmom from pymatgen.io.cif import CifBlock, CifParser, CifWriter from pymatgen.symmetry.structure import SymmetrizedStructure from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest -from pytest import approx try: import pybtex diff --git a/tests/io/test_core.py b/tests/io/test_core.py index 7c457339249..f3f08b46185 100644 --- a/tests/io/test_core.py +++ b/tests/io/test_core.py @@ -6,6 +6,7 @@ import pytest from monty.serialization import MontyDecoder + from pymatgen.core.structure import Structure from pymatgen.io.cif import CifParser, CifWriter from pymatgen.io.core import InputFile, InputSet diff --git a/tests/io/test_gaussian.py b/tests/io/test_gaussian.py index fc694fa7f97..7dc76b78b1b 100644 --- a/tests/io/test_gaussian.py +++ b/tests/io/test_gaussian.py @@ -3,11 +3,12 @@ from unittest import TestCase import pytest +from pytest import approx + from pymatgen.core.structure import Molecule from pymatgen.electronic_structure.core import Spin from pymatgen.io.gaussian import GaussianInput, GaussianOutput from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/gaussian" diff --git a/tests/io/test_jarvis.py b/tests/io/test_jarvis.py index 00a6b4f88f0..75c3c78eb82 100644 --- a/tests/io/test_jarvis.py +++ b/tests/io/test_jarvis.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from pymatgen.core import Structure from pymatgen.io.jarvis import Atoms, JarvisAtomsAdaptor from pymatgen.util.testing import VASP_IN_DIR diff --git a/tests/io/test_lmto.py b/tests/io/test_lmto.py index d398bd431b3..393a0b06bae 100644 --- a/tests/io/test_lmto.py +++ b/tests/io/test_lmto.py @@ -4,6 +4,7 @@ import numpy as np from numpy.testing import assert_array_equal + from pymatgen.core.structure import Structure from pymatgen.core.units import Ry_to_eV from pymatgen.electronic_structure.core import Spin diff --git a/tests/io/test_multiwfn.py b/tests/io/test_multiwfn.py index 8d9df8e705a..ad83830a577 100644 --- a/tests/io/test_multiwfn.py +++ b/tests/io/test_multiwfn.py @@ -3,6 +3,7 @@ import copy import pytest + from pymatgen.core.structure import Molecule from pymatgen.io.multiwfn import ( QTAIM_CONDITIONALS, diff --git a/tests/io/test_nwchem.py b/tests/io/test_nwchem.py index 5f4a13c8b5a..a81dc751ec0 100644 --- a/tests/io/test_nwchem.py +++ b/tests/io/test_nwchem.py @@ -4,10 +4,11 @@ from unittest import TestCase import pytest +from pytest import approx + from pymatgen.core.structure import Molecule from pymatgen.io.nwchem import NwInput, NwInputError, NwOutput, NwTask from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/nwchem" diff --git a/tests/io/test_openff.py b/tests/io/test_openff.py index 41057daeb0b..2152b3e955e 100644 --- a/tests/io/test_openff.py +++ b/tests/io/test_openff.py @@ -5,6 +5,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.analysis.graphs import MoleculeGraph from pymatgen.analysis.local_env import OpenBabelNN from pymatgen.core import Molecule diff --git a/tests/io/test_optimade.py b/tests/io/test_optimade.py new file mode 100644 index 00000000000..63befcbca7c --- /dev/null +++ b/tests/io/test_optimade.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import numpy as np + +from pymatgen.core import Structure +from pymatgen.io.optimade import OptimadeStructureAdapter +from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR + +STRUCTURE = Structure.from_file(f"{VASP_IN_DIR}/POSCAR") +XYZ_STRUCTURE = f"{TEST_FILES_DIR}/io/xyz/acetylene.xyz" + + +def test_get_optimade_structure_roundtrip(): + optimade_structure = OptimadeStructureAdapter.get_optimade_structure(STRUCTURE) + + assert optimade_structure["attributes"]["nsites"] == len(STRUCTURE) + assert optimade_structure["attributes"]["elements"] == ["Fe", "O", "P"] + assert optimade_structure["attributes"]["nelements"] == 3 + assert optimade_structure["attributes"]["chemical_formula_reduced"] == "FeO4P" + assert optimade_structure["attributes"]["species_at_sites"] == 4 * ["Fe"] + 4 * ["P"] + 16 * ["O"] + np.testing.assert_array_almost_equal( + np.abs(optimade_structure["attributes"]["lattice_vectors"]), np.abs(STRUCTURE.lattice.matrix) + ) + + # Set an OPTIMADE ID and some custom properties and ensure they are preserved in the properties + test_id = "test_id" + optimade_structure["id"] = test_id + custom_properties = {"_custom_field": "test_custom_field", "_custom_band_gap": 2.2} + optimade_structure["attributes"].update(custom_properties) + + roundtrip_structure = OptimadeStructureAdapter.get_structure(optimade_structure) + assert roundtrip_structure.properties["optimade_id"] == test_id + assert roundtrip_structure.properties["optimade_attributes"] == custom_properties + + # Delete the properties before the check for equality + roundtrip_structure.properties = {} + assert roundtrip_structure == STRUCTURE diff --git a/tests/io/test_packmol.py b/tests/io/test_packmol.py index e82f627e310..15d3c86b547 100644 --- a/tests/io/test_packmol.py +++ b/tests/io/test_packmol.py @@ -6,6 +6,7 @@ from subprocess import TimeoutExpired import pytest + from pymatgen.analysis.molecule_matcher import MoleculeMatcher from pymatgen.core import Molecule from pymatgen.io.packmol import PackmolBoxGen diff --git a/tests/io/test_phonopy.py b/tests/io/test_phonopy.py index c3bc2e18187..f12cce1f0a0 100644 --- a/tests/io/test_phonopy.py +++ b/tests/io/test_phonopy.py @@ -7,6 +7,8 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core import Element from pymatgen.io.phonopy import ( CompletePhononDos, @@ -27,7 +29,6 @@ get_thermal_displacement_matrices, ) from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx try: from phonopy import Phonopy diff --git a/tests/io/test_pwscf.py b/tests/io/test_pwscf.py index 124e4784a42..76870670178 100644 --- a/tests/io/test_pwscf.py +++ b/tests/io/test_pwscf.py @@ -3,9 +3,10 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.io.pwscf import PWInput, PWInputError, PWOutput from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/pwscf" @@ -375,6 +376,32 @@ def test_read_str(self): assert_allclose(lattice, pw_in.structure.lattice.matrix) assert pw_in.sections["system"]["smearing"] == "cold" + def test_write_and_read_str(self): + struct = self.get_structure("Graphite") + struct.remove_oxidation_states() + pw = PWInput( + struct, + pseudo={"C": "C.pbe-n-kjpaw_psl.1.0.0.UPF"}, + control={"calculation": "scf", "pseudo_dir": "./"}, + system={"ecutwfc": 45}, + ) + pw_str = str(pw) + assert pw_str.strip() == str(PWInput.from_str(pw_str)).strip() + + def test_write_and_read_str_with_oxidation(self): + struct = self.get_structure("Li2O") + pw = PWInput( + struct, + control={"calculation": "scf", "pseudo_dir": "./"}, + pseudo={ + "Li+": "Li.pbe-n-kjpaw_psl.0.1.UPF", + "O2-": "O.pbe-n-kjpaw_psl.0.1.UPF", + }, + system={"ecutwfc": 50}, + ) + pw_str = str(pw) + assert pw_str.strip() == str(PWInput.from_str(pw_str)).strip() + class TestPWOutput(PymatgenTest): def setUp(self): diff --git a/tests/io/test_res.py b/tests/io/test_res.py index 2e05ee5769a..5617ad86492 100644 --- a/tests/io/test_res.py +++ b/tests/io/test_res.py @@ -1,10 +1,11 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.core import Structure from pymatgen.io.res import AirssProvider, ResParseError, ResWriter from pymatgen.util.testing import TEST_FILES_DIR -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/res" diff --git a/tests/io/test_shengbte.py b/tests/io/test_shengbte.py index 60bc58b2a4a..67779c458a1 100644 --- a/tests/io/test_shengbte.py +++ b/tests/io/test_shengbte.py @@ -4,6 +4,7 @@ import pytest from numpy.testing import assert_array_equal + from pymatgen.io.shengbte import Control from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/io/test_template_input.py b/tests/io/test_template_input.py index 313fa8fc540..3be8ae45374 100644 --- a/tests/io/test_template_input.py +++ b/tests/io/test_template_input.py @@ -3,6 +3,7 @@ import os import pytest + from pymatgen.io.template import TemplateInputGen from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/io/test_wannier90.py b/tests/io/test_wannier90.py index 32757970eeb..ca4704dfe72 100644 --- a/tests/io/test_wannier90.py +++ b/tests/io/test_wannier90.py @@ -5,9 +5,10 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.io.wannier90 import Unk from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/io/wannier90" diff --git a/tests/io/test_xcrysden.py b/tests/io/test_xcrysden.py index 88db940bcf8..3a13a2cf7b5 100644 --- a/tests/io/test_xcrysden.py +++ b/tests/io/test_xcrysden.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np + from pymatgen.core.structure import Structure from pymatgen.io.xcrysden import XSF from pymatgen.util.testing import PymatgenTest diff --git a/tests/io/test_xyz.py b/tests/io/test_xyz.py index 77e3c582efa..2314eefdc8e 100644 --- a/tests/io/test_xyz.py +++ b/tests/io/test_xyz.py @@ -4,11 +4,12 @@ import pandas as pd import pytest +from pytest import approx + from pymatgen.core import Structure from pymatgen.core.structure import Molecule from pymatgen.io.xyz import XYZ from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR -from pytest import approx class TestXYZ(TestCase): diff --git a/tests/io/test_zeopp.py b/tests/io/test_zeopp.py index 3656530cb4e..2f38d24086c 100644 --- a/tests/io/test_zeopp.py +++ b/tests/io/test_zeopp.py @@ -4,6 +4,8 @@ from unittest import TestCase import pytest +from pytest import approx + from pymatgen.analysis.bond_valence import BVAnalyzer from pymatgen.core import Molecule, Species, Structure from pymatgen.io.zeopp import ( @@ -14,7 +16,6 @@ get_voronoi_nodes, ) from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR -from pytest import approx try: import zeo diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index 4ce0c20e95f..47a7b8fcc5a 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -15,6 +15,8 @@ from monty.io import zopen from monty.serialization import loadfn from numpy.testing import assert_allclose +from pytest import MonkeyPatch, approx + from pymatgen.core import SETTINGS from pymatgen.core.composition import Composition from pymatgen.core.structure import Structure @@ -34,7 +36,6 @@ _gen_potcar_summary_stats, ) from pymatgen.util.testing import FAKE_POTCAR_DIR, TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import MonkeyPatch, approx # make sure _gen_potcar_summary_stats runs and works with all tests in this file _summ_stats = _gen_potcar_summary_stats(append=False, vasp_psp_dir=str(FAKE_POTCAR_DIR), summary_stats_filename=None) @@ -773,6 +774,7 @@ def test_check_params(self): "AMIN": 0.01, "ICHARG": 1, "MAGMOM": [1, 2, 4, 5], + "LREAL": True, # special case: Union type "NBAND": 250, # typo in tag "METAGGA": "SCAM", # typo in value "EDIFF": 5 + 1j, # value should be a float diff --git a/tests/io/vasp/test_optics.py b/tests/io/vasp/test_optics.py index 3e660a58797..3a2caf7a94a 100644 --- a/tests/io/vasp/test_optics.py +++ b/tests/io/vasp/test_optics.py @@ -4,6 +4,7 @@ import pytest import scipy.special from numpy.testing import assert_allclose + from pymatgen.io.vasp.optics import DielectricFunctionCalculator, delta_func, delta_methfessel_paxton, step_func from pymatgen.io.vasp.outputs import Vasprun from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/io/vasp/test_outputs.py b/tests/io/vasp/test_outputs.py index fb2a9697842..a5faadac3ca 100644 --- a/tests/io/vasp/test_outputs.py +++ b/tests/io/vasp/test_outputs.py @@ -13,6 +13,8 @@ import pytest from monty.io import zopen from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core import Element from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure @@ -40,7 +42,6 @@ ) from pymatgen.io.wannier90 import Unk from pymatgen.util.testing import FAKE_POTCAR_DIR, TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx try: import h5py diff --git a/tests/io/vasp/test_sets.py b/tests/io/vasp/test_sets.py index a798e4032cc..beaa552686b 100644 --- a/tests/io/vasp/test_sets.py +++ b/tests/io/vasp/test_sets.py @@ -10,6 +10,8 @@ from monty.json import MontyDecoder from monty.serialization import loadfn from numpy.testing import assert_allclose +from pytest import MonkeyPatch, approx, mark + from pymatgen.analysis.structure_matcher import StructureMatcher from pymatgen.core import SETTINGS, Lattice, Species, Structure from pymatgen.core.composition import Composition @@ -53,7 +55,6 @@ ) from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import FAKE_POTCAR_DIR, TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import MonkeyPatch, approx, mark TEST_DIR = f"{TEST_FILES_DIR}/io/vasp" diff --git a/tests/io/xtb/test_outputs.py b/tests/io/xtb/test_outputs.py index dd1afc090af..c4d3fe69777 100644 --- a/tests/io/xtb/test_outputs.py +++ b/tests/io/xtb/test_outputs.py @@ -2,11 +2,12 @@ import os +from pytest import approx + from pymatgen.core.structure import Molecule from pymatgen.io.qchem.outputs import check_for_structure_changes from pymatgen.io.xtb.outputs import CRESTOutput from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx try: from openbabel import openbabel diff --git a/tests/optimization/test_linear_assignment.py b/tests/optimization/test_linear_assignment.py index 163a1ac3132..961a5032b6d 100644 --- a/tests/optimization/test_linear_assignment.py +++ b/tests/optimization/test_linear_assignment.py @@ -4,9 +4,10 @@ import numpy as np import pytest -from pymatgen.optimization.linear_assignment import LinearAssignment from pytest import approx +from pymatgen.optimization.linear_assignment import LinearAssignment + class TestLinearAssignment(TestCase): def test(self): diff --git a/tests/optimization/test_neighbors.py b/tests/optimization/test_neighbors.py index 8d66dcd9b91..a6374c1f929 100644 --- a/tests/optimization/test_neighbors.py +++ b/tests/optimization/test_neighbors.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np + from pymatgen.core.lattice import Lattice from pymatgen.optimization.neighbors import find_points_in_spheres from pymatgen.util.testing import PymatgenTest diff --git a/tests/phonon/test_bandstructure.py b/tests/phonon/test_bandstructure.py index 6f46c3dca38..7f449792544 100644 --- a/tests/phonon/test_bandstructure.py +++ b/tests/phonon/test_bandstructure.py @@ -4,10 +4,11 @@ import json from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.electronic_structure.bandstructure import Kpoint from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/electronic_structure/bandstructure" diff --git a/tests/phonon/test_dos.py b/tests/phonon/test_dos.py index 013cfa7e8b3..0fecb44d6fd 100644 --- a/tests/phonon/test_dos.py +++ b/tests/phonon/test_dos.py @@ -5,10 +5,11 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.core import Element from pymatgen.phonon.dos import CompletePhononDos, PhononDos from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/phonon/dos" diff --git a/tests/phonon/test_gruneisen.py b/tests/phonon/test_gruneisen.py index 3040db72cf3..36f1d66fa22 100644 --- a/tests/phonon/test_gruneisen.py +++ b/tests/phonon/test_gruneisen.py @@ -4,11 +4,12 @@ import numpy as np import pytest from matplotlib import colors +from pytest import approx + from pymatgen.io.phonopy import get_gruneisen_ph_bs_symm_line, get_gruneisenparameter from pymatgen.phonon.gruneisen import GruneisenParameter from pymatgen.phonon.plotter import GruneisenPhononBandStructureSymmLine, GruneisenPhononBSPlotter, GruneisenPlotter from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx try: import phonopy diff --git a/tests/phonon/test_ir_spectra.py b/tests/phonon/test_ir_spectra.py index 11bae39d9f7..7a7eb40f4ce 100644 --- a/tests/phonon/test_ir_spectra.py +++ b/tests/phonon/test_ir_spectra.py @@ -1,6 +1,7 @@ from __future__ import annotations from monty.serialization import loadfn + from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/phonon/test_plotter.py b/tests/phonon/test_plotter.py index 2e0c762804b..01e657e7eaf 100644 --- a/tests/phonon/test_plotter.py +++ b/tests/phonon/test_plotter.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import pytest from numpy.testing import assert_allclose + from pymatgen.phonon import CompletePhononDos, PhononBandStructureSymmLine from pymatgen.phonon.plotter import PhononBSPlotter, PhononDosPlotter, ThermoPlotter from pymatgen.util.testing import TEST_FILES_DIR diff --git a/tests/phonon/test_thermal_displacements.py b/tests/phonon/test_thermal_displacements.py index 315680e28d8..725c197263f 100644 --- a/tests/phonon/test_thermal_displacements.py +++ b/tests/phonon/test_thermal_displacements.py @@ -2,10 +2,11 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx + from pymatgen.core.structure import Structure from pymatgen.phonon.thermal_displacements import ThermalDisplacementMatrices from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/phonon/thermal_displacement_matrices" diff --git a/tests/symmetry/test_analyzer.py b/tests/symmetry/test_analyzer.py index 41c7b72e6c1..14de4656202 100644 --- a/tests/symmetry/test_analyzer.py +++ b/tests/symmetry/test_analyzer.py @@ -6,6 +6,9 @@ import numpy as np import pytest from numpy.testing import assert_allclose +from pytest import approx, raises +from spglib import SpglibDataset + from pymatgen.core import Lattice, Molecule, PeriodicSite, Site, Species, Structure from pymatgen.io.vasp.outputs import Vasprun from pymatgen.symmetry.analyzer import ( @@ -17,8 +20,6 @@ ) from pymatgen.symmetry.structure import SymmetrizedStructure from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -from pytest import approx, raises -from spglib import SpglibDataset TEST_DIR = f"{TEST_FILES_DIR}/symmetry/analyzer" diff --git a/tests/symmetry/test_groups.py b/tests/symmetry/test_groups.py index ae9b3a0a9f7..5a2b6fd66a4 100644 --- a/tests/symmetry/test_groups.py +++ b/tests/symmetry/test_groups.py @@ -2,10 +2,11 @@ import numpy as np import pytest +from pytest import approx + from pymatgen.core.lattice import Lattice from pymatgen.core.operations import SymmOp from pymatgen.symmetry.groups import SYMM_DATA, PointGroup, SpaceGroup -from pytest import approx __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Virtual Lab" diff --git a/tests/symmetry/test_kpath_hin.py b/tests/symmetry/test_kpath_hin.py index 63dab783ac2..61d71fd9dd3 100644 --- a/tests/symmetry/test_kpath_hin.py +++ b/tests/symmetry/test_kpath_hin.py @@ -1,11 +1,12 @@ from __future__ import annotations import pytest +from pytest import approx + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.symmetry.kpath import KPathSeek from pymatgen.util.testing import PymatgenTest -from pytest import approx try: from seekpath import get_path diff --git a/tests/symmetry/test_kpath_lm.py b/tests/symmetry/test_kpath_lm.py index e72ff7ef9d7..dbc8187e6a9 100644 --- a/tests/symmetry/test_kpath_lm.py +++ b/tests/symmetry/test_kpath_lm.py @@ -1,13 +1,14 @@ from __future__ import annotations import numpy as np +from pytest import approx + from pymatgen.analysis.magnetism.analyzer import CollinearMagneticStructureAnalyzer from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.symmetry.kpath import KPathLatimerMunro from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx class TestKPathLatimerMunro(PymatgenTest): diff --git a/tests/symmetry/test_kpath_sc.py b/tests/symmetry/test_kpath_sc.py index 2dd37d0d77e..d88c97f79dd 100644 --- a/tests/symmetry/test_kpath_sc.py +++ b/tests/symmetry/test_kpath_sc.py @@ -1,11 +1,12 @@ from __future__ import annotations import numpy as np +from pytest import approx + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.symmetry.kpath import KPathSetyawanCurtarolo from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -from pytest import approx TEST_DIR = f"{TEST_FILES_DIR}/symmetry/space_group_structs" diff --git a/tests/symmetry/test_kpaths.py b/tests/symmetry/test_kpaths.py index f13ae4b7a9b..2c907e87135 100644 --- a/tests/symmetry/test_kpaths.py +++ b/tests/symmetry/test_kpaths.py @@ -4,6 +4,7 @@ import pytest from monty.serialization import loadfn + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.symmetry.bandstructure import HighSymmKpath diff --git a/tests/symmetry/test_maggroups.py b/tests/symmetry/test_maggroups.py index cc4d529b87c..72f184d553f 100644 --- a/tests/symmetry/test_maggroups.py +++ b/tests/symmetry/test_maggroups.py @@ -2,6 +2,7 @@ import numpy as np from numpy.testing import assert_allclose + from pymatgen.core.lattice import Lattice from pymatgen.symmetry.groups import SpaceGroup from pymatgen.symmetry.maggroups import MagneticSpaceGroup diff --git a/tests/symmetry/test_settings.py b/tests/symmetry/test_settings.py index a6a8ff3e80a..24db5b2b6dc 100644 --- a/tests/symmetry/test_settings.py +++ b/tests/symmetry/test_settings.py @@ -4,6 +4,7 @@ import numpy as np from numpy.testing import assert_allclose + from pymatgen.symmetry.settings import JonesFaithfulTransformation, Lattice, SymmOp __author__ = "Matthew Horton" diff --git a/tests/symmetry/test_site_symmetries.py b/tests/symmetry/test_site_symmetries.py index c5696b31f46..21972990e8b 100644 --- a/tests/symmetry/test_site_symmetries.py +++ b/tests/symmetry/test_site_symmetries.py @@ -4,6 +4,7 @@ import json from monty.json import MontyDecoder + from pymatgen.symmetry import site_symmetries as ss from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest diff --git a/tests/test_cli.py b/tests/test_cli.py index 1893d0a5738..225664313dc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import pytest + from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR if TYPE_CHECKING: diff --git a/tests/transformations/test_advanced_transformations.py b/tests/transformations/test_advanced_transformations.py index 653c098bb19..f0bf1d0a7f8 100644 --- a/tests/transformations/test_advanced_transformations.py +++ b/tests/transformations/test_advanced_transformations.py @@ -7,6 +7,8 @@ import pytest from monty.serialization import loadfn from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.analysis.energy_models import IsingModel, SymmetryModel from pymatgen.analysis.gb.grain import GrainBoundaryGenerator from pymatgen.core import Lattice, Molecule, Species, Structure @@ -39,7 +41,6 @@ SubstitutionTransformation, ) from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, PymatgenTest -from pytest import approx try: import hiphive diff --git a/tests/transformations/test_site_transformations.py b/tests/transformations/test_site_transformations.py index cddc9bfb53f..3b30670fad3 100644 --- a/tests/transformations/test_site_transformations.py +++ b/tests/transformations/test_site_transformations.py @@ -6,6 +6,7 @@ import numpy as np import pytest from numpy.testing import assert_allclose + from pymatgen.core.structure import Molecule, Structure from pymatgen.transformations.site_transformations import ( AddSitePropertyTransformation, diff --git a/tests/transformations/test_standard_transformations.py b/tests/transformations/test_standard_transformations.py index 029fad0f8e7..331ed94b9c9 100644 --- a/tests/transformations/test_standard_transformations.py +++ b/tests/transformations/test_standard_transformations.py @@ -9,6 +9,8 @@ import numpy as np import pytest from monty.json import MontyDecoder +from pytest import approx + from pymatgen.core import Element, PeriodicSite from pymatgen.core.lattice import Lattice from pymatgen.symmetry.structure import SymmetrizedStructure @@ -34,7 +36,6 @@ SupercellTransformation, ) from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR -from pytest import approx enumlib_present = which("enum.x") and which("makestr.x") diff --git a/tests/util/test_coord.py b/tests/util/test_coord.py index 433831f3c5d..15cfab9d7dc 100644 --- a/tests/util/test_coord.py +++ b/tests/util/test_coord.py @@ -6,9 +6,10 @@ import numpy as np import pytest from numpy.testing import assert_allclose, assert_array_equal +from pytest import approx + from pymatgen.core.lattice import Lattice from pymatgen.util import coord -from pytest import approx class TestCoordUtils: diff --git a/tests/util/test_num.py b/tests/util/test_num.py index 8891e37ca0f..e391ba92412 100644 --- a/tests/util/test_num.py +++ b/tests/util/test_num.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from pymatgen.util.num import round_to_sigfigs diff --git a/tests/util/test_plotting.py b/tests/util/test_plotting.py index 85d91b457e8..ae0b6eca76b 100644 --- a/tests/util/test_plotting.py +++ b/tests/util/test_plotting.py @@ -1,6 +1,7 @@ from __future__ import annotations import matplotlib.pyplot as plt + from pymatgen.util.plotting import periodic_table_heatmap, van_arkel_triangle from pymatgen.util.testing import PymatgenTest diff --git a/tests/util/test_provenance.py b/tests/util/test_provenance.py index 9019ee50ebd..7cb2e9fa96d 100644 --- a/tests/util/test_provenance.py +++ b/tests/util/test_provenance.py @@ -7,6 +7,7 @@ import numpy as np import pytest + from pymatgen.core.structure import Molecule, Structure from pymatgen.util.provenance import Author, HistoryNode, StructureNL @@ -19,7 +20,7 @@ __date__ = "2/14/13" -class StructureNLCase(TestCase): +class TestStructureNL(TestCase): def setUp(self): # set up a Structure self.struct = Structure(np.eye(3, 3) * 3, ["Fe"], [[0, 0, 0]]) @@ -213,8 +214,8 @@ def test_as_from_dict(self): {"_my_data": "string"}, [self.valid_node, self.valid_node2], ) - b = StructureNL.from_dict(struct_nl.as_dict()) - assert struct_nl == b + round_trip_from_dict = StructureNL.from_dict(struct_nl.as_dict()) + assert struct_nl == round_trip_from_dict # complicated objects in the 'data' and 'nodes' field complicated_node = { "name": "complicated node", @@ -230,15 +231,15 @@ def test_as_from_dict(self): {"_my_data": {"structure": self.s2}}, [complicated_node, self.valid_node], ) - b = StructureNL.from_dict(struct_nl.as_dict()) + round_trip_from_dict = StructureNL.from_dict(struct_nl.as_dict()) assert ( - struct_nl == b + struct_nl == round_trip_from_dict ), "to/from dict is broken when object embedding is used! Apparently MontyEncoding is broken..." # Test molecule mol_nl = StructureNL(self.mol, self.hulk, references=self.pmg) - b = StructureNL.from_dict(mol_nl.as_dict()) - assert mol_nl == b + round_trip_from_dict = StructureNL.from_dict(mol_nl.as_dict()) + assert mol_nl == round_trip_from_dict def test_from_structures(self): s1 = Structure(np.eye(3) * 5, ["Fe"], [[0, 0, 0]]) diff --git a/tests/util/test_string.py b/tests/util/test_string.py index 1bb0785cef3..5c04a1d51ea 100644 --- a/tests/util/test_string.py +++ b/tests/util/test_string.py @@ -2,6 +2,7 @@ import numpy as np import pytest + from pymatgen.core import Structure from pymatgen.util.string import ( Stringify, diff --git a/tests/util/test_typing.py b/tests/util/test_typing.py index 9b56e8d4705..994e8731af6 100644 --- a/tests/util/test_typing.py +++ b/tests/util/test_typing.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, get_args import pytest + from pymatgen.core import Composition, DummySpecies, Element, Species from pymatgen.entries import Entry from pymatgen.util.typing import CompositionLike, EntryLike, PathLike, PbcLike, SpeciesLike diff --git a/tests/vis/test_plotters.py b/tests/vis/test_plotters.py index ab9c95ad547..ca0594b4c2d 100644 --- a/tests/vis/test_plotters.py +++ b/tests/vis/test_plotters.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np from monty.json import MontyDecoder + from pymatgen.analysis.xas.spectrum import XAS from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest from pymatgen.vis.plotters import SpectrumPlotter