Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PandA HDF writer #90

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ophyd_async/panda/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .hdf_writer import PandaHDFWriter
from .panda_hdf import PandaHDF

__all__ = ["PandaHDFWriter", "PandaHDF"]
112 changes: 112 additions & 0 deletions src/ophyd_async/panda/writers/hdf_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import asyncio
from typing import AsyncIterator, Dict, List, Optional

from bluesky.protocols import Asset, Descriptor, Hints

from ophyd_async.core import (
DEFAULT_TIMEOUT,
AsyncStatus,
DetectorWriter,
DirectoryProvider,
NameProvider,
set_and_wait_for_value,
wait_for_value,
)

from .panda_hdf import PandaHDF, _HDFDataset, _HDFFile


class PandaHDFWriter(DetectorWriter):
def __init__(
self,
hdf: PandaHDF,
directory_provider: DirectoryProvider,
name_provider: NameProvider,
**scalar_datasets_paths: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should probably check the ":CAPTURE" pv for each dataset, then only emit the StreamResource for the dataset if it isn't "No"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that makes sense

) -> None:
self.hdf = hdf
self._directory_provider = directory_provider
self._name_provider = name_provider
self._scalar_datasets_paths = scalar_datasets_paths
self._capture_status: Optional[AsyncStatus] = None
self._datasets: List[_HDFDataset] = []
self._file: Optional[_HDFFile] = None
self._multiplier = 1

async def open(self, multiplier: int = 1) -> Dict[str, Descriptor]:
self._file = None
info = self._directory_provider()
await asyncio.gather(
self.hdf.file_path.set(info.directory_path),
self.hdf.file_name.set(f"{info.filename_prefix}.h5"),
)

# Overwrite num_capture to go forever
await self.hdf.num_capture.set(0)
# Wait for it to start, stashing the status that tells us when it finishes
self._capture_status = await set_and_wait_for_value(self.hdf.capture, True)
name = self._name_provider()
if multiplier > 1:
raise ValueError(
"All PandA datasets should be scalar, multiplier should be 1"
)
self._multiplier = multiplier
self._datasets = []
# Add all the scalar datasets
for ds_name, ds_path in self._scalar_datasets_paths.items():
self._datasets.append(
_HDFDataset(
f"{name}-{ds_name}",
ds_path,
(),
multiplier,
)
)
describe = {
ds.name: Descriptor(
source=self.hdf.full_file_name.source,
shape=ds.shape,
dtype="array" if ds.shape else "number",
external="STREAM:",
)
for ds in self._datasets
}
return describe

async def wait_for_index(
self, index: int, timeout: Optional[float] = DEFAULT_TIMEOUT
):
def matcher(value: int) -> bool:
return value // self._multiplier >= index

matcher.__name__ = f"index_at_least_{index}"
await wait_for_value(self.hdf.num_written, matcher, timeout=timeout)

async def get_indices_written(self) -> int:
num_written = await self.hdf.num_written.get_value()
return num_written // self._multiplier

async def collect_stream_docs(self, indices_written: int) -> AsyncIterator[Asset]:
# TODO: fail if we get dropped frames
await self.hdf.flush_now.set(True)
if indices_written:
if not self._file:
self._file = _HDFFile(
await self.hdf.full_file_name.get_value(), self._datasets
)
for doc in self._file.stream_resources():
yield "stream_resource", doc
for doc in self._file.stream_data(indices_written):
yield "stream_datum", doc

async def close(self):
# Already done a caput callback in _capture_status, so can't do one here
await self.hdf.capture.set(False, wait=False)
await wait_for_value(self.hdf.capture, False, DEFAULT_TIMEOUT)
if self._capture_status:
# We kicked off an open, so wait for it to return
await self._capture_status

@property
def hints(self) -> Hints:
return {"fields": [self._name_provider()]}
64 changes: 64 additions & 0 deletions src/ophyd_async/panda/writers/panda_hdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Iterator, List, Sequence

from event_model import StreamDatum, StreamResource, compose_stream_resource

from ophyd_async.core.device import Device
from ophyd_async.epics.signal.signal import epics_signal_r, epics_signal_rw


class PandaHDF(Device):
def __init__(self, prefix: str, name: str = "") -> None:
# Define some signals
self.file_path = epics_signal_rw(str, prefix + ":HDF5:FilePath")
self.file_name = epics_signal_rw(str, prefix + ":HDF5:FileName")
self.full_file_name = epics_signal_r(str, prefix + ":HDF5:FullFileName")
self.num_capture = epics_signal_rw(int, prefix + ":HDF5:NumCapture")
self.num_written = epics_signal_r(int, prefix + ":HDF5:NumWritten_RBV")
self.capture = epics_signal_rw(
bool, prefix + ":HDF5:Capturing", prefix + ":HDF5:Capture"
)
self.flush_now = epics_signal_rw(bool, prefix + ":HDF5:FlushNow")
super(PandaHDF, self).__init__(name)


@dataclass
class _HDFDataset:
name: str
path: str
shape: Sequence[int]
multiplier: int


class _HDFFile:
def __init__(self, full_file_name: str, datasets: List[_HDFDataset]) -> None:
self._last_emitted = 0
self._bundles = [
compose_stream_resource(
spec="AD_HDF5_SWMR_SLICE",
root="/",
data_key=ds.name,
resource_path=full_file_name,
resource_kwargs={
"path": ds.path,
"multiplier": ds.multiplier,
},
)
for ds in datasets
]

def stream_resources(self) -> Iterator[StreamResource]:
for bundle in self._bundles:
yield bundle.stream_resource_doc

def stream_data(self, indices_written: int) -> Iterator[StreamDatum]:
# Indices are relative to resource
if indices_written > self._last_emitted:
indices = dict(
start=self._last_emitted,
stop=indices_written,
)
self._last_emitted = indices_written
for bundle in self._bundles:
yield bundle.compose_stream_datum(indices)
return None
91 changes: 91 additions & 0 deletions tests/panda/test_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from unittest.mock import patch

import pytest
from bluesky.protocols import Descriptor

from ophyd_async.core import (
DeviceCollector,
StaticDirectoryProvider,
set_and_wait_for_value,
)
from ophyd_async.epics.signal.signal import SignalR, epics_signal_rw
from ophyd_async.panda.writers.hdf_writer import PandaHDFWriter
from ophyd_async.panda.writers.panda_hdf import PandaHDF


@pytest.fixture
async def sim_writer(tmp_path) -> PandaHDFWriter:
dir_prov = StaticDirectoryProvider(str(tmp_path), "test")
async with DeviceCollector(sim=True):
hdf = PandaHDF("TEST-PANDA")
writer = PandaHDFWriter(hdf, dir_prov, lambda: "test-panda")
return writer


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_open_returns_descriptors(sim_writer):
description = await sim_writer.open()
assert isinstance(description, dict)
for key, entry in description.items():
assert isinstance(key, str)
assert isinstance(entry, Descriptor)
assert "source" in entry
assert entry.get("dtype") == "number"
assert entry.get("external") == "STREAM:"


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_open_close_sets_capture(sim_writer):
return_val = await sim_writer.open()
assert isinstance(return_val, dict)
capturing = await sim_writer.hdf.capture.get_value()
assert capturing is True
await sim_writer.close()
capturing = await sim_writer.hdf.capture.get_value()
assert capturing is False


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_open_sets_file_path(sim_writer, tmp_path):
path = await sim_writer.hdf.file_path.get_value()
assert path == ""
await sim_writer.open()
path = await sim_writer.hdf.file_path.get_value()
assert path == str(tmp_path)
name = await sim_writer.hdf.file_name.get_value()
assert name == "test.h5"


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_get_indices_written(sim_writer):
written = await sim_writer.get_indices_written()
assert written == 0, f"{written} != 0"

async def get_twentyfive():
return 25

with patch("ophyd_async.core.SignalR.get_value", wraps=get_twentyfive):
written = await sim_writer.get_indices_written()
assert written == 25, f"{written} != 25"


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_wait_for_index(sim_writer):
assert type(sim_writer.hdf.num_written) is SignalR
# usually num_written is a SignalR so can't be set from ophyd,
# overload with SignalRW for testing
sim_writer.hdf.num_written = epics_signal_rw(int, "TEST-PANDA:HDF5:NumWritten")
await sim_writer.hdf.num_written.connect(sim=True)
await set_and_wait_for_value(sim_writer.hdf.num_written, 25)
assert (await sim_writer.hdf.num_written.get_value()) == 25
await sim_writer.wait_for_index(25, timeout=1)
with pytest.raises(TimeoutError):
await sim_writer.wait_for_index(27, timeout=1)


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_collect_stream_docs(sim_writer):
assert sim_writer._file is None

[item async for item in sim_writer.collect_stream_docs(1)]
assert sim_writer._file