Skip to content

Commit

Permalink
Merge pull request #36 from CMIP-REF/34-handle-fractional
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjared authored Dec 10, 2024
2 parents f65b0de + b36e493 commit 80ce4a3
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 28 deletions.
3 changes: 3 additions & 0 deletions changelog/36.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Added option to skip any datasets that fail validation and to specify the number of cores to
use when ingesting datasets.
This behaviour can be opted in using the `--skip-invalid` and `--n-jobs` options respectively.
22 changes: 16 additions & 6 deletions packages/ref/src/ref/cli/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,16 @@ def list_columns(


@app.command()
def ingest(
def ingest( # noqa: PLR0913
ctx: typer.Context,
file_or_directory: Path,
source_type: SourceDatasetType = typer.Option(help="Type of source dataset"),
solve: bool = typer.Option(False, help="Run metrics after ingestion"),
dry_run: bool = typer.Option(False, help="Do not execute any metrics"),
source_type: Annotated[SourceDatasetType, typer.Option(help="Type of source dataset")],
solve: Annotated[bool, typer.Option(help="Solve for new metric executions after ingestion")] = False,
dry_run: Annotated[bool, typer.Option(help="Do not ingest datasets into the database")] = False,
n_jobs: Annotated[int | None, typer.Option(help="Number of jobs to run in parallel")] = None,
skip_invalid: Annotated[
bool, typer.Option(help="Ignore (but log) any datasets that don't pass validation")
] = False,
) -> None:
"""
Ingest a dataset
Expand All @@ -119,17 +123,23 @@ def ingest(
config = ctx.obj.config
db = ctx.obj.database

file_or_directory = Path(file_or_directory).expanduser()
logger.info(f"ingesting {file_or_directory}")

adapter = get_dataset_adapter(source_type.value)
kwargs = {}

if n_jobs is not None:
kwargs["n_jobs"] = n_jobs

adapter = get_dataset_adapter(source_type.value, **kwargs)

# Create a data catalog from the specified file or directory
if not file_or_directory.exists():
logger.error(f"File or directory {file_or_directory} does not exist")
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file_or_directory)

data_catalog = adapter.find_local_datasets(file_or_directory)
adapter.validate_data_catalog(data_catalog)
data_catalog = adapter.validate_data_catalog(data_catalog, skip_invalid=skip_invalid)

logger.info(
f"Found {len(data_catalog)} files for {len(data_catalog[adapter.slug_column].unique())} datasets"
Expand Down
6 changes: 3 additions & 3 deletions packages/ref/src/ref/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
Dataset handling utilities
"""

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from ref_core.datasets import SourceDatasetType

if TYPE_CHECKING:
from ref.datasets.base import DatasetAdapter


def get_dataset_adapter(source_type: str) -> "DatasetAdapter":
def get_dataset_adapter(source_type: str, **kwargs: Any) -> "DatasetAdapter":
"""
Get the appropriate adapter for the specified source type
Expand All @@ -27,6 +27,6 @@ def get_dataset_adapter(source_type: str) -> "DatasetAdapter":
if source_type.lower() == SourceDatasetType.CMIP6.value:
from ref.datasets.cmip6 import CMIP6DatasetAdapter

return CMIP6DatasetAdapter()
return CMIP6DatasetAdapter(**kwargs)
else:
raise ValueError(f"Unknown source type: {source_type}")
50 changes: 42 additions & 8 deletions packages/ref/src/ref/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,36 @@
from typing import Protocol

import pandas as pd
from loguru import logger

from ref.config import Config
from ref.database import Database
from ref.models.dataset import Dataset


def _log_duplicate_metadata(
data_catalog: pd.DataFrame, unique_metadata: pd.DataFrame, slug_column: str
) -> None:
# Drop out the rows where the values are the same
invalid_datasets = unique_metadata[unique_metadata.gt(1).any(axis=1)]
# Drop out the columns where the values are the same
invalid_datasets = invalid_datasets[invalid_datasets.columns[invalid_datasets.gt(1).any(axis=0)]]

for instance_id in invalid_datasets.index:
# Get the columns where the values are different
invalid_dataset_nunique = invalid_datasets.loc[instance_id]
invalid_dataset_columns = invalid_dataset_nunique[invalid_dataset_nunique.gt(1)].index.tolist()

# Include time_range in the list of invalid columns to make debugging easier
invalid_dataset_columns.append("time_range")

data_catalog_subset = data_catalog[data_catalog[slug_column] == instance_id]

logger.error(
f"Dataset {instance_id} has varying metadata:\n{data_catalog_subset[invalid_dataset_columns]}"
)


class DatasetAdapter(Protocol):
"""
An adapter to provide a common interface for different dataset types
Expand Down Expand Up @@ -43,14 +67,21 @@ def register_dataset(
"""
...

def validate_data_catalog(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
def validate_data_catalog(self, data_catalog: pd.DataFrame, skip_invalid: bool = False) -> pd.DataFrame:
"""
Validate a data catalog
Parameters
----------
data_catalog
Data catalog to validate
skip_invalid
If True, ignore datasets with invalid metadata and remove them from the resulting data catalog.
Raises
------
ValueError
If `skip_invalid` is False (default) and the data catalog contains validation errors.
Returns
-------
Expand All @@ -70,13 +101,16 @@ def validate_data_catalog(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
data_catalog[list(self.dataset_specific_metadata)].groupby(self.slug_column).nunique()
)
if unique_metadata.gt(1).any(axis=1).any():
# Drop out the rows where the values are the same
invalid_datasets = unique_metadata[unique_metadata.gt(1).any(axis=1)]
# Drop out the columns where the values are the same
invalid_datasets = invalid_datasets[invalid_datasets.gt(1)].dropna(axis=1)
raise ValueError(
f"Dataset specific metadata varies by dataset.\nUnique values: {invalid_datasets}"
)
_log_duplicate_metadata(data_catalog, unique_metadata, self.slug_column)

if skip_invalid:
data_catalog = data_catalog[
~data_catalog[self.slug_column].isin(
unique_metadata[unique_metadata.gt(1).any(axis=1)].index
)
]
else:
raise ValueError("Dataset specific metadata varies by dataset")

return data_catalog

Expand Down
46 changes: 43 additions & 3 deletions packages/ref/src/ref/datasets/cmip6.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,47 @@ def _parse_datetime(dt_str: pd.Series[str]) -> pd.Series[datetime | Any]:
"""
Pandas tries to coerce everything to their own datetime format, which is not what we want here.
"""

def _inner(date_string: str | None) -> datetime | None:
if not date_string:
return None

# Try to parse the date string with and without milliseconds
try:
dt = datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
except ValueError:
dt = datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")

return dt

return pd.Series(
[datetime.strptime(dt, "%Y-%m-%d %H:%M:%S") if dt else None for dt in dt_str],
[_inner(dt) for dt in dt_str],
index=dt_str.index,
dtype="object",
)


def _apply_fixes(data_catalog: pd.DataFrame) -> pd.DataFrame:
def _fix_parent_variant_label(group: pd.DataFrame) -> pd.DataFrame:
if group["parent_variant_label"].nunique() == 1:
return group
group["parent_variant_label"] = group["variant_label"].iloc[0]

return group

data_catalog = data_catalog.groupby("instance_id").apply(_fix_parent_variant_label).reset_index(drop=True)

# EC-Earth3 uses "D" as a suffix for the branch_time_in_child and branch_time_in_parent columns
data_catalog["branch_time_in_child"] = pd.to_numeric(
data_catalog["branch_time_in_child"].astype(str).str.replace("D", ""), errors="raise"
)
data_catalog["branch_time_in_parent"] = pd.to_numeric(
data_catalog["branch_time_in_parent"].astype(str).str.replace("D", ""), errors="raise"
)

return data_catalog


class CMIP6DatasetAdapter(DatasetAdapter):
"""
Adapter for CMIP6 datasets
Expand Down Expand Up @@ -75,6 +109,9 @@ class CMIP6DatasetAdapter(DatasetAdapter):

file_specific_metadata = ("start_time", "end_time", "path")

def __init__(self, n_jobs: int = 1):
self.n_jobs = n_jobs

def pretty_subset(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
"""
Get a subset of the data_catalog to pretty print
Expand Down Expand Up @@ -127,8 +164,7 @@ def find_local_datasets(self, file_or_directory: Path) -> pd.DataFrame:
paths=[str(file_or_directory)],
depth=10,
include_patterns=["*.nc"],
# TODO: This is hardcoded to 1 because of >1 fails during unittests
joblib_parallel_kwargs={"n_jobs": 1},
joblib_parallel_kwargs={"n_jobs": self.n_jobs},
).build(parsing_func=ecgtools.parsers.parse_cmip6)

datasets = builder.df
Expand All @@ -153,6 +189,10 @@ def find_local_datasets(self, file_or_directory: Path) -> pd.DataFrame:
lambda row: "CMIP6." + ".".join([row[item] for item in drs_items]), axis=1
)

# Temporary fix for some datasets
# TODO: Replace with a standalone package that contains metadata fixes for CMIP6 datasets
datasets = _apply_fixes(datasets)

return datasets

def register_dataset(
Expand Down
14 changes: 13 additions & 1 deletion packages/ref/tests/unit/datasets/test_cmip6.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime

import pandas as pd
import pytest

from ref.datasets.cmip6 import CMIP6DatasetAdapter
from ref.datasets.cmip6 import CMIP6DatasetAdapter, _parse_datetime


@pytest.fixture
Expand All @@ -15,6 +17,16 @@ def check(df: pd.DataFrame, basename: str):
return check


def test_parse_datetime():
pd.testing.assert_series_equal(
_parse_datetime(pd.Series(["2021-01-01 00:00:00", "1850-01-17 00:29:59.999993", None])),
pd.Series(
[datetime.datetime(2021, 1, 1, 0, 0), datetime.datetime(1850, 1, 17, 0, 29, 59, 999993), None],
dtype="object",
),
)


class TestCMIP6Adapter:
def test_catalog_empty(self, db):
adapter = CMIP6DatasetAdapter()
Expand Down
21 changes: 14 additions & 7 deletions packages/ref/tests/unit/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class MockDatasetAdapter(DatasetAdapter):
dataset_model: pd.DataFrame
slug_column: str = "dataset_slug"
dataset_specific_metadata: tuple[str, ...] = ("metadata1", "metadata2")
dataset_specific_metadata: tuple[str, ...] = ("metadata1", "metadata2", "dataset_slug")
file_specific_metadata: tuple[str, ...] = ("file_name", "file_size")

def pretty_subset(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -24,10 +24,11 @@ def find_local_datasets(self, file_or_directory: Path) -> pd.DataFrame:
"dataset_slug": [f"{file_or_directory.stem}_001", f"{file_or_directory.stem}_001"],
"metadata1": ["value1", "value1"],
"metadata2": ["value2", "value2"],
"time_range": ["2020-01-01", "2020-01-01"],
"file_name": [file_or_directory.name, file_or_directory.name + "_2"],
"file_size": [100, 100],
}
return pd.DataFrame(data).set_index(self.slug_column)
return pd.DataFrame(data)

def register_dataset(self, config, db, data_catalog_dataset: pd.DataFrame) -> pd.DataFrame | None:
# Returning the input as a stand-in "registered" dataset
Expand Down Expand Up @@ -60,20 +61,26 @@ def test_validate_data_catalog_missing_columns():
adapter.validate_data_catalog(data_catalog.drop(columns=["file_name"]))


def test_validate_data_catalog_metadata_variance():
def test_validate_data_catalog_metadata_variance(caplog):
adapter = MockDatasetAdapter()
data_catalog = adapter.find_local_datasets(Path("path/to/dataset"))
# file_name differs between datasets
adapter.dataset_specific_metadata = ("metadata1", "metadata2", "file_name")
adapter.dataset_specific_metadata = (*adapter.dataset_specific_metadata, "file_name")

exp_df = pd.DataFrame(columns=["file_name"], index=["dataset_001"], data=[2])
exp_df.index.name = "dataset_slug"
exp_message = "Dataset dataset_001 has varying metadata:\n file_name time_range\n0 dataset 2020-01-01\n1 dataset_2 2020-01-01" # noqa: E501

with pytest.raises(
ValueError,
match=f"Dataset specific metadata varies by dataset.\nUnique values: {exp_df}",
match="Dataset specific metadata varies by dataset",
):
adapter.validate_data_catalog(data_catalog)
assert len(caplog.records) == 1
assert caplog.records[0].message == exp_message

caplog.clear()
assert len(adapter.validate_data_catalog(data_catalog, skip_invalid=True)) == 0
assert len(caplog.records) == 1
assert caplog.records[0].message == exp_message


@pytest.mark.parametrize(
Expand Down

0 comments on commit 80ce4a3

Please sign in to comment.