Skip to content

Commit

Permalink
Add basic fortran binary file parser
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Feb 13, 2025
1 parent 5a13e21 commit ac4f029
Show file tree
Hide file tree
Showing 13 changed files with 23,188 additions and 10 deletions.
17 changes: 17 additions & 0 deletions castep_outputs/bin_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""List of parsers for binary file formats."""
from __future__ import annotations

from collections.abc import Callable

from .cst_esp_file_parser import parse_cst_esp_file

#: Dictionary of available parsers.
PARSERS: dict[str, Callable] = {
"cst_esp": parse_cst_esp_file,
}

#: Names of parsers/parsable file extensions (without ``"."``).
CASTEP_OUTPUT_NAMES: tuple[str, ...] = tuple(PARSERS.keys())

#: Names of parsable file extensions.
CASTEP_FILE_FORMATS: tuple[str, ...] = tuple(f".{typ}" for typ in CASTEP_OUTPUT_NAMES)
60 changes: 60 additions & 0 deletions castep_outputs/bin_parsers/cst_esp_file_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Parser for cst_esp files."""
from __future__ import annotations

from typing import BinaryIO, TypedDict

from ..utilities.utility import to_type
from .fortran_bin_parser import binary_file_reader


class ESPData(TypedDict):
"""Data from electrostatic potential."""

#: Number of spins in run.
n_spins: int
#: Grid size sampled at.
grid: int
#: ESP Data.
esp: tuple[tuple[tuple[complex, ...]]]
#: MGGA
mgga: tuple[tuple[tuple[complex, ...]]]


def parse_cst_esp_file(cst_esp_file: BinaryIO) -> ESPData:
"""Parse castep `cst_esp` files.
Parameters
----------
cst_esp_file : BinaryIO
File to parse.
Returns
-------
ESPData
Parsed data.
"""
dtypes = {"n_spins": int, "grid": int}

accum = {"esp": []}

reader = binary_file_reader(cst_esp_file)
for (key, typ), datum in zip(dtypes.items(), reader):
accum[key] = to_type(datum, typ)

prev_nx = None
curr = []
for datum in reader:
nx, ny = to_type(datum[:8], int)
if prev_nx != nx and curr:
accum["esp"].append(curr)
curr = []
curr.append(to_type(datum[8:], complex))
prev_nx = nx

accum["esp"].append(curr)

size = accum["grid"][0]
if len(accum["esp"]) > size: # Have MGGA pot
accum["esp"], accum["mgga"] = accum["esp"][:size], accum["esp"][size:]

return accum
27 changes: 27 additions & 0 deletions castep_outputs/bin_parsers/fortran_bin_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""General parser for the Fortran Unformatted file format."""
from os import SEEK_CUR
from typing import BinaryIO, Generator

FortranBinaryReader = Generator[bytes, int, None]

def binary_file_reader(file: BinaryIO) -> FortranBinaryReader:
"""Yield the elements of a Fortran unformatted file."""
while bin_size := file.read(4):
size = int.from_bytes(bin_size, "big")
data = file.read(size)
skip = yield data
file.read(4)
if skip: # NB. Send proceeds to yield.
# `True` implies rewind 1
if skip < 0 or skip is True:
for _ in range(abs(skip)):
# Rewind to record size before last read
file.seek(-size-12, SEEK_CUR)
size = int.from_bytes(file.read(4), "big")

# Rewind one extra (which will be yielded)
file.seek(-size-8, SEEK_CUR)
else:
for _ in range(skip):
size = int.from_bytes(file.read(4), "big")
file.seek(size+4, SEEK_CUR)
28 changes: 20 additions & 8 deletions castep_outputs/cli/castep_outputs_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@
import io
import logging
import sys
from collections import ChainMap
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any, TextIO

from ..bin_parsers import PARSERS as BIN_PARSERS
from ..parsers import PARSERS
from ..utilities.constants import OutFormats
from ..utilities.dumpers import get_dumpers
from ..utilities.utility import flatten_dict, json_safe, normalise
from .args import extract_parsables, parse_args

ALL_PARSERS = ChainMap(PARSERS, BIN_PARSERS)


def parse_single(in_file: str | Path | TextIO,
parser: Callable[[TextIO], list[dict[str, Any]]] | None = None,
Expand Down Expand Up @@ -57,18 +61,26 @@ def parse_single(in_file: str | Path | TextIO,
if parser is None and isinstance(in_file, Path):
ext = in_file.suffix.strip(".")

if ext not in PARSERS:
if ext not in ALL_PARSERS:
raise KeyError(f"Parser for file {in_file} (assumed type: {ext}) not found")

parser = PARSERS[ext]
parser = ALL_PARSERS[ext]

assert parser is not None

if isinstance(in_file, io.TextIOBase):
data = parser(in_file)
elif isinstance(in_file, Path):
with in_file.open(mode="r", encoding="utf-8") as file:
data = parser(file)
if parser in PARSERS.values():
if isinstance(in_file, io.TextIOBase):
data = parser(in_file)
elif isinstance(in_file, Path):
with in_file.open(mode="r", encoding="utf-8") as file:
data = parser(file)
elif parser in BIN_PARSERS.values():
if isinstance(in_file, io.IOBase):
data = parser(in_file)
elif isinstance(in_file, Path):
with in_file.open(mode="rb") as file:
data = parser(file)


if out_format == "json" or testing:
data = normalise(data, {dict: json_safe, complex: json_safe})
Expand Down Expand Up @@ -112,7 +124,7 @@ def parse_all(

data = {}
for typ, paths in files.items():
parser = PARSERS[typ]
parser = ALL_PARSERS[typ]
for path in paths:
data[path] = parse_single(path, parser, out_format, loglevel=loglevel, testing=testing)

Expand Down
33 changes: 32 additions & 1 deletion castep_outputs/utilities/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, MutableMapping, Sequence
from copy import copy
from functools import partial
from itertools import filterfalse
from struct import unpack
from typing import Any, TextIO, TypeVar

import castep_outputs.utilities.castep_res as REs
Expand Down Expand Up @@ -478,8 +480,30 @@ def _parse_logical(val: str) -> bool:
return val.title() in ("T", "True", "1")


def _parse_float_bytes(val: bytes):
ans = unpack(f">{len(val)//8}d", val)
return ans if len(ans) != 1 else ans[0]

def _parse_int_bytes(val: bytes):
ans = unpack(f">{len(val)//4}i", val)
return ans if len(ans) != 1 else ans[0]

def _parse_bool_bytes(val: bytes):
ans = bool(unpack(f">{len(val)//4}i", val))
return ans if len(ans) != 1 else ans[0]

def _parse_complex_bytes(val: bytes):
tmp = unpack(f">{len(val)//8}d", val)
ans = tuple(map(complex, tmp[::2], tmp[1::2]))
return ans if len(ans) != 1 else ans[0]

_TYPE_PARSERS: dict[type, Callable] = {float: _parse_float_or_rational,
bool: _parse_logical}
_BYTE_PARSERS: dict[type, Callable] = {complex: _parse_complex_bytes,
float: _parse_float_bytes,
bool: _parse_bool_bytes,
int: _parse_int_bytes,
str: partial(str, encoding="ascii")}


@functools.singledispatch
Expand Down Expand Up @@ -511,10 +535,17 @@ def _(data_in: str, typ: type[T]) -> T:
@to_type.register(tuple)
@to_type.register(list)
def _(data_in, typ: type[T]) -> tuple[T, ...]:
parser: Callable[[str], T] = _TYPE_PARSERS.get(typ, typ)
parse_dict = _BYTE_PARSERS if data_in and isinstance(data_in[0], bytes) else _TYPE_PARSERS
parser: Callable[[str], T] = parse_dict.get(typ, typ)
return tuple(parser(x) for x in data_in)


@to_type.register(bytes)
def _(data_in, typ: type[T]) -> T | tuple[T, ...]:
parser: Callable[[bytes], T] = _BYTE_PARSERS.get(typ, typ)
return parser(data_in)


def fix_data_types(in_dict: MutableMapping[str, Any], type_dict: dict[str, type]):
"""
Apply correct types to elements of `in_dict` by mapping given in `type_dict`.
Expand Down
Loading

0 comments on commit ac4f029

Please sign in to comment.