Skip to content

Commit

Permalink
Add basic fortran binary file parser
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Feb 12, 2025
1 parent f7e2e6d commit e9c2cb7
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 1 deletion.
Empty file.
60 changes: 60 additions & 0 deletions castep_outputs/bin_parsers/cst_esp_file_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Parser for cst_esp files."""
from __future__ import annotations

from typing import BinaryIO, TypedDict

from ..utilities.utility import to_type
from .fortran_bin_parser import binary_file_reader


class ESPData(TypedDict):
"""Data from electrostatic potential."""

#: Number of spins in run.
n_spins: int
#: Grid size sampled at.
grid: int
#: ESP Data.
esp: tuple[tuple[tuple[complex, ...]]]
#: MGGA
mgga: tuple[tuple[tuple[complex, ...]]]


def parse_cst_esp_file(cst_esp_file: BinaryIO) -> ESPData:
"""Parse castep `cst_esp` files.
Parameters
----------
cst_esp_file : BinaryIO
File to parse.
Returns
-------
ESPData
Parsed data.
"""
dtypes = {"n_spins": int, "grid": int}

accum = {"esp": []}

reader = binary_file_reader(cst_esp_file)
for (key, typ), datum in zip(dtypes.items(), reader):
accum[key] = to_type(datum, typ)

prev_nx = None
curr = []
for datum in reader:
nx, ny = to_type(datum[:8], int)
if prev_nx != nx and curr:
accum["esp"].append(curr)
curr = []
curr.append(to_type(datum[8:], complex))
prev_nx = nx

accum["esp"].append(curr)

size = accum["grid"][0]
if len(accum["esp"]) > size: # Have MGGA pot
accum["esp"], accum["mgga"] = accum["esp"][:size], accum["esp"][size:]

return accum
27 changes: 27 additions & 0 deletions castep_outputs/bin_parsers/fortran_bin_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""General parser for the Fortran Unformatted file format."""
from os import SEEK_CUR
from typing import BinaryIO, Generator

FortranBinaryReader = Generator[bytes, int, None]

def binary_file_reader(file: BinaryIO) -> FortranBinaryReader:
"""Yield the elements of a Fortran unformatted file."""
while bin_size := file.read(4):
size = int.from_bytes(bin_size, "big")
data = file.read(size)
skip = yield data
file.read(4)
if skip: # NB. Send proceeds to yield.
# `True` implies rewind 1
if skip < 0 or skip is True:
for _ in range(abs(skip)):
# Rewind to record size before last read
file.seek(-size-12, SEEK_CUR)
size = int.from_bytes(file.read(4), "big")

# Rewind one extra (which will be yielded)
file.seek(-size-8, SEEK_CUR)
else:
for _ in range(skip):
size = int.from_bytes(file.read(4), "big")
file.seek(size+4, SEEK_CUR)
3 changes: 3 additions & 0 deletions castep_outputs/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections.abc import Callable

from ..bin_parsers.cst_esp_file_parser import parse_cst_esp_file
from .bands_file_parser import parse_bands_file
from .castep_file_parser import parse_castep_file
from .cell_param_file_parser import parse_cell_param_file
Expand Down Expand Up @@ -30,6 +31,7 @@
"parse_cell_param_file",
"parse_cell_param_file",
"parse_chdiff_fmt_file",
"parse_cst_esp_file",
"parse_den_fmt_file",
"parse_efield_file",
"parse_elastic_file",
Expand Down Expand Up @@ -70,6 +72,7 @@
"tddft": parse_tddft_file,
"err": parse_err_file,
"phonon": parse_phonon_file,
"cst_esp": parse_cst_esp_file
}

#: Names of parsers/parsable file extensions (without ``"."``).
Expand Down
33 changes: 32 additions & 1 deletion castep_outputs/utilities/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableMapping, Sequence
from copy import copy
from functools import partial
from itertools import filterfalse
from struct import unpack
from typing import Any, TextIO, TypeVar

import castep_outputs.utilities.castep_res as REs
Expand Down Expand Up @@ -478,8 +480,30 @@ def _parse_logical(val: str) -> bool:
return val.title() in ("T", "True", "1")


def _parse_float_bytes(val: bytes):
ans = unpack(f">{len(val)//8}d", val)
return ans if len(ans) != 1 else ans[0]

def _parse_int_bytes(val: bytes):
ans = unpack(f">{len(val)//4}i", val)
return ans if len(ans) != 1 else ans[0]

def _parse_bool_bytes(val: bytes):
ans = bool(unpack(f">{len(val)//4}i", val))
return ans if len(ans) != 1 else ans[0]

def _parse_complex_bytes(val: bytes):
tmp = unpack(f">{len(val)//8}d", val)
ans = tuple(map(complex, tmp[::2], tmp[1::2]))
return ans if len(ans) != 1 else ans[0]

_TYPE_PARSERS: dict[type, Callable] = {float: _parse_float_or_rational,
bool: _parse_logical}
_BYTE_PARSERS: dict[type, Callable] = {complex: _parse_complex_bytes,
float: _parse_float_bytes,
bool: _parse_bool_bytes,
int: _parse_int_bytes,
str: partial(str, encoding="ascii")}


@functools.singledispatch
Expand Down Expand Up @@ -511,10 +535,17 @@ def _(data_in: str, typ: type[T]) -> T:
@to_type.register(tuple)
@to_type.register(list)
def _(data_in, typ: type[T]) -> tuple[T, ...]:
parser: Callable[[str], T] = _TYPE_PARSERS.get(typ, typ)
parse_dict = _BYTE_PARSERS if data_in and isinstance(data_in[0], bytes) else _TYPE_PARSERS
parser: Callable[[str], T] = parse_dict.get(typ, typ)
return tuple(parser(x) for x in data_in)


@to_type.register(bytes)
def _(data_in, typ: type[T]) -> T | tuple[T, ...]:
parser: Callable[[bytes], T] = _BYTE_PARSERS.get(typ, typ)
return parser(data_in)


def fix_data_types(in_dict: MutableMapping[str, Any], type_dict: dict[str, type]):
"""
Apply correct types to elements of `in_dict` by mapping given in `type_dict`.
Expand Down
87 changes: 87 additions & 0 deletions test/test_binary_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import struct
from functools import singledispatch
from io import BytesIO
from pathlib import Path
from typing import Any

import pytest

from castep_outputs.bin_parsers.fortran_bin_parser import binary_file_reader
from castep_outputs.utilities.utility import to_type

DATA_FILES = Path(__file__).parent / "data_files"

@singledispatch
def fort_unformat(inp):
raise NotImplementedError(f"Cannot convert {type(inp).__name__}.")


@fort_unformat.register(list)
@fort_unformat.register(tuple)
def _(inp):
return add_size(b"".join(map(to_bytes, inp)))


@fort_unformat.register(float)
@fort_unformat.register(int)
@fort_unformat.register(str)
def _(inp):
return add_size(to_bytes(inp))


@singledispatch
def to_bytes(inp):
raise NotImplementedError(f"Cannot convert {type(inp).__name__}.")


@to_bytes.register
def _(inp: float):
return struct.pack(">d", inp)


@to_bytes.register
def _(inp: int):
return struct.pack(">i", inp)


@to_bytes.register
def _(inp: str):
return struct.pack(f">{len(inp)}s", bytes(inp, encoding="utf-8"))


def add_size(inp: bytes):
size = to_bytes(len(inp))
return b"".join((size, inp, size))


def to_unformat_file(*inp: Any):
return BytesIO(b"".join(map(fort_unformat, inp)))


def test_binary_file_reader():
sample = (1, 2, 3, 2, 3.1, "Hello", (1., 3., 6.))

data, types = to_unformat_file(*sample), (*map(type, sample[:-1]), float)
reader = binary_file_reader(data)

for datum, typ, expected in zip(reader, types, sample):
assert to_type(datum, typ) == expected


def test_actual_read():
dtypes = ((int, 2), (int, (27, 27, 27)))

with (DATA_FILES / "test.cst_esp").open("rb") as file:
reader = binary_file_reader(file)
for (typ, expected), datum in zip(dtypes, reader):
assert to_type(datum, typ) == expected
for (y, datum) in enumerate(reader, 1):
nx, ny = to_type(datum[:8], int)
assert nx == 1
assert ny == y
res = to_type(datum[8:], complex)
assert len(res) == 27


if __name__ == "__main__":
pytest.main()

0 comments on commit e9c2cb7

Please sign in to comment.