-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic fortran binary file parser
- Loading branch information
Showing
6 changed files
with
194 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Parser for cst_esp files.""" | ||
from __future__ import annotations | ||
|
||
from typing import BinaryIO, TypedDict | ||
|
||
from ..utilities.utility import to_type | ||
from .fortran_bin_parser import binary_file_reader | ||
|
||
|
||
class ESPData(TypedDict): | ||
"""Data from electrostatic potential.""" | ||
|
||
#: Number of spins in run. | ||
n_spins: int | ||
#: Grid size sampled at. | ||
grid: int | ||
#: Data. | ||
val: tuple[tuple[tuple[complex, ...]]] | ||
|
||
|
||
def parse_cst_esp_file(cst_esp_file: BinaryIO) -> ESPData: | ||
"""Parse castep `cst_esp` files. | ||
Parameters | ||
---------- | ||
cst_esp_file : BinaryIO | ||
File to parse. | ||
Returns | ||
------- | ||
ESPData | ||
Parsed data. | ||
""" | ||
dtypes = {"n_spins": int, "grid": int} | ||
|
||
accum = {"val": []} | ||
|
||
reader = binary_file_reader(cst_esp_file) | ||
for (key, typ), datum in zip(dtypes.items(), reader): | ||
accum[key] = to_type(datum, typ) | ||
|
||
for datum in reader: | ||
accum["val"].append(to_type(datum[8:], complex)) # Discard indices | ||
|
||
return accum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""General parser for the Fortran Unformatted file format.""" | ||
from os import SEEK_CUR | ||
from typing import BinaryIO, Generator | ||
|
||
FortranBinaryReader = Generator[bytes, int, None] | ||
|
||
def binary_file_reader(file: BinaryIO) -> FortranBinaryReader: | ||
"""Yield the elements of a Fortran unformatted file.""" | ||
while bin_size := file.read(4): | ||
size = int.from_bytes(bin_size, "big") | ||
data = file.read(size) | ||
skip = yield data | ||
file.read(4) | ||
if skip: # NB. Send proceeds to yield. | ||
# `True` implies rewind 1 | ||
if skip < 0 or skip is True: | ||
for _ in range(abs(skip)): | ||
# Rewind to record size before last read | ||
file.seek(-size-12, SEEK_CUR) | ||
size = int.from_bytes(file.read(4), "big") | ||
|
||
# Rewind one extra (which will be yielded) | ||
file.seek(-size-8, SEEK_CUR) | ||
else: | ||
for _ in range(skip): | ||
size = int.from_bytes(file.read(4), "big") | ||
file.seek(size+4, SEEK_CUR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import struct | ||
from functools import singledispatch | ||
from io import BytesIO | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import pytest | ||
|
||
from castep_outputs.bin_parsers.fortran_bin_parser import binary_file_reader | ||
from castep_outputs.utilities.utility import to_type | ||
|
||
DATA_FILES = Path(__file__).parent / "data_files" | ||
|
||
@singledispatch | ||
def fort_unformat(inp): | ||
raise NotImplementedError(f"Cannot convert {type(inp).__name__}.") | ||
|
||
|
||
@fort_unformat.register(list) | ||
@fort_unformat.register(tuple) | ||
def _(inp): | ||
return add_size(b"".join(map(to_bytes, inp))) | ||
|
||
|
||
@fort_unformat.register(float) | ||
@fort_unformat.register(int) | ||
@fort_unformat.register(str) | ||
def _(inp): | ||
return add_size(to_bytes(inp)) | ||
|
||
|
||
@singledispatch | ||
def to_bytes(inp): | ||
raise NotImplementedError(f"Cannot convert {type(inp).__name__}.") | ||
|
||
|
||
@to_bytes.register | ||
def _(inp: float): | ||
return struct.pack(">d", inp) | ||
|
||
|
||
@to_bytes.register | ||
def _(inp: int): | ||
return struct.pack(">i", inp) | ||
|
||
|
||
@to_bytes.register | ||
def _(inp: str): | ||
return struct.pack(f">{len(inp)}s", bytes(inp, encoding="utf-8")) | ||
|
||
|
||
def add_size(inp: bytes): | ||
size = to_bytes(len(inp)) | ||
return b"".join((size, inp, size)) | ||
|
||
|
||
def to_unformat_file(*inp: Any): | ||
return BytesIO(b"".join(map(fort_unformat, inp))) | ||
|
||
|
||
def test_binary_file_reader(): | ||
sample = (1, 2, 3, 2, 3.1, "Hello", (1., 3., 6.)) | ||
|
||
data, types = to_unformat_file(*sample), (*map(type, sample[:-1]), float) | ||
reader = binary_file_reader(data) | ||
|
||
for datum, typ, expected in zip(reader, types, sample): | ||
assert to_type(datum, typ) == expected | ||
|
||
|
||
def test_actual_read(): | ||
dtypes = ((int, 2), (int, (27, 27, 27))) | ||
|
||
with (DATA_FILES / "test.cst_esp").open("rb") as file: | ||
reader = binary_file_reader(file) | ||
for (typ, expected), datum in zip(dtypes, reader): | ||
assert to_type(datum, typ) == expected | ||
for (y, datum) in enumerate(reader, 1): | ||
nx, ny = to_type(datum[:8], int) | ||
assert nx == 1 | ||
assert ny == y | ||
res = to_type(datum[8:], complex) | ||
assert len(res) == 27 | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |