Skip to content

Commit

Permalink
write test-created files to temporary directory (#3454)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh authored Nov 3, 2023
1 parent 8ce1cdd commit 2ccbfa1
Show file tree
Hide file tree
Showing 23 changed files with 114 additions and 170 deletions.
3 changes: 1 addition & 2 deletions pymatgen/analysis/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import scipy.constants as const
from monty.design_patterns import singleton
from monty.string import unicode2str

from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram
from pymatgen.core.composition import Composition
Expand Down Expand Up @@ -94,7 +93,7 @@ def __init__(self, filename):
self._chemsys_entries = defaultdict(list)
filename = os.path.join(os.path.dirname(__file__), filename)
with open(filename) as f:
reader = csv.reader(f, quotechar=unicode2str("|"))
reader = csv.reader(f, quotechar="|")
for row in reader:
comp = Composition(row[0])
cost_per_mol = float(row[1]) * comp.weight.to("kg") * const.N_A
Expand Down
3 changes: 1 addition & 2 deletions pymatgen/core/xcfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from monty.functools import lazy_property
from monty.json import MSONable
from monty.string import is_string

from pymatgen.core.libxcfunc import LibxcFunc

Expand Down Expand Up @@ -122,7 +121,7 @@ def asxc(cls, obj):
"""Convert object into Xcfunc."""
if isinstance(obj, cls):
return obj
if is_string(obj):
if isinstance(obj, str):
return cls.from_name(obj)
raise TypeError(f"Don't know how to convert <{type(obj)}:{obj}> to Xcfunc")

Expand Down
9 changes: 4 additions & 5 deletions pymatgen/entries/entry_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import TYPE_CHECKING, Literal

from monty.json import MontyDecoder, MontyEncoder, MSONable
from monty.string import unicode2str

from pymatgen.analysis.phase_diagram import PDEntry
from pymatgen.analysis.structure_matcher import SpeciesComparator, StructureMatcher
Expand Down Expand Up @@ -302,8 +301,8 @@ def to_csv(self, filename: str, latexify_names: bool = False) -> None:
with open(filename, "w") as f:
writer = csv.writer(
f,
delimiter=unicode2str(","),
quotechar=unicode2str('"'),
delimiter=",",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
)
writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"])
Expand All @@ -326,8 +325,8 @@ def from_csv(cls, filename: str):
with open(filename, encoding="utf-8") as f:
reader = csv.reader(
f,
delimiter=unicode2str(","),
quotechar=unicode2str('"'),
delimiter=",",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
)
entries = []
Expand Down
12 changes: 6 additions & 6 deletions pymatgen/io/abinit/abitimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import matplotlib.pyplot as plt
import numpy as np
from monty.string import is_string, list_strings

from pymatgen.io.core import ParseError
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig
Expand Down Expand Up @@ -107,7 +106,8 @@ def parse(self, filenames):
Return: list of successfully read files.
"""
filenames = list_strings(filenames)
if isinstance(filenames, str):
filenames = [filenames]

read_ok = []
for fname in filenames:
Expand Down Expand Up @@ -667,16 +667,16 @@ def get_section(self, section_name):

def to_csv(self, fileobj=sys.stdout):
"""Write data on file fileobj using CSV format."""
openclose = is_string(fileobj)
is_str = isinstance(fileobj, str)

if openclose:
if is_str:
fileobj = open(fileobj, "w") # noqa: SIM115

for idx, section in enumerate(self.sections):
fileobj.write(section.to_csvline(with_header=(idx == 0)))
fileobj.flush()

if openclose:
if is_str:
fileobj.close()

def to_table(self, sort_key="wall_time", stop=None):
Expand Down Expand Up @@ -718,7 +718,7 @@ def get_dataframe(self, sort_key="wall_time", **kwargs):

def get_values(self, keys):
"""Return a list of values associated to a particular list of keys."""
if is_string(keys):
if isinstance(keys, str):
return [s.__dict__[keys] for s in self.sections]
values = []
for k in keys:
Expand Down
24 changes: 14 additions & 10 deletions pymatgen/io/abinit/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
from monty.collections import AttrDict
from monty.json import MSONable
from monty.string import is_string, list_strings

from pymatgen.core.structure import Structure
from pymatgen.io.abinit import abiobjects as aobj
Expand Down Expand Up @@ -99,9 +98,7 @@


# Default values used if user does not specify them
_DEFAULTS = {
"kppa": 1000,
}
_DEFAULTS = {"kppa": 1000}


def as_structure(obj):
Expand All @@ -115,7 +112,7 @@ def as_structure(obj):
if isinstance(obj, Structure):
return obj

if is_string(obj):
if isinstance(obj, str):
return Structure.from_file(obj)

if isinstance(obj, Mapping):
Expand Down Expand Up @@ -151,7 +148,7 @@ def from_object(cls, obj):
"""
if isinstance(obj, cls):
return obj
if is_string(obj):
if isinstance(obj, str):
return cls(obj[0].upper())
raise TypeError(f"The object provided is not handled: type {type(obj).__name__}")

Expand Down Expand Up @@ -676,8 +673,10 @@ def remove_vars(self, keys, strict=True):
keys: string or list of strings with variable names.
strict: If True, KeyError is raised if at least one variable is not present.
"""
if isinstance(keys, str):
keys = [keys]
removed = {}
for key in list_strings(keys):
for key in keys:
if strict and key not in self:
raise KeyError(f"{key=} not in self:\n {list(self)}")
if key in self:
Expand Down Expand Up @@ -710,7 +709,7 @@ class BasicAbinitInput(AbstractInput, MSONable):
def __init__(
self,
structure,
pseudos,
pseudos: str | list[str] | list[Pseudo] | PseudoTable,
pseudo_dir=None,
comment=None,
abi_args=None,
Expand Down Expand Up @@ -745,11 +744,14 @@ def __init__(
self._vars = dict(args)
self.set_structure(structure)

if isinstance(pseudos, str):
pseudos = [pseudos]

if pseudo_dir is not None:
pseudo_dir = os.path.abspath(pseudo_dir)
if not os.path.exists(pseudo_dir):
raise self.Error(f"Directory {pseudo_dir} does not exist")
pseudos = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)]
pseudos = [os.path.join(pseudo_dir, p) for p in pseudos]

try:
self._pseudos = PseudoTable.as_table(pseudos).get_pseudos_for_structure(self.structure)
Expand Down Expand Up @@ -1092,8 +1094,10 @@ def __init__(self, structure: Structure, pseudos, pseudo_dir="", ndtset=1):

else:
# String(s)
if isinstance(pseudos, str):
pseudos = [pseudos]
pseudo_dir = os.path.abspath(pseudo_dir)
pseudo_paths = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)]
pseudo_paths = [os.path.join(pseudo_dir, p) for p in pseudos]

missing = [p for p in pseudo_paths if not os.path.exists(p)]
if missing:
Expand Down
11 changes: 6 additions & 5 deletions pymatgen/io/abinit/pseudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from monty.itertools import iterator_from_slice
from monty.json import MontyDecoder, MSONable
from monty.os.path import find_exts
from monty.string import is_string, list_strings
from tabulate import tabulate

from pymatgen.core.periodic_table import Element
Expand Down Expand Up @@ -612,7 +611,7 @@ def _dict_from_lines(lines, key_nums, sep=None):
Raises:
ValueError if parsing fails.
"""
if is_string(lines):
if isinstance(lines, str):
lines = [lines]

if not isinstance(key_nums, collections.abc.Iterable):
Expand Down Expand Up @@ -1611,8 +1610,8 @@ def __init__(self, pseudos: Sequence[Pseudo]) -> None:
if not isinstance(pseudos, collections.abc.Iterable):
pseudos = [pseudos]

if len(pseudos) and is_string(pseudos[0]):
pseudos = list_strings(pseudos)
if isinstance(pseudos, str):
pseudos = [pseudos]

self._pseudos_with_z = defaultdict(list)

Expand Down Expand Up @@ -1772,7 +1771,9 @@ def select_symbols(self, symbols, ret_list=False):
Prepend the symbol string with "-", to exclude pseudos.
ret_list: if True a list of pseudos is returned instead of a PseudoTable
"""
symbols = list_strings(symbols)
if isinstance(symbols, str):
symbols = [symbols]

exclude = symbols[0].startswith("-")

if exclude:
Expand Down
3 changes: 1 addition & 2 deletions pymatgen/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
from monty.io import zopen
from monty.serialization import loadfn
from monty.string import remove_non_ascii

from pymatgen.core.composition import Composition
from pymatgen.core.lattice import Lattice
Expand Down Expand Up @@ -137,7 +136,7 @@ def _process_string(cls, string):
# remove empty lines
string = re.sub(r"^\s*\n", "", string, flags=re.MULTILINE)
# remove non_ascii
string = remove_non_ascii(string)
string = string.encode("ascii", "ignore").decode("ascii")
# since line breaks in .cif files are mostly meaningless,
# break up into a stream of tokens to parse, rejoining multiline
# strings (between semicolons)
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/vasp/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def __init__(

updates: dict[str, float] = {}
# select the KSPACING and smearing parameters based on the bandgap
if self.bandgap < 1e-4:
if self.bandgap < bandgap_tol:
updates.update(KSPACING=0.22, SIGMA=0.2, ISMEAR=2)
else:
rmin = max(1.5, 25.22 - 2.87 * bandgap) # Eq. 25
Expand Down
9 changes: 4 additions & 5 deletions pymatgen/util/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from io import StringIO

from monty.json import MontyDecoder, MontyEncoder
from monty.string import remove_non_ascii

try:
from pybtex import errors
Expand All @@ -29,10 +28,10 @@
__credits__ = "Dan Gunter"


MAX_HNODE_SIZE = 64000 # maximum size (bytes) of SNL HistoryNode
MAX_DATA_SIZE = 256000 # maximum size (bytes) of SNL data field
MAX_HNODE_SIZE = 64_000 # maximum size (bytes) of SNL HistoryNode
MAX_DATA_SIZE = 256_000 # maximum size (bytes) of SNL data field
MAX_HNODES = 100 # maximum number of HistoryNodes in SNL file
MAX_BIBTEX_CHARS = 20000 # maximum number of characters for BibTeX reference
MAX_BIBTEX_CHARS = 20_000 # maximum number of characters for BibTeX reference


def is_valid_bibtex(reference: str) -> bool:
Expand All @@ -46,7 +45,7 @@ def is_valid_bibtex(reference: str) -> bool:
"""
# str is necessary since pybtex seems to have an issue with unicode. The
# filter expression removes all non-ASCII characters.
sio = StringIO(remove_non_ascii(reference))
sio = StringIO(reference.encode("ascii", "ignore").decode("ascii"))
parser = bibtex.Parser()
errors.set_strict_mode(enable=False)
bib_data = parser.parse_stream(sio)
Expand Down
6 changes: 2 additions & 4 deletions tests/core/test_periodic_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import math
import os
import pickle
import unittest
from copy import deepcopy
Expand Down Expand Up @@ -405,13 +404,12 @@ def test_pickle(self):
cs = Species("Cs1+")
cl = Species("Cl1+")

with open("cscl.pickle", "wb") as file:
with open(f"{self.tmp_path}/cscl.pickle", "wb") as file:
pickle.dump((cs, cl), file)

with open("cscl.pickle", "rb") as file:
with open(f"{self.tmp_path}/cscl.pickle", "rb") as file:
tup = pickle.load(file)
assert tup == (cs, cl)
os.remove("cscl.pickle")

def test_get_crystal_field_spin(self):
assert Species("Fe", 2).get_crystal_field_spin() == 4
Expand Down
5 changes: 2 additions & 3 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,9 +2085,8 @@ def test_to_from_file_string(self):
assert m == self.mol
assert isinstance(m, Molecule)

self.mol.to(filename="CH4_testing.xyz")
assert os.path.isfile("CH4_testing.xyz")
os.remove("CH4_testing.xyz")
self.mol.to(filename=f"{self.tmp_path}/CH4_testing.xyz")
assert os.path.isfile(f"{self.tmp_path}/CH4_testing.xyz")

def test_extract_cluster(self):
species = self.mol.species * 2
Expand Down
15 changes: 6 additions & 9 deletions tests/core/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import copy
import os

import numpy as np
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -448,12 +447,11 @@ def test_variable_lattice(self):
assert all(np.allclose(struct.lattice.matrix, structures[i].lattice.matrix) for i, struct in enumerate(traj))

# Check if the file is written correctly when lattice is not constant.
traj.write_Xdatcar(filename="traj_test_XDATCAR")
traj.write_Xdatcar(filename=f"{self.tmp_path}/traj_test_XDATCAR")

# Load trajectory from written xdatcar and compare to original
written_traj = Trajectory.from_file("traj_test_XDATCAR", constant_lattice=False)
# Load trajectory from written XDATCAR and compare to original
written_traj = Trajectory.from_file(f"{self.tmp_path}/traj_test_XDATCAR", constant_lattice=False)
self._check_traj_equality(traj, written_traj)
os.remove("traj_test_XDATCAR")

def test_as_from_dict(self):
d = self.traj.as_dict()
Expand All @@ -465,9 +463,8 @@ def test_as_from_dict(self):
assert isinstance(traj, Trajectory)

def test_xdatcar_write(self):
self.traj.write_Xdatcar(filename="traj_test_XDATCAR")
self.traj.write_Xdatcar(filename=f"{self.tmp_path}/traj_test_XDATCAR")

# Load trajectory from written xdatcar and compare to original
written_traj = Trajectory.from_file("traj_test_XDATCAR")
# Load trajectory from written XDATCAR and compare to original
written_traj = Trajectory.from_file(f"{self.tmp_path}/traj_test_XDATCAR")
self._check_traj_equality(self.traj, written_traj)
os.remove("traj_test_XDATCAR")
Loading

0 comments on commit 2ccbfa1

Please sign in to comment.