diff --git a/castep_outputs/bin_parsers/__init__.py b/castep_outputs/bin_parsers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/castep_outputs/bin_parsers/cst_esp_file_parser.py b/castep_outputs/bin_parsers/cst_esp_file_parser.py new file mode 100644 index 0000000..ed5f84e --- /dev/null +++ b/castep_outputs/bin_parsers/cst_esp_file_parser.py @@ -0,0 +1,45 @@ +"""Parser for cst_esp files.""" +from __future__ import annotations + +from typing import BinaryIO, TypedDict + +from ..utilities.utility import to_type +from .fortran_bin_parser import binary_file_reader + + +class ESPData(TypedDict): + """Data from electrostatic potential.""" + + #: Number of spins in run. + n_spins: int + #: Grid size sampled at. + grid: int + #: Data. + val: tuple[tuple[tuple[complex, ...]]] + + +def parse_cst_esp_file(cst_esp_file: BinaryIO) -> ESPData: + """Parse castep `cst_esp` files. + + Parameters + ---------- + cst_esp_file : BinaryIO + File to parse. + + Returns + ------- + ESPData + Parsed data. + """ + dtypes = {"n_spins": int, "grid": int} + + accum = {"val": []} + + reader = binary_file_reader(cst_esp_file) + for (key, typ), datum in zip(dtypes.items(), reader): + accum[key] = to_type(datum, typ) + + for datum in reader: + accum["val"].append(to_type(datum[8:], complex)) # Discard indices + + return accum diff --git a/castep_outputs/bin_parsers/fortran_bin_parser.py b/castep_outputs/bin_parsers/fortran_bin_parser.py new file mode 100644 index 0000000..2f480b7 --- /dev/null +++ b/castep_outputs/bin_parsers/fortran_bin_parser.py @@ -0,0 +1,27 @@ +"""General parser for the Fortran Unformatted file format.""" +from os import SEEK_CUR +from typing import BinaryIO, Generator + +FortranBinaryReader = Generator[bytes, int, None] + +def binary_file_reader(file: BinaryIO) -> FortranBinaryReader: + """Yield the elements of a Fortran unformatted file.""" + while bin_size := file.read(4): + size = int.from_bytes(bin_size, "big") + data = file.read(size) + skip = yield data + file.read(4) + if skip: # NB. Send proceeds to yield. + # `True` implies rewind 1 + if skip < 0 or skip is True: + for _ in range(abs(skip)): + # Rewind to record size before last read + file.seek(-size-12, SEEK_CUR) + size = int.from_bytes(file.read(4), "big") + + # Rewind one extra (which will be yielded) + file.seek(-size-8, SEEK_CUR) + else: + for _ in range(skip): + size = int.from_bytes(file.read(4), "big") + file.seek(size+4, SEEK_CUR) diff --git a/castep_outputs/parsers/__init__.py b/castep_outputs/parsers/__init__.py index 62303d5..34a5050 100644 --- a/castep_outputs/parsers/__init__.py +++ b/castep_outputs/parsers/__init__.py @@ -3,6 +3,7 @@ from collections.abc import Callable +from ..bin_parsers.cst_esp_file_parser import parse_cst_esp_file from .bands_file_parser import parse_bands_file from .castep_file_parser import parse_castep_file from .cell_param_file_parser import parse_cell_param_file @@ -30,6 +31,7 @@ "parse_cell_param_file", "parse_cell_param_file", "parse_chdiff_fmt_file", + "parse_cst_esp_file", "parse_den_fmt_file", "parse_efield_file", "parse_elastic_file", @@ -70,6 +72,7 @@ "tddft": parse_tddft_file, "err": parse_err_file, "phonon": parse_phonon_file, + "cst_esp": parse_cst_esp_file } #: Names of parsers/parsable file extensions (without ``"."``). diff --git a/castep_outputs/utilities/utility.py b/castep_outputs/utilities/utility.py index d15d28c..761489a 100644 --- a/castep_outputs/utilities/utility.py +++ b/castep_outputs/utilities/utility.py @@ -10,7 +10,9 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, MutableMapping, Sequence from copy import copy +from functools import partial from itertools import filterfalse +from struct import unpack from typing import Any, TextIO, TypeVar import castep_outputs.utilities.castep_res as REs @@ -478,8 +480,30 @@ def _parse_logical(val: str) -> bool: return val.title() in ("T", "True", "1") +def _parse_float_bytes(val: bytes): + ans = unpack(f">{len(val)//8}d", val) + return ans if len(ans) != 1 else ans[0] + +def _parse_int_bytes(val: bytes): + ans = unpack(f">{len(val)//4}i", val) + return ans if len(ans) != 1 else ans[0] + +def _parse_bool_bytes(val: bytes): + ans = bool(unpack(f">{len(val)//4}i", val)) + return ans if len(ans) != 1 else ans[0] + +def _parse_complex_bytes(val: bytes): + tmp = unpack(f">{len(val)//8}d", val) + ans = tuple(map(complex, tmp[::2], tmp[1::2])) + return ans if len(ans) != 1 else ans[0] + _TYPE_PARSERS: dict[type, Callable] = {float: _parse_float_or_rational, bool: _parse_logical} +_BYTE_PARSERS: dict[type, Callable] = {complex: _parse_complex_bytes, + float: _parse_float_bytes, + bool: _parse_bool_bytes, + int: _parse_int_bytes, + str: partial(str, encoding="ascii")} @functools.singledispatch @@ -511,10 +535,17 @@ def _(data_in: str, typ: type[T]) -> T: @to_type.register(tuple) @to_type.register(list) def _(data_in, typ: type[T]) -> tuple[T, ...]: - parser: Callable[[str], T] = _TYPE_PARSERS.get(typ, typ) + parse_dict = _BYTE_PARSERS if data_in and isinstance(data_in[0], bytes) else _TYPE_PARSERS + parser: Callable[[str], T] = parse_dict.get(typ, typ) return tuple(parser(x) for x in data_in) +@to_type.register(bytes) +def _(data_in, typ: type[T]) -> T | tuple[T, ...]: + parser: Callable[[bytes], T] = _BYTE_PARSERS.get(typ, typ) + return parser(data_in) + + def fix_data_types(in_dict: MutableMapping[str, Any], type_dict: dict[str, type]): """ Apply correct types to elements of `in_dict` by mapping given in `type_dict`. diff --git a/test/test_binary_reader.py b/test/test_binary_reader.py new file mode 100644 index 0000000..596717d --- /dev/null +++ b/test/test_binary_reader.py @@ -0,0 +1,87 @@ +import struct +from functools import singledispatch +from io import BytesIO +from pathlib import Path +from typing import Any + +import pytest + +from castep_outputs.bin_parsers.fortran_bin_parser import binary_file_reader +from castep_outputs.utilities.utility import to_type + +DATA_FILES = Path(__file__).parent / "data_files" + +@singledispatch +def fort_unformat(inp): + raise NotImplementedError(f"Cannot convert {type(inp).__name__}.") + + +@fort_unformat.register(list) +@fort_unformat.register(tuple) +def _(inp): + return add_size(b"".join(map(to_bytes, inp))) + + +@fort_unformat.register(float) +@fort_unformat.register(int) +@fort_unformat.register(str) +def _(inp): + return add_size(to_bytes(inp)) + + +@singledispatch +def to_bytes(inp): + raise NotImplementedError(f"Cannot convert {type(inp).__name__}.") + + +@to_bytes.register +def _(inp: float): + return struct.pack(">d", inp) + + +@to_bytes.register +def _(inp: int): + return struct.pack(">i", inp) + + +@to_bytes.register +def _(inp: str): + return struct.pack(f">{len(inp)}s", bytes(inp, encoding="utf-8")) + + +def add_size(inp: bytes): + size = to_bytes(len(inp)) + return b"".join((size, inp, size)) + + +def to_unformat_file(*inp: Any): + return BytesIO(b"".join(map(fort_unformat, inp))) + + +def test_binary_file_reader(): + sample = (1, 2, 3, 2, 3.1, "Hello", (1., 3., 6.)) + + data, types = to_unformat_file(*sample), (*map(type, sample[:-1]), float) + reader = binary_file_reader(data) + + for datum, typ, expected in zip(reader, types, sample): + assert to_type(datum, typ) == expected + + +def test_actual_read(): + dtypes = ((int, 2), (int, (27, 27, 27))) + + with (DATA_FILES / "test.cst_esp").open("rb") as file: + reader = binary_file_reader(file) + for (typ, expected), datum in zip(dtypes, reader): + assert to_type(datum, typ) == expected + for (y, datum) in enumerate(reader, 1): + nx, ny = to_type(datum[:8], int) + assert nx == 1 + assert ny == y + res = to_type(datum[8:], complex) + assert len(res) == 27 + + +if __name__ == "__main__": + pytest.main()