Skip to content

Commit

Permalink
test: Fix up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjared committed Nov 20, 2024
1 parent f9921fb commit 79ca672
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 212 deletions.
89 changes: 0 additions & 89 deletions packages/ref/alembic/versions/0e4a1c4da55d_dataset_rework.py

This file was deleted.

6 changes: 3 additions & 3 deletions packages/ref/src/ref/cli/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from ref.cli.solve import solve as solve_cli
from ref.config import Config
from ref.database import Database
from ref.datasets import get_dataset_adapter, validate_data_catalog
from ref.datasets import get_dataset_adapter
from ref.models.dataset import Dataset

app = typer.Typer()
console = Console()


def validate_prefix(config: Config, raw_path: str) -> Path:
def validate_path(config: Config, raw_path: str) -> Path:
"""
Validate the prefix of a dataset against the data directory
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def ingest(
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file_or_directory)

data_catalog = adapter.find_datasets(file_or_directory)
validate_data_catalog(adapter, data_catalog)
adapter.validate_data_catalog(data_catalog)

logger.info(f"Found {len(data_catalog)} files for {len(data_catalog.index.unique())} datasets")
pretty_print_df(adapter.pretty_subset(data_catalog))
Expand Down
45 changes: 0 additions & 45 deletions packages/ref/src/ref/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,12 @@

from typing import TYPE_CHECKING

import pandas as pd
from ref_core.datasets import SourceDatasetType

if TYPE_CHECKING:
from ref.datasets.base import DatasetAdapter


def validate_data_catalog(adapter: "DatasetAdapter", data_catalog: pd.DataFrame) -> pd.DataFrame:
"""
Validate the data catalog against the adapter
Parameters
----------
adapter
Dataset adapter
data_catalog
Data catalog to validate
Returns
-------
:
Validated data catalog
"""
# Check if the data catalog contains the required columns
missing_columns = set(adapter.dataset_specific_metadata + adapter.file_specific_metadata) - set(
data_catalog.columns
)
if missing_columns:
raise ValueError(f"Data catalog is missing required columns: {missing_columns}")

# Verify that the dataset specific columns don't vary by dataset
unique_metadata = (
data_catalog[list(adapter.dataset_specific_metadata)].groupby(adapter.slug_column).nunique()
)

# Verify that the dataset specific columns don't vary by dataset by counting the unique values
# for each dataset and checking if there are any that have more than one unique value.

unique_metadata = (
data_catalog[list(adapter.dataset_specific_metadata)].groupby(adapter.slug_column).nunique()
)
if unique_metadata.gt(1).any(axis=1).any():
# Drop out the rows where the values are the same
invalid_datasets = unique_metadata[unique_metadata.gt(1).any(axis=1)]
# Drop out the columns where the values are the same
invalid_datasets = invalid_datasets[invalid_datasets.gt(1)].dropna(axis=1)
raise ValueError(f"Dataset specific metadata varies by dataset.\nUnique values: {invalid_datasets}")

return data_catalog


def get_dataset_adapter(source_type: str) -> "DatasetAdapter":
"""
Get the appropriate adapter for the specified source type
Expand Down
40 changes: 39 additions & 1 deletion packages/ref/src/ref/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DatasetAdapter(Protocol):

slug_column: str
dataset_specific_metadata: tuple[str, ...]
file_specific_metadata: tuple[str, ...]
file_specific_metadata: tuple[str, ...] = ()

def pretty_subset(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
"""
Expand All @@ -41,3 +41,41 @@ def register_dataset(
Register a dataset in the database using the data catalog
"""
...

def validate_data_catalog(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
"""
Validate a data catalog
Parameters
----------
data_catalog
Data catalog to validate
Returns
-------
:
Validated data catalog
"""
# Check if the data catalog contains the required columns
missing_columns = set(self.dataset_specific_metadata + self.file_specific_metadata) - set(
data_catalog.columns
)
if missing_columns:
raise ValueError(f"Data catalog is missing required columns: {missing_columns}")

# Verify that the dataset specific columns don't vary by dataset by counting the unique values
# for each dataset and checking if there are any that have more than one unique value.

unique_metadata = (
data_catalog[list(self.dataset_specific_metadata)].groupby(self.slug_column).nunique()
)
if unique_metadata.gt(1).any(axis=1).any():
# Drop out the rows where the values are the same
invalid_datasets = unique_metadata[unique_metadata.gt(1).any(axis=1)]
# Drop out the columns where the values are the same
invalid_datasets = invalid_datasets[invalid_datasets.gt(1)].dropna(axis=1)
raise ValueError(
f"Dataset specific metadata varies by dataset.\nUnique values: {invalid_datasets}"
)

return data_catalog
25 changes: 16 additions & 9 deletions packages/ref/src/ref/datasets/cmip6.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loguru import logger
from ref_core.exceptions import RefException

from ref.cli.ingest import validate_prefix
from ref.cli.ingest import validate_path
from ref.config import Config
from ref.database import Database
from ref.datasets.base import DatasetAdapter
Expand Down Expand Up @@ -71,7 +71,7 @@ class CMIP6DatasetAdapter(DatasetAdapter):
slug_column,
)

file_specific_metadata = ("start_time", "end_time", "time_range", "path")
file_specific_metadata = ("start_time", "end_time", "path")

def pretty_subset(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -171,12 +171,15 @@ def register_dataset(
:
Registered dataset if successful, else None
"""
self.validate_data_catalog(data_catalog_dataset)

unique_slugs = data_catalog_dataset[self.slug_column].unique()
if len(unique_slugs) != 1:
raise RefException(f"Found multiple datasets in the same directory: {unique_slugs}")
slug = unique_slugs[0]

dataset, created = db.get_or_create(CMIP6Dataset, slug=slug)
dataset_metadata = data_catalog_dataset[list(self.dataset_specific_metadata)].iloc[0].to_dict()
dataset, created = db.get_or_create(CMIP6Dataset, slug=slug, **dataset_metadata)

if not created:
logger.warning(f"{dataset} already exists in the database. Skipping")
Expand All @@ -185,11 +188,15 @@ def register_dataset(
db.session.flush()

for dataset_file in data_catalog_dataset.to_dict(orient="records"):
dataset_file["dataset_id"] = dataset.id

raw_path = dataset_file.pop("path")
prefix = validate_prefix(config, raw_path)

db.session.add(CMIP6File.build(prefix=str(prefix), **dataset_file)) # type: ignore
path = validate_path(config, dataset_file.pop("path"))

db.session.add(
CMIP6File(
path=str(path),
dataset_id=dataset.id,
start_time=dataset_file.pop("start_time"),
end_time=dataset_file.pop("end_time"),
)
) # type: ignore

return dataset
Loading

0 comments on commit 79ca672

Please sign in to comment.