From 299d04bd99a9fc573c0e9a8f6befe26c5d11b61c Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sat, 11 Jan 2025 00:07:44 +0100 Subject: [PATCH 01/10] fix metadata issues, and add more debug logs --- src/sed/loader/flash/buffer_handler.py | 62 +++++++++++++++-------- src/sed/loader/flash/dataframe.py | 20 ++++++-- src/sed/loader/flash/metadata.py | 40 ++++++++------- tests/loader/flash/test_flash_metadata.py | 16 +++--- 4 files changed, 86 insertions(+), 52 deletions(-) diff --git a/src/sed/loader/flash/buffer_handler.py b/src/sed/loader/flash/buffer_handler.py index a1e5a258..b7cab37e 100644 --- a/src/sed/loader/flash/buffer_handler.py +++ b/src/sed/loader/flash/buffer_handler.py @@ -2,6 +2,7 @@ import os from pathlib import Path +import time import dask.dataframe as dd import pyarrow.parquet as pq @@ -20,7 +21,7 @@ DF_TYP = ["electron", "timed"] -logger = setup_logging(__name__) +logger = setup_logging("flash_buffer_handler") class BufferFilePaths: @@ -135,16 +136,15 @@ def __init__( def _schema_check(self, files: list[Path], expected_schema_set: set) -> None: """ Checks the schema of the Parquet files. - - Raises: - ValueError: If the schema of the Parquet files does not match the configuration. """ + logger.debug(f"Checking schema for {len(files)} files") existing = [file for file in files if file.exists()] parquet_schemas = [pq.read_schema(file) for file in existing] for filename, schema in zip(existing, parquet_schemas): schema_set = set(schema.names) if schema_set != expected_schema_set: + logger.error(f"Schema mismatch in file: {filename}") missing_in_parquet = expected_schema_set - schema_set missing_in_config = schema_set - expected_schema_set @@ -159,6 +159,7 @@ def _schema_check(self, files: list[Path], expected_schema_set: set) -> None: f"{' '.join(errors)}. " "Please check the configuration file or set force_recreate to True.", ) + logger.debug("Schema check passed successfully") def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: """Creates the timed dataframe, optionally filtering by electron events. @@ -184,36 +185,57 @@ def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: # Take only first electron per event return df_timed.loc[:, :, 0] - def _save_buffer_file(self, paths: dict[str, Path]) -> None: - """ - Creates the electron and timed buffer files from the raw H5 file. - First the dataframe is accessed and forward filled in the non-electron channels. - Then the data types are set. For the electron dataframe, all values not in the electron - channels are dropped. For the timed dataframe, only the train and pulse channels are taken - and it pulse resolved (no longer electron resolved). Both are saved as parquet files. + def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: + """Creates the timed dataframe, optionally filtering by electron events. Args: - paths (dict[str, Path]): Dictionary containing the paths to the H5 and buffer files. + df (dd.DataFrame): The input dataframe containing all data + + Returns: + dd.DataFrame: The timed dataframe """ - # Create a DataFrameCreator instance and get the h5 file + # Get channels that should be in timed dataframe + timed_channels = self.fill_channels + + if self.filter_timed_by_electron: + # Get electron channels to use for filtering + electron_channels = get_channels(self._config, "per_electron") + # Filter rows where electron data exists + df_timed = df.dropna(subset=electron_channels)[timed_channels] + else: + # Take all timed data rows without filtering + df_timed = df[timed_channels] + + # Take only first electron per event + return df_timed.loc[:, :, 0] + + def _save_buffer_file(self, paths: dict[str, Path]) -> None: + """Creates the electron and timed buffer files from the raw H5 file.""" + logger.debug(f"Processing file: {paths['raw'].stem}") + start_time = time.time() + # Create DataFrameCreator and get get dataframe df = DataFrameCreator(config_dataframe=self._config, h5_path=paths["raw"]).df - # forward fill all the non-electron channels + # Forward fill non-electron channels + logger.debug(f"Forward filling {len(self.fill_channels)} channels") df[self.fill_channels] = df[self.fill_channels].ffill() # Save electron resolved dataframe electron_channels = get_channels(self._config, "per_electron") dtypes = get_dtypes(self._config, df.columns.values) - df.dropna(subset=electron_channels).astype(dtypes).reset_index().to_parquet( - paths["electron"], - ) + electron_df = df.dropna(subset=electron_channels).astype(dtypes).reset_index() + logger.debug(f"Saving electron buffer with shape: {electron_df.shape}") + electron_df.to_parquet(paths["electron"]) # Create and save timed dataframe df_timed = self._create_timed_dataframe(df) dtypes = get_dtypes(self._config, df_timed.columns.values) - df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"]) + timed_df = df_timed.astype(dtypes).reset_index() + logger.debug(f"Saving timed buffer with shape: {timed_df.shape}") + timed_df.to_parquet(paths["timed"]) - logger.debug(f"Processed {paths['raw'].stem}") + + logger.debug(f"Processed {paths['raw'].stem} in {time.time() - start_time:.2f}s") def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None: """ @@ -286,7 +308,6 @@ def _get_dataframes(self) -> None: config=self._config, ) self.metadata.update(meta) - def process_and_load_dataframe( self, h5_paths: list[Path], @@ -337,3 +358,4 @@ def process_and_load_dataframe( self._get_dataframes() return self.df["electron"], self.df["timed"] + diff --git a/src/sed/loader/flash/dataframe.py b/src/sed/loader/flash/dataframe.py index 6501c82a..b8a5bfcc 100644 --- a/src/sed/loader/flash/dataframe.py +++ b/src/sed/loader/flash/dataframe.py @@ -14,6 +14,9 @@ from sed.loader.flash.utils import get_channels from sed.loader.flash.utils import InvalidFileError +from sed.core.logging import setup_logging + +logger = setup_logging("flash_dataframe_creator") class DataFrameCreator: @@ -34,6 +37,7 @@ def __init__(self, config_dataframe: dict, h5_path: Path) -> None: config_dataframe (dict): The configuration dictionary with only the dataframe key. h5_path (Path): Path to the h5 file. """ + logger.debug(f"Initializing DataFrameCreator for file: {h5_path}") self.h5_file = h5py.File(h5_path, "r") self.multi_index = get_channels(index=True) self._config = config_dataframe @@ -76,6 +80,7 @@ def get_dataset_array( tuple[pd.Index, np.ndarray | h5py.Dataset]: A tuple containing the train ID pd.Index and the channel's data. """ + logger.debug(f"Getting dataset array for channel: {channel}") # Get the data from the necessary h5 file and channel index_key, dataset_key = self.get_index_dataset_key(channel) @@ -85,6 +90,7 @@ def get_dataset_array( if slice_: slice_index = self._config["channels"][channel].get("slice", None) if slice_index is not None: + logger.debug(f"Slicing dataset with index: {slice_index}") dataset = np.take(dataset, slice_index, axis=1) # If np_array is size zero, fill with NaNs, fill it with NaN values # of the same shape as index @@ -291,10 +297,14 @@ def df(self) -> pd.DataFrame: Returns: pd.DataFrame: The combined pandas DataFrame. """ - + logger.debug("Creating combined DataFrame") self.validate_channel_keys() - # been tested with merge, join and concat - # concat offers best performance, almost 3 times faster + df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index() - # all the negative pulse values are dropped as they are invalid - return df[df.index.get_level_values("pulseId") >= 0] + logger.debug(f"Created DataFrame with shape: {df.shape}") + + # Filter negative pulse values + df = df[df.index.get_level_values("pulseId") >= 0] + logger.debug(f"Filtered DataFrame shape: {df.shape}") + + return df diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index 50fd69b1..2317b006 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -7,6 +7,9 @@ import warnings import requests +from sed.core.logging import setup_logging + +logger = setup_logging("flash_metadata_retriever") class MetadataRetriever: @@ -15,19 +18,19 @@ class MetadataRetriever: on beamtime and run IDs. """ - def __init__(self, metadata_config: dict, scicat_token: str = None) -> None: + def __init__(self, metadata_config: dict, token: str = None) -> None: """ Initializes the MetadataRetriever class. Args: metadata_config (dict): Takes a dict containing at least url, and optionally token for the scicat instance. - scicat_token (str, optional): The token to use for fetching metadata. + token (str, optional): The token to use for fetching metadata. """ - self.token = metadata_config.get("scicat_token", None) - if scicat_token: - self.token = scicat_token - self.url = metadata_config.get("scicat_url", None) + self.token = metadata_config.get("token", None) + if token: + self.token = token + self.url = metadata_config.get("archiver_url", None) if not self.token or not self.url: raise ValueError("No URL or token provided for fetching metadata from scicat.") @@ -36,7 +39,7 @@ def __init__(self, metadata_config: dict, scicat_token: str = None) -> None: "Content-Type": "application/json", "Accept": "application/json", } - self.token = metadata_config["scicat_token"] + self.token = metadata_config["token"] def get_metadata( self, @@ -59,19 +62,18 @@ def get_metadata( Raises: Exception: If the request to retrieve metadata fails. """ - # If metadata is not provided, initialize it as an empty dictionary + logger.debug(f"Fetching metadata for beamtime {beamtime_id}, runs: {runs}") + if metadata is None: metadata = {} - # Iterate over the list of runs for run in runs: pid = f"{beamtime_id}/{run}" - # Retrieve metadata for each run and update the overall metadata dictionary + logger.debug(f"Retrieving metadata for PID: {pid}") metadata_run = self._get_metadata_per_run(pid) - metadata.update( - metadata_run, - ) # TODO: Not correct for multiple runs + metadata.update(metadata_run) + logger.debug(f"Retrieved metadata with {len(metadata)} entries") return metadata def _get_metadata_per_run(self, pid: str) -> dict: @@ -91,26 +93,26 @@ def _get_metadata_per_run(self, pid: str) -> dict: headers2["Authorization"] = f"Bearer {self.token}" try: + logger.debug(f"Attempting to fetch metadata with new URL format for PID: {pid}") dataset_response = requests.get( self._create_new_dataset_url(pid), headers=headers2, timeout=10, ) dataset_response.raise_for_status() - # Check if response is an empty object because wrong url for older implementation + if not dataset_response.content: + logger.debug("Empty response, trying old URL format") dataset_response = requests.get( self._create_old_dataset_url(pid), headers=headers2, timeout=10, ) - # If the dataset request is successful, return the retrieved metadata - # as a JSON object return dataset_response.json() + except requests.exceptions.RequestException as exception: - # If the request fails, raise warning - print(warnings.warn(f"Failed to retrieve metadata for PID {pid}: {str(exception)}")) - return {} # Return an empty dictionary for this run + logger.warning(f"Failed to retrieve metadata for PID {pid}: {str(exception)}") + return {} def _create_old_dataset_url(self, pid: str) -> str: return "{burl}/{url}/%2F{npid}".format( diff --git a/tests/loader/flash/test_flash_metadata.py b/tests/loader/flash/test_flash_metadata.py index 5b30b0da..84ec3900 100644 --- a/tests/loader/flash/test_flash_metadata.py +++ b/tests/loader/flash/test_flash_metadata.py @@ -16,8 +16,8 @@ def mock_requests(requests_mock) -> None: # Test cases for MetadataRetriever def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", + "token": "fake_token", } retriever = MetadataRetriever(metadata_config) metadata = retriever.get_metadata("11013410", ["43878"]) @@ -27,8 +27,8 @@ def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001 def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", + "token": "fake_token", } retriever = MetadataRetriever(metadata_config) existing_metadata = {"existing": "metadata"} @@ -39,8 +39,8 @@ def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # no def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", + "token": "fake_token", } retriever = MetadataRetriever(metadata_config) metadata = retriever._get_metadata_per_run("11013410/43878") @@ -50,8 +50,8 @@ def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001 def test_create_dataset_url_by_PID() -> None: metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", + "token": "fake_token", } retriever = MetadataRetriever(metadata_config) # Assuming the dataset follows the new format From e2dfcc5efaff7991fa7ff783802c6d77f94dbc76 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sat, 11 Jan 2025 00:43:01 +0100 Subject: [PATCH 02/10] fix lint errors --- src/sed/loader/flash/buffer_handler.py | 3 +-- src/sed/loader/flash/dataframe.py | 6 +++--- src/sed/loader/flash/metadata.py | 6 +++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/sed/loader/flash/buffer_handler.py b/src/sed/loader/flash/buffer_handler.py index b7cab37e..d27ae27f 100644 --- a/src/sed/loader/flash/buffer_handler.py +++ b/src/sed/loader/flash/buffer_handler.py @@ -234,7 +234,6 @@ def _save_buffer_file(self, paths: dict[str, Path]) -> None: logger.debug(f"Saving timed buffer with shape: {timed_df.shape}") timed_df.to_parquet(paths["timed"]) - logger.debug(f"Processed {paths['raw'].stem} in {time.time() - start_time:.2f}s") def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None: @@ -308,6 +307,7 @@ def _get_dataframes(self) -> None: config=self._config, ) self.metadata.update(meta) + def process_and_load_dataframe( self, h5_paths: list[Path], @@ -358,4 +358,3 @@ def process_and_load_dataframe( self._get_dataframes() return self.df["electron"], self.df["timed"] - diff --git a/src/sed/loader/flash/dataframe.py b/src/sed/loader/flash/dataframe.py index b8a5bfcc..f50abe10 100644 --- a/src/sed/loader/flash/dataframe.py +++ b/src/sed/loader/flash/dataframe.py @@ -299,12 +299,12 @@ def df(self) -> pd.DataFrame: """ logger.debug("Creating combined DataFrame") self.validate_channel_keys() - + df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index() logger.debug(f"Created DataFrame with shape: {df.shape}") - + # Filter negative pulse values df = df[df.index.get_level_values("pulseId") >= 0] logger.debug(f"Filtered DataFrame shape: {df.shape}") - + return df diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index 2317b006..f5162850 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -63,7 +63,7 @@ def get_metadata( Exception: If the request to retrieve metadata fails. """ logger.debug(f"Fetching metadata for beamtime {beamtime_id}, runs: {runs}") - + if metadata is None: metadata = {} @@ -100,7 +100,7 @@ def _get_metadata_per_run(self, pid: str) -> dict: timeout=10, ) dataset_response.raise_for_status() - + if not dataset_response.content: logger.debug("Empty response, trying old URL format") dataset_response = requests.get( @@ -109,7 +109,7 @@ def _get_metadata_per_run(self, pid: str) -> dict: timeout=10, ) return dataset_response.json() - + except requests.exceptions.RequestException as exception: logger.warning(f"Failed to retrieve metadata for PID {pid}: {str(exception)}") return {} From 24df1f0ecaf9fa6fca4b47b933a84495a63f2f89 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sat, 11 Jan 2025 00:50:47 +0100 Subject: [PATCH 03/10] remove repitition --- src/sed/loader/flash/buffer_handler.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/sed/loader/flash/buffer_handler.py b/src/sed/loader/flash/buffer_handler.py index d27ae27f..d56de29f 100644 --- a/src/sed/loader/flash/buffer_handler.py +++ b/src/sed/loader/flash/buffer_handler.py @@ -185,30 +185,6 @@ def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: # Take only first electron per event return df_timed.loc[:, :, 0] - def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: - """Creates the timed dataframe, optionally filtering by electron events. - - Args: - df (dd.DataFrame): The input dataframe containing all data - - Returns: - dd.DataFrame: The timed dataframe - """ - # Get channels that should be in timed dataframe - timed_channels = self.fill_channels - - if self.filter_timed_by_electron: - # Get electron channels to use for filtering - electron_channels = get_channels(self._config, "per_electron") - # Filter rows where electron data exists - df_timed = df.dropna(subset=electron_channels)[timed_channels] - else: - # Take all timed data rows without filtering - df_timed = df[timed_channels] - - # Take only first electron per event - return df_timed.loc[:, :, 0] - def _save_buffer_file(self, paths: dict[str, Path]) -> None: """Creates the electron and timed buffer files from the raw H5 file.""" logger.debug(f"Processing file: {paths['raw'].stem}") From 5eba2e3dc2a858ce4fa5ee4c066dffe44897ecd0 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sat, 11 Jan 2025 00:51:29 +0100 Subject: [PATCH 04/10] remove warnings --- src/sed/loader/flash/metadata.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index f5162850..293ad2cc 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -4,8 +4,6 @@ """ from __future__ import annotations -import warnings - import requests from sed.core.logging import setup_logging From d66d34ddf1d9bc7c57f9eead655f397a8752cdb4 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sun, 12 Jan 2025 18:20:41 +0100 Subject: [PATCH 05/10] update token handling with env variables --- .cspell/custom-dictionary.txt | 1 + src/sed/core/config_model.py | 2 -- src/sed/loader/flash/loader.py | 15 ++++++++----- src/sed/loader/flash/metadata.py | 38 ++++++++++++++++++++++++-------- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/.cspell/custom-dictionary.txt b/.cspell/custom-dictionary.txt index 264069be..6391f3a0 100644 --- a/.cspell/custom-dictionary.txt +++ b/.cspell/custom-dictionary.txt @@ -89,6 +89,7 @@ dictionarized dictmerge DLDAUX DOOCS +dotenv dpkg dropna dset diff --git a/src/sed/core/config_model.py b/src/sed/core/config_model.py index 880e5864..6379b639 100644 --- a/src/sed/core/config_model.py +++ b/src/sed/core/config_model.py @@ -14,7 +14,6 @@ from pydantic import HttpUrl from pydantic import NewPath from pydantic import PositiveInt -from pydantic import SecretStr from sed.loader.loader_interface import get_names_of_all_loaders @@ -323,7 +322,6 @@ class MetadataModel(BaseModel): model_config = ConfigDict(extra="forbid") archiver_url: Optional[HttpUrl] = None - token: Optional[SecretStr] = None epics_pvs: Optional[Sequence[str]] = None fa_in_channel: Optional[str] = None fa_hor_channel: Optional[str] = None diff --git a/src/sed/loader/flash/loader.py b/src/sed/loader/flash/loader.py index b1d4ff81..c2cf79b9 100644 --- a/src/sed/loader/flash/loader.py +++ b/src/sed/loader/flash/loader.py @@ -207,14 +207,14 @@ def get_files_from_run_id( # type: ignore[override] # Return the list of found files return [str(file.resolve()) for file in files] - def parse_metadata(self, scicat_token: str = None) -> dict: + def parse_metadata(self, token: str = None) -> dict: """Uses the MetadataRetriever class to fetch metadata from scicat for each run. Returns: dict: Metadata dictionary - scicat_token (str, optional):: The scicat token to use for fetching metadata + token (str, optional):: The scicat token to use for fetching metadata """ - metadata_retriever = MetadataRetriever(self._config["metadata"], scicat_token) + metadata_retriever = MetadataRetriever(self._config["metadata"], token) metadata = metadata_retriever.get_metadata( beamtime_id=self._config["core"]["beamtime_id"], runs=self.runs, @@ -329,7 +329,9 @@ def read_dataframe( debug (bool, optional): Whether to run buffer creation in serial. Defaults to False. remove_invalid_files (bool, optional): Whether to exclude invalid files. Defaults to False. - scicat_token (str, optional): The scicat token to use for fetching metadata. + token (str, optional): The scicat token to use for fetching metadata. If provided, + will be saved to .env file for future use. If not provided, will check environment + variables when collect_metadata is True. filter_timed_by_electron (bool, optional): When True, the timed dataframe will only contain data points where valid electron events were detected. When False, all timed data points are included regardless of electron detection. Defaults to True. @@ -341,13 +343,14 @@ def read_dataframe( Raises: ValueError: If neither 'runs' nor 'files'/'raw_dir' is provided. FileNotFoundError: If the conversion fails for some files or no data is available. + ValueError: If collect_metadata is True and no token is available. """ detector = kwds.pop("detector", "") force_recreate = kwds.pop("force_recreate", False) processed_dir = kwds.pop("processed_dir", None) debug = kwds.pop("debug", False) remove_invalid_files = kwds.pop("remove_invalid_files", False) - scicat_token = kwds.pop("scicat_token", None) + token = kwds.pop("token", None) filter_timed_by_electron = kwds.pop("filter_timed_by_electron", True) if len(kwds) > 0: @@ -401,7 +404,7 @@ def read_dataframe( if self.instrument == "wespe": df, df_timed = wespe_convert(df, df_timed) - self.metadata.update(self.parse_metadata(scicat_token) if collect_metadata else {}) + self.metadata.update(self.parse_metadata(token) if collect_metadata else {}) self.metadata.update(bh.metadata) print(f"loading complete in {time.time() - t0: .2f} s") diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index 293ad2cc..d7279e43 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -4,7 +4,13 @@ """ from __future__ import annotations +import os +from pathlib import Path + import requests +from dotenv import load_dotenv +from dotenv import set_key + from sed.core.logging import setup_logging logger = setup_logging("flash_metadata_retriever") @@ -21,23 +27,37 @@ def __init__(self, metadata_config: dict, token: str = None) -> None: Initializes the MetadataRetriever class. Args: - metadata_config (dict): Takes a dict containing - at least url, and optionally token for the scicat instance. - token (str, optional): The token to use for fetching metadata. + metadata_config (dict): Takes a dict containing at least url for the scicat instance. + token (str, optional): The token to use for fetching metadata. If provided, + will be saved to .env file for future use. """ - self.token = metadata_config.get("token", None) + # Token handling if token: - self.token = token - self.url = metadata_config.get("archiver_url", None) + # Save token to .env file in user's home directory + env_path = Path.home() / ".sed" / ".env" + env_path.parent.mkdir(parents=True, exist_ok=True) + set_key(str(env_path), "SCICAT_TOKEN", token) + else: + # Try to load token from config or environment + self.token = metadata_config.get("token") + if not self.token: + load_dotenv(Path.home() / ".sed" / ".env") + self.token = os.getenv("SCICAT_TOKEN") + + if not self.token: + raise ValueError( + "Token is required for metadata collection. Either provide a token " + "parameter or set the SCICAT_TOKEN environment variable.", + ) - if not self.token or not self.url: - raise ValueError("No URL or token provided for fetching metadata from scicat.") + self.url = metadata_config.get("archiver_url") + if not self.url: + raise ValueError("No URL provided for fetching metadata from scicat.") self.headers = { "Content-Type": "application/json", "Accept": "application/json", } - self.token = metadata_config["token"] def get_metadata( self, From 9a7a8b03d2e3e0aa5f092df8dd234f21843c6917 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sun, 12 Jan 2025 18:31:55 +0100 Subject: [PATCH 06/10] update test --- tests/loader/flash/test_flash_metadata.py | 56 ++++++++++++++++++++--- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/tests/loader/flash/test_flash_metadata.py b/tests/loader/flash/test_flash_metadata.py index 84ec3900..2f30537c 100644 --- a/tests/loader/flash/test_flash_metadata.py +++ b/tests/loader/flash/test_flash_metadata.py @@ -1,6 +1,8 @@ """Tests for FlashLoader metadata functionality""" from __future__ import annotations +from pathlib import Path + import pytest from sed.loader.flash.metadata import MetadataRetriever @@ -13,11 +15,38 @@ def mock_requests(requests_mock) -> None: requests_mock.get(dataset_url, json={"fake": "data"}, status_code=200) -# Test cases for MetadataRetriever -def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001 +@pytest.fixture +def mock_env_token(monkeypatch, tmp_path) -> None: + # Create a temporary .env file + env_path = tmp_path / ".env" + env_path.write_text("SCICAT_TOKEN=env_test_token") + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + +def test_get_metadata_with_explicit_token(mock_requests: None) -> None: # noqa: ARG001 + metadata_config = { + "archiver_url": "https://example.com", + } + retriever = MetadataRetriever(metadata_config, token="explicit_test_token") + metadata = retriever.get_metadata("11013410", ["43878"]) + assert isinstance(metadata, dict) + assert metadata == {"fake": "data"} + + +def test_get_metadata_with_config_token(mock_requests: None) -> None: # noqa: ARG001 + metadata_config = { + "archiver_url": "https://example.com", + "token": "config_test_token", + } + retriever = MetadataRetriever(metadata_config) + metadata = retriever.get_metadata("11013410", ["43878"]) + assert isinstance(metadata, dict) + assert metadata == {"fake": "data"} + + +def test_get_metadata_with_env_token(mock_requests: None, mock_env_token: None) -> None: # noqa: ARG001 metadata_config = { "archiver_url": "https://example.com", - "token": "fake_token", } retriever = MetadataRetriever(metadata_config) metadata = retriever.get_metadata("11013410", ["43878"]) @@ -25,10 +54,24 @@ def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001 assert metadata == {"fake": "data"} +def test_get_metadata_no_token() -> None: + metadata_config = { + "archiver_url": "https://example.com", + } + with pytest.raises(ValueError, match="Token is required for metadata collection"): + MetadataRetriever(metadata_config) + + +def test_get_metadata_no_url() -> None: + metadata_config: dict = {} + with pytest.raises(ValueError, match="No URL provided for fetching metadata"): + MetadataRetriever(metadata_config, token="test_token") + + def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { "archiver_url": "https://example.com", - "token": "fake_token", + "token": "test_token", } retriever = MetadataRetriever(metadata_config) existing_metadata = {"existing": "metadata"} @@ -40,7 +83,7 @@ def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # no def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { "archiver_url": "https://example.com", - "token": "fake_token", + "token": "test_token", } retriever = MetadataRetriever(metadata_config) metadata = retriever._get_metadata_per_run("11013410/43878") @@ -51,10 +94,9 @@ def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001 def test_create_dataset_url_by_PID() -> None: metadata_config = { "archiver_url": "https://example.com", - "token": "fake_token", + "token": "test_token", } retriever = MetadataRetriever(metadata_config) - # Assuming the dataset follows the new format pid = "11013410/43878" url = retriever._create_new_dataset_url(pid) expected_url = "https://example.com/Datasets/11013410%2F43878" From 9d22be7234d368e64354e36d81ce8e1b411bb058 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sun, 12 Jan 2025 18:47:49 +0100 Subject: [PATCH 07/10] add the dotenv package --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6a852d9d..168d279d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "joblib>=1.2.0", "pyarrow>=14.0.1,<17.0", "pydantic>=2.8.2", + "python-dotenv>=1.0.1", ] [project.urls] From 80ef9c165b08d2aecb876e663279232097ffafe2 Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Sun, 12 Jan 2025 23:47:57 +0100 Subject: [PATCH 08/10] read write env variables without extra package, tests added --- pyproject.toml | 1 - src/sed/core/config.py | 51 +++++++++++++++++++++++ src/sed/loader/flash/metadata.py | 20 +++------ tests/loader/flash/test_flash_metadata.py | 46 +++++++++----------- tests/test_config.py | 47 +++++++++++++++++++++ 5 files changed, 123 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 168d279d..6a852d9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ dependencies = [ "joblib>=1.2.0", "pyarrow>=14.0.1,<17.0", "pydantic>=2.8.2", - "python-dotenv>=1.0.1", ] [project.urls] diff --git a/src/sed/core/config.py b/src/sed/core/config.py index 647b1c33..a870e536 100644 --- a/src/sed/core/config.py +++ b/src/sed/core/config.py @@ -242,3 +242,54 @@ def complete_dictionary(dictionary: dict, base_dictionary: dict) -> dict: dictionary[k] = v return dictionary + + +def read_env_var(var_name: str) -> str | None: + """Read an environment variable from the .env file in the user config directory. + + Args: + var_name (str): Name of the environment variable to read + + Returns: + str | None: Value of the environment variable or None if not found + """ + env_path = USER_CONFIG_PATH / ".env" + if not env_path.exists(): + logger.debug(f"Environment variable {var_name} not found in .env file") + return None + + with open(env_path) as f: + for line in f: + if line.startswith(f"{var_name}="): + return line.strip().split("=", 1)[1] + logger.debug(f"Environment variable {var_name} not found in .env file") + return None + + +def save_env_var(var_name: str, value: str) -> None: + """Save an environment variable to the .env file in the user config directory. + If the file exists, preserves other variables. If not, creates a new file. + + Args: + var_name (str): Name of the environment variable to save + value (str): Value to save for the environment variable + """ + env_path = USER_CONFIG_PATH / ".env" + env_content = {} + + # Read existing variables if file exists + if env_path.exists(): + with open(env_path) as f: + for line in f: + if "=" in line: + key, val = line.strip().split("=", 1) + env_content[key] = val + + # Update or add new variable + env_content[var_name] = value + + # Write all variables back to file + with open(env_path, "w") as f: + for key, val in env_content.items(): + f.write(f"{key}={val}\n") + logger.debug(f"Environment variable {var_name} saved to .env file") diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index d7279e43..9703907b 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -4,13 +4,10 @@ """ from __future__ import annotations -import os -from pathlib import Path - import requests -from dotenv import load_dotenv -from dotenv import set_key +from sed.core.config import read_env_var +from sed.core.config import save_env_var from sed.core.logging import setup_logging logger = setup_logging("flash_metadata_retriever") @@ -33,16 +30,11 @@ def __init__(self, metadata_config: dict, token: str = None) -> None: """ # Token handling if token: - # Save token to .env file in user's home directory - env_path = Path.home() / ".sed" / ".env" - env_path.parent.mkdir(parents=True, exist_ok=True) - set_key(str(env_path), "SCICAT_TOKEN", token) + self.token = token + save_env_var("SCICAT_TOKEN", self.token) else: - # Try to load token from config or environment - self.token = metadata_config.get("token") - if not self.token: - load_dotenv(Path.home() / ".sed" / ".env") - self.token = os.getenv("SCICAT_TOKEN") + # Try to load token from config or .env file + self.token = read_env_var("SCICAT_TOKEN") if not self.token: raise ValueError( diff --git a/tests/loader/flash/test_flash_metadata.py b/tests/loader/flash/test_flash_metadata.py index 2f30537c..165a6e8a 100644 --- a/tests/loader/flash/test_flash_metadata.py +++ b/tests/loader/flash/test_flash_metadata.py @@ -1,12 +1,17 @@ """Tests for FlashLoader metadata functionality""" from __future__ import annotations -from pathlib import Path +import os import pytest +from sed.core.config import read_env_var +from sed.core.config import save_env_var +from sed.core.config import USER_CONFIG_PATH from sed.loader.flash.metadata import MetadataRetriever +ENV_PATH = USER_CONFIG_PATH / ".env" + @pytest.fixture def mock_requests(requests_mock) -> None: @@ -15,14 +20,6 @@ def mock_requests(requests_mock) -> None: requests_mock.get(dataset_url, json={"fake": "data"}, status_code=200) -@pytest.fixture -def mock_env_token(monkeypatch, tmp_path) -> None: - # Create a temporary .env file - env_path = tmp_path / ".env" - env_path.write_text("SCICAT_TOKEN=env_test_token") - monkeypatch.setattr(Path, "home", lambda: tmp_path) - - def test_get_metadata_with_explicit_token(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { "archiver_url": "https://example.com", @@ -31,20 +28,13 @@ def test_get_metadata_with_explicit_token(mock_requests: None) -> None: # noqa: metadata = retriever.get_metadata("11013410", ["43878"]) assert isinstance(metadata, dict) assert metadata == {"fake": "data"} + assert ENV_PATH.exists() + assert read_env_var("SCICAT_TOKEN") == "explicit_test_token" + os.remove(ENV_PATH) -def test_get_metadata_with_config_token(mock_requests: None) -> None: # noqa: ARG001 - metadata_config = { - "archiver_url": "https://example.com", - "token": "config_test_token", - } - retriever = MetadataRetriever(metadata_config) - metadata = retriever.get_metadata("11013410", ["43878"]) - assert isinstance(metadata, dict) - assert metadata == {"fake": "data"} - - -def test_get_metadata_with_env_token(mock_requests: None, mock_env_token: None) -> None: # noqa: ARG001 +def test_get_metadata_with_env_token(mock_requests: None) -> None: # noqa: ARG001 + save_env_var("SCICAT_TOKEN", "env_test_token") metadata_config = { "archiver_url": "https://example.com", } @@ -52,6 +42,7 @@ def test_get_metadata_with_env_token(mock_requests: None, mock_env_token: None) metadata = retriever.get_metadata("11013410", ["43878"]) assert isinstance(metadata, dict) assert metadata == {"fake": "data"} + os.remove(ENV_PATH) def test_get_metadata_no_token() -> None: @@ -66,39 +57,40 @@ def test_get_metadata_no_url() -> None: metadata_config: dict = {} with pytest.raises(ValueError, match="No URL provided for fetching metadata"): MetadataRetriever(metadata_config, token="test_token") + os.remove(ENV_PATH) def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { "archiver_url": "https://example.com", - "token": "test_token", } - retriever = MetadataRetriever(metadata_config) + retriever = MetadataRetriever(metadata_config, token="test_token") existing_metadata = {"existing": "metadata"} metadata = retriever.get_metadata("11013410", ["43878"], existing_metadata) assert isinstance(metadata, dict) assert metadata == {"existing": "metadata", "fake": "data"} + os.remove(ENV_PATH) def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { "archiver_url": "https://example.com", - "token": "test_token", } - retriever = MetadataRetriever(metadata_config) + retriever = MetadataRetriever(metadata_config, token="test_token") metadata = retriever._get_metadata_per_run("11013410/43878") assert isinstance(metadata, dict) assert metadata == {"fake": "data"} + os.remove(ENV_PATH) def test_create_dataset_url_by_PID() -> None: metadata_config = { "archiver_url": "https://example.com", - "token": "test_token", } - retriever = MetadataRetriever(metadata_config) + retriever = MetadataRetriever(metadata_config, token="test_token") pid = "11013410/43878" url = retriever._create_new_dataset_url(pid) expected_url = "https://example.com/Datasets/11013410%2F43878" assert isinstance(url, str) assert url == expected_url + os.remove(ENV_PATH) diff --git a/tests/test_config.py b/tests/test_config.py index 7f8a43b8..321b35d6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,9 @@ from sed.core.config import complete_dictionary from sed.core.config import load_config from sed.core.config import parse_config +from sed.core.config import read_env_var from sed.core.config import save_config +from sed.core.config import save_env_var test_dir = os.path.dirname(__file__) test_config_dir = Path(f"{test_dir}/data/loader/") @@ -231,3 +233,48 @@ def test_invalid_config_wrong_values(): verify_config=True, ) assert "Invalid value 9999 for gid. Group not found." in str(e.value) + + +def test_env_var_read_write(tmp_path, monkeypatch): + """Test reading and writing environment variables.""" + # Mock USER_CONFIG_PATH to use a temporary directory + monkeypatch.setattr("sed.core.config.USER_CONFIG_PATH", tmp_path) + + # Test writing a new variable + save_env_var("TEST_VAR", "test_value") + assert read_env_var("TEST_VAR") == "test_value" + + # Test writing multiple variables + save_env_var("TEST_VAR2", "test_value2") + assert read_env_var("TEST_VAR") == "test_value" + assert read_env_var("TEST_VAR2") == "test_value2" + + # Test overwriting an existing variable + save_env_var("TEST_VAR", "new_value") + assert read_env_var("TEST_VAR") == "new_value" + assert read_env_var("TEST_VAR2") == "test_value2" # Other variables unchanged + + # Test reading non-existent variable + assert read_env_var("NON_EXISTENT_VAR") is None + + +def test_env_var_read_no_file(tmp_path, monkeypatch): + """Test reading environment variables when .env file doesn't exist.""" + # Mock USER_CONFIG_PATH to use an empty temporary directory + monkeypatch.setattr("sed.core.config.USER_CONFIG_PATH", tmp_path) + + # Test reading from non-existent file + assert read_env_var("TEST_VAR") is None + + +def test_env_var_special_characters(): + """Test reading and writing environment variables with special characters.""" + test_cases = { + "TEST_URL": "http://example.com/path?query=value", + "TEST_PATH": "/path/to/something/with/spaces and special=chars", + "TEST_QUOTES": "value with 'single' and \"double\" quotes", + } + + for var_name, value in test_cases.items(): + save_env_var(var_name, value) + assert read_env_var(var_name) == value From 9b13c991ecc9a481c3d579fb928dadba2dbc6faf Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Tue, 14 Jan 2025 14:02:59 +0100 Subject: [PATCH 09/10] search for .env in cwd and os env variables --- src/sed/core/config.py | 64 ++++++++++++++++++++++----------- tests/test_config.py | 81 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 20 deletions(-) diff --git a/src/sed/core/config.py b/src/sed/core/config.py index a870e536..328353e4 100644 --- a/src/sed/core/config.py +++ b/src/sed/core/config.py @@ -244,8 +244,31 @@ def complete_dictionary(dictionary: dict, base_dictionary: dict) -> dict: return dictionary +def _parse_env_file(file_path: Path) -> dict: + """Helper function to parse a .env file into a dictionary. + + Args: + file_path (Path): Path to the .env file + + Returns: + dict: Dictionary of environment variables from the file + """ + env_content = {} + if file_path.exists(): + with open(file_path) as f: + for line in f: + line = line.strip() + if line and "=" in line: + key, val = line.split("=", 1) + env_content[key.strip()] = val.strip() + return env_content + + def read_env_var(var_name: str) -> str | None: - """Read an environment variable from the .env file in the user config directory. + """Read an environment variable from multiple locations in order: + 1. OS environment variables + 2. .env file in current directory + 3. .env file in user config directory Args: var_name (str): Name of the environment variable to read @@ -253,16 +276,25 @@ def read_env_var(var_name: str) -> str | None: Returns: str | None: Value of the environment variable or None if not found """ - env_path = USER_CONFIG_PATH / ".env" - if not env_path.exists(): - logger.debug(f"Environment variable {var_name} not found in .env file") - return None - - with open(env_path) as f: - for line in f: - if line.startswith(f"{var_name}="): - return line.strip().split("=", 1)[1] - logger.debug(f"Environment variable {var_name} not found in .env file") + # First check OS environment variables + value = os.getenv(var_name) + if value is not None: + logger.debug(f"Found {var_name} in OS environment variables") + return value + + # Then check .env in current directory + local_vars = _parse_env_file(Path(".env")) + if var_name in local_vars: + logger.debug(f"Found {var_name} in ./.env file") + return local_vars[var_name] + + # Finally check .env in user config directory + user_vars = _parse_env_file(USER_CONFIG_PATH / ".env") + if var_name in user_vars: + logger.debug(f"Found {var_name} in user config .env file") + return user_vars[var_name] + + logger.debug(f"Environment variable {var_name} not found in any location") return None @@ -275,15 +307,7 @@ def save_env_var(var_name: str, value: str) -> None: value (str): Value to save for the environment variable """ env_path = USER_CONFIG_PATH / ".env" - env_content = {} - - # Read existing variables if file exists - if env_path.exists(): - with open(env_path) as f: - for line in f: - if "=" in line: - key, val = line.strip().split("=", 1) - env_content[key] = val + env_content = _parse_env_file(env_path) # Update or add new variable env_content[var_name] = value diff --git a/tests/test_config.py b/tests/test_config.py index 321b35d6..75d780c6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -16,6 +16,7 @@ from sed.core.config import read_env_var from sed.core.config import save_config from sed.core.config import save_env_var +from sed.core.config import USER_CONFIG_PATH test_dir = os.path.dirname(__file__) test_config_dir = Path(f"{test_dir}/data/loader/") @@ -278,3 +279,83 @@ def test_env_var_special_characters(): for var_name, value in test_cases.items(): save_env_var(var_name, value) assert read_env_var(var_name) == value + + +@pytest.fixture +def cleanup_env_files(): + """Cleanup any .env files before and after tests""" + # Clean up any existing .env files + for path in [Path(".env"), USER_CONFIG_PATH / ".env"]: + if path.exists(): + path.unlink() + + yield + + # Clean up after tests + for path in [Path(".env"), USER_CONFIG_PATH / ".env"]: + if path.exists(): + path.unlink() + + +def test_env_var_precedence(cleanup_env_files): # noqa: ARG001 + """Test that environment variables are read in correct order of precedence""" + # Set up test values in different locations + os.environ["TEST_VAR"] = "os_value" + + with open(".env", "w") as f: + f.write("TEST_VAR=local_value\n") + + save_env_var("TEST_VAR", "user_value") # Saves to USER_CONFIG_PATH + + # Should get OS value first + assert read_env_var("TEST_VAR") == "os_value" + + # Remove from OS env and should get local value + del os.environ["TEST_VAR"] + assert read_env_var("TEST_VAR") == "local_value" + + # Remove local .env and should get user config value + Path(".env").unlink() + assert read_env_var("TEST_VAR") == "user_value" + + # Remove user config and should get None + (USER_CONFIG_PATH / ".env").unlink() + assert read_env_var("TEST_VAR") is None + + +def test_env_var_save_and_load(cleanup_env_files): # noqa: ARG001 + """Test saving and loading environment variables""" + # Save a variable + save_env_var("TEST_VAR", "test_value") + + # Should be able to read it back + assert read_env_var("TEST_VAR") == "test_value" + + # Save another variable - should preserve existing ones + save_env_var("OTHER_VAR", "other_value") + assert read_env_var("TEST_VAR") == "test_value" + assert read_env_var("OTHER_VAR") == "other_value" + + +def test_env_var_not_found(cleanup_env_files): # noqa: ARG001 + """Test behavior when environment variable is not found""" + assert read_env_var("NONEXISTENT_VAR") is None + + +def test_env_file_format(cleanup_env_files): # noqa: ARG001 + """Test that .env file parsing handles different formats correctly""" + with open(".env", "w") as f: + f.write( + """ + TEST_VAR=value1 + SPACES_VAR = value2 + EMPTY_VAR= + #COMMENT=value3 + INVALID_LINE + """, + ) + + assert read_env_var("TEST_VAR") == "value1" + assert read_env_var("SPACES_VAR") == "value2" + assert read_env_var("EMPTY_VAR") == "" + assert read_env_var("COMMENT") is None From 25a935beb8bf9eb48ce71ea392df94a1f621ba2a Mon Sep 17 00:00:00 2001 From: Zain Sohail Date: Tue, 14 Jan 2025 14:08:21 +0100 Subject: [PATCH 10/10] bring back comments --- .cspell/custom-dictionary.txt | 1 - src/sed/loader/flash/metadata.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.cspell/custom-dictionary.txt b/.cspell/custom-dictionary.txt index 6391f3a0..264069be 100644 --- a/.cspell/custom-dictionary.txt +++ b/.cspell/custom-dictionary.txt @@ -89,7 +89,6 @@ dictionarized dictmerge DLDAUX DOOCS -dotenv dpkg dropna dset diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index 9703907b..578fa9fd 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -81,7 +81,7 @@ def get_metadata( pid = f"{beamtime_id}/{run}" logger.debug(f"Retrieving metadata for PID: {pid}") metadata_run = self._get_metadata_per_run(pid) - metadata.update(metadata_run) + metadata.update(metadata_run) # TODO: Not correct for multiple runs logger.debug(f"Retrieved metadata with {len(metadata)} entries") return metadata @@ -111,6 +111,7 @@ def _get_metadata_per_run(self, pid: str) -> dict: ) dataset_response.raise_for_status() + # Check if response is an empty object because wrong url for older implementation if not dataset_response.content: logger.debug("Empty response, trying old URL format") dataset_response = requests.get( @@ -118,11 +119,13 @@ def _get_metadata_per_run(self, pid: str) -> dict: headers=headers2, timeout=10, ) + # If the dataset request is successful, return the retrieved metadata + # as a JSON object return dataset_response.json() except requests.exceptions.RequestException as exception: logger.warning(f"Failed to retrieve metadata for PID {pid}: {str(exception)}") - return {} + return {} # Return an empty dictionary for this run def _create_old_dataset_url(self, pid: str) -> str: return "{burl}/{url}/%2F{npid}".format(