Skip to content

Commit

Permalink
Merge pull request #547 from OpenCOMPES/flash-metadata-fixes
Browse files Browse the repository at this point in the history
Flash metadata fixes
  • Loading branch information
zain-sohail authored Jan 15, 2025
2 parents a99b945 + 25a935b commit 82fc11f
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 69 deletions.
75 changes: 75 additions & 0 deletions src/sed/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,78 @@ def complete_dictionary(dictionary: dict, base_dictionary: dict) -> dict:
dictionary[k] = v

return dictionary


def _parse_env_file(file_path: Path) -> dict:
"""Helper function to parse a .env file into a dictionary.
Args:
file_path (Path): Path to the .env file
Returns:
dict: Dictionary of environment variables from the file
"""
env_content = {}
if file_path.exists():
with open(file_path) as f:
for line in f:
line = line.strip()
if line and "=" in line:
key, val = line.split("=", 1)
env_content[key.strip()] = val.strip()
return env_content


def read_env_var(var_name: str) -> str | None:
"""Read an environment variable from multiple locations in order:
1. OS environment variables
2. .env file in current directory
3. .env file in user config directory
Args:
var_name (str): Name of the environment variable to read
Returns:
str | None: Value of the environment variable or None if not found
"""
# First check OS environment variables
value = os.getenv(var_name)
if value is not None:
logger.debug(f"Found {var_name} in OS environment variables")
return value

# Then check .env in current directory
local_vars = _parse_env_file(Path(".env"))
if var_name in local_vars:
logger.debug(f"Found {var_name} in ./.env file")
return local_vars[var_name]

# Finally check .env in user config directory
user_vars = _parse_env_file(USER_CONFIG_PATH / ".env")
if var_name in user_vars:
logger.debug(f"Found {var_name} in user config .env file")
return user_vars[var_name]

logger.debug(f"Environment variable {var_name} not found in any location")
return None


def save_env_var(var_name: str, value: str) -> None:
"""Save an environment variable to the .env file in the user config directory.
If the file exists, preserves other variables. If not, creates a new file.
Args:
var_name (str): Name of the environment variable to save
value (str): Value to save for the environment variable
"""
env_path = USER_CONFIG_PATH / ".env"
env_content = _parse_env_file(env_path)

# Update or add new variable
env_content[var_name] = value

# Write all variables back to file
with open(env_path, "w") as f:
for key, val in env_content.items():
f.write(f"{key}={val}\n")
logger.debug(f"Environment variable {var_name} saved to .env file")
2 changes: 0 additions & 2 deletions src/sed/core/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pydantic import HttpUrl
from pydantic import NewPath
from pydantic import PositiveInt
from pydantic import SecretStr

from sed.loader.loader_interface import get_names_of_all_loaders

Expand Down Expand Up @@ -323,7 +322,6 @@ class MetadataModel(BaseModel):
model_config = ConfigDict(extra="forbid")

archiver_url: Optional[HttpUrl] = None
token: Optional[SecretStr] = None
epics_pvs: Optional[Sequence[str]] = None
fa_in_channel: Optional[str] = None
fa_hor_channel: Optional[str] = None
Expand Down
39 changes: 18 additions & 21 deletions src/sed/loader/flash/buffer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from pathlib import Path
import time

import dask.dataframe as dd
import pyarrow.parquet as pq
Expand All @@ -20,7 +21,7 @@

DF_TYP = ["electron", "timed"]

logger = setup_logging(__name__)
logger = setup_logging("flash_buffer_handler")


class BufferFilePaths:
Expand Down Expand Up @@ -135,16 +136,15 @@ def __init__(
def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
"""
Checks the schema of the Parquet files.
Raises:
ValueError: If the schema of the Parquet files does not match the configuration.
"""
logger.debug(f"Checking schema for {len(files)} files")
existing = [file for file in files if file.exists()]
parquet_schemas = [pq.read_schema(file) for file in existing]

for filename, schema in zip(existing, parquet_schemas):
schema_set = set(schema.names)
if schema_set != expected_schema_set:
logger.error(f"Schema mismatch in file: {filename}")
missing_in_parquet = expected_schema_set - schema_set
missing_in_config = schema_set - expected_schema_set

Expand All @@ -159,6 +159,7 @@ def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
f"{' '.join(errors)}. "
"Please check the configuration file or set force_recreate to True.",
)
logger.debug("Schema check passed successfully")

def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame:
"""Creates the timed dataframe, optionally filtering by electron events.
Expand All @@ -185,35 +186,31 @@ def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame:
return df_timed.loc[:, :, 0]

def _save_buffer_file(self, paths: dict[str, Path]) -> None:
"""
Creates the electron and timed buffer files from the raw H5 file.
First the dataframe is accessed and forward filled in the non-electron channels.
Then the data types are set. For the electron dataframe, all values not in the electron
channels are dropped. For the timed dataframe, only the train and pulse channels are taken
and it pulse resolved (no longer electron resolved). Both are saved as parquet files.
Args:
paths (dict[str, Path]): Dictionary containing the paths to the H5 and buffer files.
"""
# Create a DataFrameCreator instance and get the h5 file
"""Creates the electron and timed buffer files from the raw H5 file."""
logger.debug(f"Processing file: {paths['raw'].stem}")
start_time = time.time()
# Create DataFrameCreator and get get dataframe
df = DataFrameCreator(config_dataframe=self._config, h5_path=paths["raw"]).df

# forward fill all the non-electron channels
# Forward fill non-electron channels
logger.debug(f"Forward filling {len(self.fill_channels)} channels")
df[self.fill_channels] = df[self.fill_channels].ffill()

# Save electron resolved dataframe
electron_channels = get_channels(self._config, "per_electron")
dtypes = get_dtypes(self._config, df.columns.values)
df.dropna(subset=electron_channels).astype(dtypes).reset_index().to_parquet(
paths["electron"],
)
electron_df = df.dropna(subset=electron_channels).astype(dtypes).reset_index()
logger.debug(f"Saving electron buffer with shape: {electron_df.shape}")
electron_df.to_parquet(paths["electron"])

# Create and save timed dataframe
df_timed = self._create_timed_dataframe(df)
dtypes = get_dtypes(self._config, df_timed.columns.values)
df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"])
timed_df = df_timed.astype(dtypes).reset_index()
logger.debug(f"Saving timed buffer with shape: {timed_df.shape}")
timed_df.to_parquet(paths["timed"])

logger.debug(f"Processed {paths['raw'].stem}")
logger.debug(f"Processed {paths['raw'].stem} in {time.time() - start_time:.2f}s")

def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
"""
Expand Down
20 changes: 15 additions & 5 deletions src/sed/loader/flash/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from sed.loader.flash.utils import get_channels
from sed.loader.flash.utils import InvalidFileError
from sed.core.logging import setup_logging

logger = setup_logging("flash_dataframe_creator")


class DataFrameCreator:
Expand All @@ -34,6 +37,7 @@ def __init__(self, config_dataframe: dict, h5_path: Path) -> None:
config_dataframe (dict): The configuration dictionary with only the dataframe key.
h5_path (Path): Path to the h5 file.
"""
logger.debug(f"Initializing DataFrameCreator for file: {h5_path}")
self.h5_file = h5py.File(h5_path, "r")
self.multi_index = get_channels(index=True)
self._config = config_dataframe
Expand Down Expand Up @@ -76,6 +80,7 @@ def get_dataset_array(
tuple[pd.Index, np.ndarray | h5py.Dataset]: A tuple containing the train ID
pd.Index and the channel's data.
"""
logger.debug(f"Getting dataset array for channel: {channel}")
# Get the data from the necessary h5 file and channel
index_key, dataset_key = self.get_index_dataset_key(channel)

Expand All @@ -85,6 +90,7 @@ def get_dataset_array(
if slice_:
slice_index = self._config["channels"][channel].get("slice", None)
if slice_index is not None:
logger.debug(f"Slicing dataset with index: {slice_index}")
dataset = np.take(dataset, slice_index, axis=1)
# If np_array is size zero, fill with NaNs, fill it with NaN values
# of the same shape as index
Expand Down Expand Up @@ -291,10 +297,14 @@ def df(self) -> pd.DataFrame:
Returns:
pd.DataFrame: The combined pandas DataFrame.
"""

logger.debug("Creating combined DataFrame")
self.validate_channel_keys()
# been tested with merge, join and concat
# concat offers best performance, almost 3 times faster

df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index()
# all the negative pulse values are dropped as they are invalid
return df[df.index.get_level_values("pulseId") >= 0]
logger.debug(f"Created DataFrame with shape: {df.shape}")

# Filter negative pulse values
df = df[df.index.get_level_values("pulseId") >= 0]
logger.debug(f"Filtered DataFrame shape: {df.shape}")

return df
15 changes: 9 additions & 6 deletions src/sed/loader/flash/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ def get_files_from_run_id( # type: ignore[override]
# Return the list of found files
return [str(file.resolve()) for file in files]

def parse_metadata(self, scicat_token: str = None) -> dict:
def parse_metadata(self, token: str = None) -> dict:
"""Uses the MetadataRetriever class to fetch metadata from scicat for each run.
Returns:
dict: Metadata dictionary
scicat_token (str, optional):: The scicat token to use for fetching metadata
token (str, optional):: The scicat token to use for fetching metadata
"""
metadata_retriever = MetadataRetriever(self._config["metadata"], scicat_token)
metadata_retriever = MetadataRetriever(self._config["metadata"], token)
metadata = metadata_retriever.get_metadata(
beamtime_id=self._config["core"]["beamtime_id"],
runs=self.runs,
Expand Down Expand Up @@ -329,7 +329,9 @@ def read_dataframe(
debug (bool, optional): Whether to run buffer creation in serial. Defaults to False.
remove_invalid_files (bool, optional): Whether to exclude invalid files.
Defaults to False.
scicat_token (str, optional): The scicat token to use for fetching metadata.
token (str, optional): The scicat token to use for fetching metadata. If provided,
will be saved to .env file for future use. If not provided, will check environment
variables when collect_metadata is True.
filter_timed_by_electron (bool, optional): When True, the timed dataframe will only
contain data points where valid electron events were detected. When False, all
timed data points are included regardless of electron detection. Defaults to True.
Expand All @@ -341,13 +343,14 @@ def read_dataframe(
Raises:
ValueError: If neither 'runs' nor 'files'/'raw_dir' is provided.
FileNotFoundError: If the conversion fails for some files or no data is available.
ValueError: If collect_metadata is True and no token is available.
"""
detector = kwds.pop("detector", "")
force_recreate = kwds.pop("force_recreate", False)
processed_dir = kwds.pop("processed_dir", None)
debug = kwds.pop("debug", False)
remove_invalid_files = kwds.pop("remove_invalid_files", False)
scicat_token = kwds.pop("scicat_token", None)
token = kwds.pop("token", None)
filter_timed_by_electron = kwds.pop("filter_timed_by_electron", True)

if len(kwds) > 0:
Expand Down Expand Up @@ -401,7 +404,7 @@ def read_dataframe(
if self.instrument == "wespe":
df, df_timed = wespe_convert(df, df_timed)

self.metadata.update(self.parse_metadata(scicat_token) if collect_metadata else {})
self.metadata.update(self.parse_metadata(token) if collect_metadata else {})
self.metadata.update(bh.metadata)

print(f"loading complete in {time.time() - t0: .2f} s")
Expand Down
Loading

0 comments on commit 82fc11f

Please sign in to comment.