diff --git a/src/sed/core/config.py b/src/sed/core/config.py index 647b1c33..328353e4 100644 --- a/src/sed/core/config.py +++ b/src/sed/core/config.py @@ -242,3 +242,78 @@ def complete_dictionary(dictionary: dict, base_dictionary: dict) -> dict: dictionary[k] = v return dictionary + + +def _parse_env_file(file_path: Path) -> dict: + """Helper function to parse a .env file into a dictionary. + + Args: + file_path (Path): Path to the .env file + + Returns: + dict: Dictionary of environment variables from the file + """ + env_content = {} + if file_path.exists(): + with open(file_path) as f: + for line in f: + line = line.strip() + if line and "=" in line: + key, val = line.split("=", 1) + env_content[key.strip()] = val.strip() + return env_content + + +def read_env_var(var_name: str) -> str | None: + """Read an environment variable from multiple locations in order: + 1. OS environment variables + 2. .env file in current directory + 3. .env file in user config directory + + Args: + var_name (str): Name of the environment variable to read + + Returns: + str | None: Value of the environment variable or None if not found + """ + # First check OS environment variables + value = os.getenv(var_name) + if value is not None: + logger.debug(f"Found {var_name} in OS environment variables") + return value + + # Then check .env in current directory + local_vars = _parse_env_file(Path(".env")) + if var_name in local_vars: + logger.debug(f"Found {var_name} in ./.env file") + return local_vars[var_name] + + # Finally check .env in user config directory + user_vars = _parse_env_file(USER_CONFIG_PATH / ".env") + if var_name in user_vars: + logger.debug(f"Found {var_name} in user config .env file") + return user_vars[var_name] + + logger.debug(f"Environment variable {var_name} not found in any location") + return None + + +def save_env_var(var_name: str, value: str) -> None: + """Save an environment variable to the .env file in the user config directory. + If the file exists, preserves other variables. If not, creates a new file. + + Args: + var_name (str): Name of the environment variable to save + value (str): Value to save for the environment variable + """ + env_path = USER_CONFIG_PATH / ".env" + env_content = _parse_env_file(env_path) + + # Update or add new variable + env_content[var_name] = value + + # Write all variables back to file + with open(env_path, "w") as f: + for key, val in env_content.items(): + f.write(f"{key}={val}\n") + logger.debug(f"Environment variable {var_name} saved to .env file") diff --git a/src/sed/core/config_model.py b/src/sed/core/config_model.py index 880e5864..6379b639 100644 --- a/src/sed/core/config_model.py +++ b/src/sed/core/config_model.py @@ -14,7 +14,6 @@ from pydantic import HttpUrl from pydantic import NewPath from pydantic import PositiveInt -from pydantic import SecretStr from sed.loader.loader_interface import get_names_of_all_loaders @@ -323,7 +322,6 @@ class MetadataModel(BaseModel): model_config = ConfigDict(extra="forbid") archiver_url: Optional[HttpUrl] = None - token: Optional[SecretStr] = None epics_pvs: Optional[Sequence[str]] = None fa_in_channel: Optional[str] = None fa_hor_channel: Optional[str] = None diff --git a/src/sed/loader/flash/buffer_handler.py b/src/sed/loader/flash/buffer_handler.py index a1e5a258..d56de29f 100644 --- a/src/sed/loader/flash/buffer_handler.py +++ b/src/sed/loader/flash/buffer_handler.py @@ -2,6 +2,7 @@ import os from pathlib import Path +import time import dask.dataframe as dd import pyarrow.parquet as pq @@ -20,7 +21,7 @@ DF_TYP = ["electron", "timed"] -logger = setup_logging(__name__) +logger = setup_logging("flash_buffer_handler") class BufferFilePaths: @@ -135,16 +136,15 @@ def __init__( def _schema_check(self, files: list[Path], expected_schema_set: set) -> None: """ Checks the schema of the Parquet files. - - Raises: - ValueError: If the schema of the Parquet files does not match the configuration. """ + logger.debug(f"Checking schema for {len(files)} files") existing = [file for file in files if file.exists()] parquet_schemas = [pq.read_schema(file) for file in existing] for filename, schema in zip(existing, parquet_schemas): schema_set = set(schema.names) if schema_set != expected_schema_set: + logger.error(f"Schema mismatch in file: {filename}") missing_in_parquet = expected_schema_set - schema_set missing_in_config = schema_set - expected_schema_set @@ -159,6 +159,7 @@ def _schema_check(self, files: list[Path], expected_schema_set: set) -> None: f"{' '.join(errors)}. " "Please check the configuration file or set force_recreate to True.", ) + logger.debug("Schema check passed successfully") def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: """Creates the timed dataframe, optionally filtering by electron events. @@ -185,35 +186,31 @@ def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame: return df_timed.loc[:, :, 0] def _save_buffer_file(self, paths: dict[str, Path]) -> None: - """ - Creates the electron and timed buffer files from the raw H5 file. - First the dataframe is accessed and forward filled in the non-electron channels. - Then the data types are set. For the electron dataframe, all values not in the electron - channels are dropped. For the timed dataframe, only the train and pulse channels are taken - and it pulse resolved (no longer electron resolved). Both are saved as parquet files. - - Args: - paths (dict[str, Path]): Dictionary containing the paths to the H5 and buffer files. - """ - # Create a DataFrameCreator instance and get the h5 file + """Creates the electron and timed buffer files from the raw H5 file.""" + logger.debug(f"Processing file: {paths['raw'].stem}") + start_time = time.time() + # Create DataFrameCreator and get get dataframe df = DataFrameCreator(config_dataframe=self._config, h5_path=paths["raw"]).df - # forward fill all the non-electron channels + # Forward fill non-electron channels + logger.debug(f"Forward filling {len(self.fill_channels)} channels") df[self.fill_channels] = df[self.fill_channels].ffill() # Save electron resolved dataframe electron_channels = get_channels(self._config, "per_electron") dtypes = get_dtypes(self._config, df.columns.values) - df.dropna(subset=electron_channels).astype(dtypes).reset_index().to_parquet( - paths["electron"], - ) + electron_df = df.dropna(subset=electron_channels).astype(dtypes).reset_index() + logger.debug(f"Saving electron buffer with shape: {electron_df.shape}") + electron_df.to_parquet(paths["electron"]) # Create and save timed dataframe df_timed = self._create_timed_dataframe(df) dtypes = get_dtypes(self._config, df_timed.columns.values) - df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"]) + timed_df = df_timed.astype(dtypes).reset_index() + logger.debug(f"Saving timed buffer with shape: {timed_df.shape}") + timed_df.to_parquet(paths["timed"]) - logger.debug(f"Processed {paths['raw'].stem}") + logger.debug(f"Processed {paths['raw'].stem} in {time.time() - start_time:.2f}s") def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None: """ diff --git a/src/sed/loader/flash/dataframe.py b/src/sed/loader/flash/dataframe.py index 6501c82a..f50abe10 100644 --- a/src/sed/loader/flash/dataframe.py +++ b/src/sed/loader/flash/dataframe.py @@ -14,6 +14,9 @@ from sed.loader.flash.utils import get_channels from sed.loader.flash.utils import InvalidFileError +from sed.core.logging import setup_logging + +logger = setup_logging("flash_dataframe_creator") class DataFrameCreator: @@ -34,6 +37,7 @@ def __init__(self, config_dataframe: dict, h5_path: Path) -> None: config_dataframe (dict): The configuration dictionary with only the dataframe key. h5_path (Path): Path to the h5 file. """ + logger.debug(f"Initializing DataFrameCreator for file: {h5_path}") self.h5_file = h5py.File(h5_path, "r") self.multi_index = get_channels(index=True) self._config = config_dataframe @@ -76,6 +80,7 @@ def get_dataset_array( tuple[pd.Index, np.ndarray | h5py.Dataset]: A tuple containing the train ID pd.Index and the channel's data. """ + logger.debug(f"Getting dataset array for channel: {channel}") # Get the data from the necessary h5 file and channel index_key, dataset_key = self.get_index_dataset_key(channel) @@ -85,6 +90,7 @@ def get_dataset_array( if slice_: slice_index = self._config["channels"][channel].get("slice", None) if slice_index is not None: + logger.debug(f"Slicing dataset with index: {slice_index}") dataset = np.take(dataset, slice_index, axis=1) # If np_array is size zero, fill with NaNs, fill it with NaN values # of the same shape as index @@ -291,10 +297,14 @@ def df(self) -> pd.DataFrame: Returns: pd.DataFrame: The combined pandas DataFrame. """ - + logger.debug("Creating combined DataFrame") self.validate_channel_keys() - # been tested with merge, join and concat - # concat offers best performance, almost 3 times faster + df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index() - # all the negative pulse values are dropped as they are invalid - return df[df.index.get_level_values("pulseId") >= 0] + logger.debug(f"Created DataFrame with shape: {df.shape}") + + # Filter negative pulse values + df = df[df.index.get_level_values("pulseId") >= 0] + logger.debug(f"Filtered DataFrame shape: {df.shape}") + + return df diff --git a/src/sed/loader/flash/loader.py b/src/sed/loader/flash/loader.py index b1d4ff81..c2cf79b9 100644 --- a/src/sed/loader/flash/loader.py +++ b/src/sed/loader/flash/loader.py @@ -207,14 +207,14 @@ def get_files_from_run_id( # type: ignore[override] # Return the list of found files return [str(file.resolve()) for file in files] - def parse_metadata(self, scicat_token: str = None) -> dict: + def parse_metadata(self, token: str = None) -> dict: """Uses the MetadataRetriever class to fetch metadata from scicat for each run. Returns: dict: Metadata dictionary - scicat_token (str, optional):: The scicat token to use for fetching metadata + token (str, optional):: The scicat token to use for fetching metadata """ - metadata_retriever = MetadataRetriever(self._config["metadata"], scicat_token) + metadata_retriever = MetadataRetriever(self._config["metadata"], token) metadata = metadata_retriever.get_metadata( beamtime_id=self._config["core"]["beamtime_id"], runs=self.runs, @@ -329,7 +329,9 @@ def read_dataframe( debug (bool, optional): Whether to run buffer creation in serial. Defaults to False. remove_invalid_files (bool, optional): Whether to exclude invalid files. Defaults to False. - scicat_token (str, optional): The scicat token to use for fetching metadata. + token (str, optional): The scicat token to use for fetching metadata. If provided, + will be saved to .env file for future use. If not provided, will check environment + variables when collect_metadata is True. filter_timed_by_electron (bool, optional): When True, the timed dataframe will only contain data points where valid electron events were detected. When False, all timed data points are included regardless of electron detection. Defaults to True. @@ -341,13 +343,14 @@ def read_dataframe( Raises: ValueError: If neither 'runs' nor 'files'/'raw_dir' is provided. FileNotFoundError: If the conversion fails for some files or no data is available. + ValueError: If collect_metadata is True and no token is available. """ detector = kwds.pop("detector", "") force_recreate = kwds.pop("force_recreate", False) processed_dir = kwds.pop("processed_dir", None) debug = kwds.pop("debug", False) remove_invalid_files = kwds.pop("remove_invalid_files", False) - scicat_token = kwds.pop("scicat_token", None) + token = kwds.pop("token", None) filter_timed_by_electron = kwds.pop("filter_timed_by_electron", True) if len(kwds) > 0: @@ -401,7 +404,7 @@ def read_dataframe( if self.instrument == "wespe": df, df_timed = wespe_convert(df, df_timed) - self.metadata.update(self.parse_metadata(scicat_token) if collect_metadata else {}) + self.metadata.update(self.parse_metadata(token) if collect_metadata else {}) self.metadata.update(bh.metadata) print(f"loading complete in {time.time() - t0: .2f} s") diff --git a/src/sed/loader/flash/metadata.py b/src/sed/loader/flash/metadata.py index 50fd69b1..578fa9fd 100644 --- a/src/sed/loader/flash/metadata.py +++ b/src/sed/loader/flash/metadata.py @@ -4,10 +4,14 @@ """ from __future__ import annotations -import warnings - import requests +from sed.core.config import read_env_var +from sed.core.config import save_env_var +from sed.core.logging import setup_logging + +logger = setup_logging("flash_metadata_retriever") + class MetadataRetriever: """ @@ -15,28 +19,37 @@ class MetadataRetriever: on beamtime and run IDs. """ - def __init__(self, metadata_config: dict, scicat_token: str = None) -> None: + def __init__(self, metadata_config: dict, token: str = None) -> None: """ Initializes the MetadataRetriever class. Args: - metadata_config (dict): Takes a dict containing - at least url, and optionally token for the scicat instance. - scicat_token (str, optional): The token to use for fetching metadata. + metadata_config (dict): Takes a dict containing at least url for the scicat instance. + token (str, optional): The token to use for fetching metadata. If provided, + will be saved to .env file for future use. """ - self.token = metadata_config.get("scicat_token", None) - if scicat_token: - self.token = scicat_token - self.url = metadata_config.get("scicat_url", None) + # Token handling + if token: + self.token = token + save_env_var("SCICAT_TOKEN", self.token) + else: + # Try to load token from config or .env file + self.token = read_env_var("SCICAT_TOKEN") + + if not self.token: + raise ValueError( + "Token is required for metadata collection. Either provide a token " + "parameter or set the SCICAT_TOKEN environment variable.", + ) - if not self.token or not self.url: - raise ValueError("No URL or token provided for fetching metadata from scicat.") + self.url = metadata_config.get("archiver_url") + if not self.url: + raise ValueError("No URL provided for fetching metadata from scicat.") self.headers = { "Content-Type": "application/json", "Accept": "application/json", } - self.token = metadata_config["scicat_token"] def get_metadata( self, @@ -59,19 +72,18 @@ def get_metadata( Raises: Exception: If the request to retrieve metadata fails. """ - # If metadata is not provided, initialize it as an empty dictionary + logger.debug(f"Fetching metadata for beamtime {beamtime_id}, runs: {runs}") + if metadata is None: metadata = {} - # Iterate over the list of runs for run in runs: pid = f"{beamtime_id}/{run}" - # Retrieve metadata for each run and update the overall metadata dictionary + logger.debug(f"Retrieving metadata for PID: {pid}") metadata_run = self._get_metadata_per_run(pid) - metadata.update( - metadata_run, - ) # TODO: Not correct for multiple runs + metadata.update(metadata_run) # TODO: Not correct for multiple runs + logger.debug(f"Retrieved metadata with {len(metadata)} entries") return metadata def _get_metadata_per_run(self, pid: str) -> dict: @@ -91,14 +103,17 @@ def _get_metadata_per_run(self, pid: str) -> dict: headers2["Authorization"] = f"Bearer {self.token}" try: + logger.debug(f"Attempting to fetch metadata with new URL format for PID: {pid}") dataset_response = requests.get( self._create_new_dataset_url(pid), headers=headers2, timeout=10, ) dataset_response.raise_for_status() + # Check if response is an empty object because wrong url for older implementation if not dataset_response.content: + logger.debug("Empty response, trying old URL format") dataset_response = requests.get( self._create_old_dataset_url(pid), headers=headers2, @@ -107,9 +122,9 @@ def _get_metadata_per_run(self, pid: str) -> dict: # If the dataset request is successful, return the retrieved metadata # as a JSON object return dataset_response.json() + except requests.exceptions.RequestException as exception: - # If the request fails, raise warning - print(warnings.warn(f"Failed to retrieve metadata for PID {pid}: {str(exception)}")) + logger.warning(f"Failed to retrieve metadata for PID {pid}: {str(exception)}") return {} # Return an empty dictionary for this run def _create_old_dataset_url(self, pid: str) -> str: diff --git a/tests/loader/flash/test_flash_metadata.py b/tests/loader/flash/test_flash_metadata.py index 5b30b0da..165a6e8a 100644 --- a/tests/loader/flash/test_flash_metadata.py +++ b/tests/loader/flash/test_flash_metadata.py @@ -1,10 +1,17 @@ """Tests for FlashLoader metadata functionality""" from __future__ import annotations +import os + import pytest +from sed.core.config import read_env_var +from sed.core.config import save_env_var +from sed.core.config import USER_CONFIG_PATH from sed.loader.flash.metadata import MetadataRetriever +ENV_PATH = USER_CONFIG_PATH / ".env" + @pytest.fixture def mock_requests(requests_mock) -> None: @@ -13,50 +20,77 @@ def mock_requests(requests_mock) -> None: requests_mock.get(dataset_url, json={"fake": "data"}, status_code=200) -# Test cases for MetadataRetriever -def test_get_metadata(mock_requests: None) -> None: # noqa: ARG001 +def test_get_metadata_with_explicit_token(mock_requests: None) -> None: # noqa: ARG001 + metadata_config = { + "archiver_url": "https://example.com", + } + retriever = MetadataRetriever(metadata_config, token="explicit_test_token") + metadata = retriever.get_metadata("11013410", ["43878"]) + assert isinstance(metadata, dict) + assert metadata == {"fake": "data"} + assert ENV_PATH.exists() + assert read_env_var("SCICAT_TOKEN") == "explicit_test_token" + os.remove(ENV_PATH) + + +def test_get_metadata_with_env_token(mock_requests: None) -> None: # noqa: ARG001 + save_env_var("SCICAT_TOKEN", "env_test_token") metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", } retriever = MetadataRetriever(metadata_config) metadata = retriever.get_metadata("11013410", ["43878"]) assert isinstance(metadata, dict) assert metadata == {"fake": "data"} + os.remove(ENV_PATH) + + +def test_get_metadata_no_token() -> None: + metadata_config = { + "archiver_url": "https://example.com", + } + with pytest.raises(ValueError, match="Token is required for metadata collection"): + MetadataRetriever(metadata_config) + + +def test_get_metadata_no_url() -> None: + metadata_config: dict = {} + with pytest.raises(ValueError, match="No URL provided for fetching metadata"): + MetadataRetriever(metadata_config, token="test_token") + os.remove(ENV_PATH) def test_get_metadata_with_existing_metadata(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", } - retriever = MetadataRetriever(metadata_config) + retriever = MetadataRetriever(metadata_config, token="test_token") existing_metadata = {"existing": "metadata"} metadata = retriever.get_metadata("11013410", ["43878"], existing_metadata) assert isinstance(metadata, dict) assert metadata == {"existing": "metadata", "fake": "data"} + os.remove(ENV_PATH) def test_get_metadata_per_run(mock_requests: None) -> None: # noqa: ARG001 metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", } - retriever = MetadataRetriever(metadata_config) + retriever = MetadataRetriever(metadata_config, token="test_token") metadata = retriever._get_metadata_per_run("11013410/43878") assert isinstance(metadata, dict) assert metadata == {"fake": "data"} + os.remove(ENV_PATH) def test_create_dataset_url_by_PID() -> None: metadata_config = { - "scicat_url": "https://example.com", - "scicat_token": "fake_token", + "archiver_url": "https://example.com", } - retriever = MetadataRetriever(metadata_config) - # Assuming the dataset follows the new format + retriever = MetadataRetriever(metadata_config, token="test_token") pid = "11013410/43878" url = retriever._create_new_dataset_url(pid) expected_url = "https://example.com/Datasets/11013410%2F43878" assert isinstance(url, str) assert url == expected_url + os.remove(ENV_PATH) diff --git a/tests/test_config.py b/tests/test_config.py index 7f8a43b8..75d780c6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,10 @@ from sed.core.config import complete_dictionary from sed.core.config import load_config from sed.core.config import parse_config +from sed.core.config import read_env_var from sed.core.config import save_config +from sed.core.config import save_env_var +from sed.core.config import USER_CONFIG_PATH test_dir = os.path.dirname(__file__) test_config_dir = Path(f"{test_dir}/data/loader/") @@ -231,3 +234,128 @@ def test_invalid_config_wrong_values(): verify_config=True, ) assert "Invalid value 9999 for gid. Group not found." in str(e.value) + + +def test_env_var_read_write(tmp_path, monkeypatch): + """Test reading and writing environment variables.""" + # Mock USER_CONFIG_PATH to use a temporary directory + monkeypatch.setattr("sed.core.config.USER_CONFIG_PATH", tmp_path) + + # Test writing a new variable + save_env_var("TEST_VAR", "test_value") + assert read_env_var("TEST_VAR") == "test_value" + + # Test writing multiple variables + save_env_var("TEST_VAR2", "test_value2") + assert read_env_var("TEST_VAR") == "test_value" + assert read_env_var("TEST_VAR2") == "test_value2" + + # Test overwriting an existing variable + save_env_var("TEST_VAR", "new_value") + assert read_env_var("TEST_VAR") == "new_value" + assert read_env_var("TEST_VAR2") == "test_value2" # Other variables unchanged + + # Test reading non-existent variable + assert read_env_var("NON_EXISTENT_VAR") is None + + +def test_env_var_read_no_file(tmp_path, monkeypatch): + """Test reading environment variables when .env file doesn't exist.""" + # Mock USER_CONFIG_PATH to use an empty temporary directory + monkeypatch.setattr("sed.core.config.USER_CONFIG_PATH", tmp_path) + + # Test reading from non-existent file + assert read_env_var("TEST_VAR") is None + + +def test_env_var_special_characters(): + """Test reading and writing environment variables with special characters.""" + test_cases = { + "TEST_URL": "http://example.com/path?query=value", + "TEST_PATH": "/path/to/something/with/spaces and special=chars", + "TEST_QUOTES": "value with 'single' and \"double\" quotes", + } + + for var_name, value in test_cases.items(): + save_env_var(var_name, value) + assert read_env_var(var_name) == value + + +@pytest.fixture +def cleanup_env_files(): + """Cleanup any .env files before and after tests""" + # Clean up any existing .env files + for path in [Path(".env"), USER_CONFIG_PATH / ".env"]: + if path.exists(): + path.unlink() + + yield + + # Clean up after tests + for path in [Path(".env"), USER_CONFIG_PATH / ".env"]: + if path.exists(): + path.unlink() + + +def test_env_var_precedence(cleanup_env_files): # noqa: ARG001 + """Test that environment variables are read in correct order of precedence""" + # Set up test values in different locations + os.environ["TEST_VAR"] = "os_value" + + with open(".env", "w") as f: + f.write("TEST_VAR=local_value\n") + + save_env_var("TEST_VAR", "user_value") # Saves to USER_CONFIG_PATH + + # Should get OS value first + assert read_env_var("TEST_VAR") == "os_value" + + # Remove from OS env and should get local value + del os.environ["TEST_VAR"] + assert read_env_var("TEST_VAR") == "local_value" + + # Remove local .env and should get user config value + Path(".env").unlink() + assert read_env_var("TEST_VAR") == "user_value" + + # Remove user config and should get None + (USER_CONFIG_PATH / ".env").unlink() + assert read_env_var("TEST_VAR") is None + + +def test_env_var_save_and_load(cleanup_env_files): # noqa: ARG001 + """Test saving and loading environment variables""" + # Save a variable + save_env_var("TEST_VAR", "test_value") + + # Should be able to read it back + assert read_env_var("TEST_VAR") == "test_value" + + # Save another variable - should preserve existing ones + save_env_var("OTHER_VAR", "other_value") + assert read_env_var("TEST_VAR") == "test_value" + assert read_env_var("OTHER_VAR") == "other_value" + + +def test_env_var_not_found(cleanup_env_files): # noqa: ARG001 + """Test behavior when environment variable is not found""" + assert read_env_var("NONEXISTENT_VAR") is None + + +def test_env_file_format(cleanup_env_files): # noqa: ARG001 + """Test that .env file parsing handles different formats correctly""" + with open(".env", "w") as f: + f.write( + """ + TEST_VAR=value1 + SPACES_VAR = value2 + EMPTY_VAR= + #COMMENT=value3 + INVALID_LINE + """, + ) + + assert read_env_var("TEST_VAR") == "value1" + assert read_env_var("SPACES_VAR") == "value2" + assert read_env_var("EMPTY_VAR") == "" + assert read_env_var("COMMENT") is None