diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3fe607f..622fa7f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -33,6 +33,7 @@ jobs: env: REF_DATA_ROOT: ${{ github.workspace }}/.esgpull/data REF_OUTPUT_ROOT: ${{ github.workspace }}/out + REF_DATABASE_URL: "sqlite:///${{ github.workspace }}/.ref/db/ref.db" steps: - name: Check out repository uses: actions/checkout@v4 @@ -48,7 +49,10 @@ jobs: echo "Rerun after cache generation in tests job" exit 1 - name: docs - run: uv run mkdocs build --strict + run: | + mkdir -p ${{ github.workspace }}/.ref/db + uv run ref ingest --source-type cmip6 .esgpull/data + uv run mkdocs build --strict tests: strategy: diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9a3cdef..32df0aa 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -19,7 +19,8 @@ build: - asdf global uv latest - uv sync --frozen # Fetch test data from ESGF (needed by notebooks) - - uv run esgpull self install $READTHEDOCS_REPOSITORY_PATH/.esgf + - uv run esgpull self install $READTHEDOCS_REPOSITORY_PATH/.esgpull - uv run python scripts/fetch_test_data.py + - uv run ref ingest $READTHEDOCS_REPOSITORY_PATH/.esgpull/data # Run a strict build - - NO_COLOR=1 REF_DATA_ROOT=$READTHEDOCS_REPOSITORY_PATH/.esgf/data uv run mkdocs build --strict --site-dir $READTHEDOCS_OUTPUT/html + - NO_COLOR=1 REF_DATA_ROOT=$READTHEDOCS_REPOSITORY_PATH/.esgpull/data uv run mkdocs build --strict --site-dir $READTHEDOCS_OUTPUT/html diff --git a/README.md b/README.md index d321076..71e39fa 100644 --- a/README.md +++ b/README.md @@ -91,14 +91,26 @@ dependency management. To get started, you will need to make sure that uv is installed ([instructions here](https://docs.astral.sh/uv/getting-started/installation/)). -For all of work, we use our `Makefile`. +We use our `Makefile` to provide an easy way to run common developer commands. You can read the instructions out and run the commands by hand if you wish, but we generally discourage this because it can be error prone. -In order to create your environment, run `make virtual-environment`. -If you wish to run the test suite, -some input data must be fetched from ESGF. -To do this, you will need to run `make fetch-data`. +The following steps are required to set up a development environment. +This will install the required dependencies and fetch some test data, +as well as set up the configuration for the REF. + +```bash +make virtual-environment +uv run esgpull self install $PWD/.esgpull +uv run ref config list > $PWD/.ref/ref.toml +export REF_CONFIGURATION=$PWD/.ref +make fetch-test-data +uv run ref ingest --source-type cmip6 $PWD/.esgpull/data +``` + +The local `ref.toml` configuration file will make it easier to play around with settings. +By default, the database will be stored in your home directory, +this can be modified by changing the `db.database_url` setting in the `ref.toml` file. The test suite can then be run using `make test`. This will run the test suites for each package and finally the integration test suite. diff --git a/changelog/15.feature.md b/changelog/15.feature.md new file mode 100644 index 0000000..81c6093 --- /dev/null +++ b/changelog/15.feature.md @@ -0,0 +1,7 @@ +Added a `DataRequirement` class to declare the requirements for a metric. + +This provides the ability to: + +* filter a data catalog +* group datasets together to be used in a metric calculation +* declare constraints on the data that is required for a metric calculation diff --git a/docs/explanation.md b/docs/explanation.md index d23c92e..89ab0c2 100644 --- a/docs/explanation.md +++ b/docs/explanation.md @@ -23,7 +23,7 @@ An example implementation of a metric provider is provided in the `ref_metrics_e ### Metrics A metric represents a specific calculation or analysis that can be performed on a dataset -or set of datasets with the aim for benchmarking the performance of different models. +or group of datasets with the aim for benchmarking the performance of different models. These metrics often represent a specific aspects of the Earth system and are compared against observations of the same quantities. @@ -40,6 +40,24 @@ The Earth System Metrics and Diagnostics Standards provide a community standard for reporting outputs. This enables the ability to generate standardised outputs that can be distributed. +## Datasets + +The REF aims to support a variety of input datasets, +including CMIP6, CMIP7+, Obs4MIPs, and other observational datasets. + +When ingesting these datasets into the REF, +the metadata used to uniquely describe the datasets is stored in a database. +This metadata includes information such as: + +* the model that produced the dataset +* the experiment that was run +* the variable and units of the data +* the time period of the data + +The facets (or dimensions) of the metadata depend on the dataset type. +This metadata, in combination with the data requirements from a Metric, +are used to determine which new metric executions are required. + ## Execution Environments The REF aims to support the execution of metrics in a variety of environments. diff --git a/docs/how-to-guides/dataset-selection.py b/docs/how-to-guides/dataset-selection.py new file mode 100644 index 0000000..c0faaeb --- /dev/null +++ b/docs/how-to-guides/dataset-selection.py @@ -0,0 +1,216 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.4 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Dataset Selection +# A metric defines the requirements for the data it needs to run. +# The requirements are defined in the `data_requirements` attribute of the metric class. +# +# This notebook provides some examples querying and filtering datasets. + +# %% tags=["hide_code"] +import pandas as pd +from IPython.display import display +from ref_core.datasets import FacetFilter, SourceDatasetType +from ref_core.metrics import DataRequirement + +from ref.cli.config import load_config +from ref.database import Database + +# %% tags=["hide_code"] +config = load_config() +db = Database.from_config(config) + +# %% [markdown] +# +# Each source dataset type has a corresponding adapter that can be used to load the data catalog. +# +# The adapter provides a consistent interface for ingesting +# and querying datasets across different dataset types. +# It contains information such as the columns that are expected. +# %% +from ref.datasets import get_dataset_adapter + +adapter = get_dataset_adapter("cmip6") +adapter + +# %% [markdown] +# Below is an example of a data catalog of the CMIP6 datasets that have already been ingested. +# +# This data catalog contains information about the datasets that are available for use in the metrics. +# The data catalog is a pandas DataFrame that contains information about the datasets, +# such as the variable, source_id, and other metadata. +# +# Each row represents an individual NetCDF file, +# with the rows containing the metadata associated with that file. +# There are ~36 different **facets** of metadata for a CMIP6 data file. +# Each of these facets can be used to refine the datasets that are needed for a given metric execution. + +# %% +data_catalog = adapter.load_catalog(db) +data_catalog + + +# %% [markdown] +# A dataset may consist of more than one file. In the case of CMIP6 datasets, +# the modelling centers who produce the data may chunk a dataset along the time axis. +# The size of these chunks is at the descression of the modelling center. +# +# Datasets share a common set of metadata (see `adapter.dataset_specific_metadata`) +# which do not vary for a given dataset, +# while some facets vary by dataset (`adapter.file_specific_metadata`). +# +# Each data catalog will have a facet that can be used to split the catalog into unique datasets +# (See `adapter.slug_column`). + +# %% +adapter.slug_column + +# %% +for unique_id, dataset_files in data_catalog.groupby(adapter.slug_column): + print(unique_id) + display(dataset_files) + print() + +# %% [markdown] +# Each metric may be run multiple times with different groups of datasets. +# +# Determining which metric executions should be performed is a three-step process: +# 1. Filter the data catalog based on the metric's requirements +# 2. Group the filtered data catalog using unique metadata fields +# 3. Apply constraints to the groups to ensure the correct data is available +# +# Each group that passes the constraints is a valid group for the metric to be executed. +# +# ## Examples +# Below are some examples showing different data requests +# and the corresponding groups of datasets that would be executed. + +# %% +from ref.solver import extract_covered_datasets + + +# %% tags=["hide_code"] +def display_groups(frames): + for frame in frames: + display(frame[["instance_id", "source_id", "variable_id"]].drop_duplicates()) + + +# %% [markdown] +# The simplest data request is a `FacetFilter`. +# This filters the data catalog to include only the data required for a given metric run. + +# %% +data_requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=( + # Only include "tas" and "rsut" + FacetFilter(facets={"variable_id": ("tas", "rsut")}), + ), + group_by=None, +) + +groups = extract_covered_datasets(data_catalog, data_requirement) + +display_groups(groups) + +# %% [markdown] +# The `group_by` field can be used to split the filtered data into multiple groups, +# each of which has a unique set of values in the specified facets. +# This results in multiple groups of datasets, each of which would correspond to a metric execution. + +# %% +data_requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=( + # Only include "tas" and "rsut" + FacetFilter(facets={"variable_id": ("tas", "rsut")}), + ), + group_by=( + "variable_id", + "source_id", + ), +) + +groups = extract_covered_datasets(data_catalog, data_requirement) + +display_groups(groups) + + +# %% [markdown] +# A data requirement can optionally specify `Constraint`s. +# These constraints are applied to each group independtly to modify a group or ignore it. +# All constraints much hold for a group to be executed. +# +# One type of constraint is a `GroupOperation`. +# This constraint allows for the manipulation of a given group. +# This can be used to remove datasets or include additional datasets from the catalog, +# which is useful into select common datasets for all groups (e.g. cell areas). +# +# Below an `IncludeTas` GroupOperation is included which adds the corresponding `tas` dataset to each group. + + +# %% +class IncludeTas: + def apply(self, group: pd.DataFrame, data_catalog: pd.DataFrame) -> pd.DataFrame: + # we will probably need to include some helpers + tas = data_catalog[ + (data_catalog["variable_id"] == "tas") + & data_catalog["source_id"].isin(group["source_id"].unique()) + ] + + return pd.concat([group, tas]) + + +data_requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"frequency": "mon"}),), + group_by=("variable_id", "source_id", "member_id"), + constraints=(IncludeTas(),), +) + +groups = extract_covered_datasets(data_catalog, data_requirement) + +display_groups(groups) + + +# %% [markdown] +# In addition to operations, a `GroupValidator` constraint can be specified. +# This validator is used to determine if a group is valid or not. +# If the validator does not return True, then the group is excluded from the list of groups for execution. + + +# %% +class AtLeast2: + def validate(self, group: pd.DataFrame) -> bool: + return len(group["instance_id"].drop_duplicates()) >= 2 + + +# %% [markdown] +# Here we add a simple validator which ensures that at least 2 unique datasets are present. +# This removes the tas-only group from above. + +# %% +data_requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"frequency": "mon"}),), + group_by=("variable_id", "source_id", "member_id"), + constraints=(IncludeTas(), AtLeast2()), +) + +groups = extract_covered_datasets(data_catalog, data_requirement) + +display_groups(groups) + +# %% diff --git a/docs/how-to-guides/running-metrics-locally.py b/docs/how-to-guides/running-metrics-locally.py index 6943a5f..ade8195 100644 --- a/docs/how-to-guides/running-metrics-locally.py +++ b/docs/how-to-guides/running-metrics-locally.py @@ -71,7 +71,7 @@ # This can be overridden by specifying the `REF_EXECUTOR` environment variable. # %% -result = run_metric("example", provider, configuration=configuration, trigger=trigger) +result = run_metric("global_mean_timeseries", provider, configuration=configuration, trigger=trigger) result # %% @@ -87,7 +87,7 @@ # This will not perform and validation/verification of the output results. # %% -metric = provider.get("example") +metric = provider.get("global_mean_timeseries") direct_result = metric.run(configuration=configuration, trigger=trigger) assert direct_result.successful diff --git a/packages/ref-core/src/ref_core/constraints.py b/packages/ref-core/src/ref_core/constraints.py new file mode 100644 index 0000000..5b47217 --- /dev/null +++ b/packages/ref-core/src/ref_core/constraints.py @@ -0,0 +1,155 @@ +from typing import Protocol, runtime_checkable + +import pandas as pd +from attrs import frozen +from loguru import logger + +from ref_core.exceptions import ConstraintNotSatisfied + + +@runtime_checkable +class GroupValidator(Protocol): + """ + A constraint that must be satisfied when executing a given metric run. + + All constraints must be satisfied for a given group to be run. + """ + + def validate(self, group: pd.DataFrame) -> bool: + """ + Validate if the constraint is satisfied by the dataset. + + This is executed after the apply method to determine if the constraint is satisfied. + If the constraint is not satisfied, the group will not be executed. + + Parameters + ---------- + group + A group of datasets that is being validated. + + Returns + ------- + : + Whether the constraint is satisfied + """ + ... + + +@runtime_checkable +class GroupOperation(Protocol): + """ + An operation to perform on a group of datasets resulting in a new group of datasets. + + !! warning + + Operations should not mutate the input group, but instead return a new group. + Mutating the input group may result in unexpected behaviour. + """ + + def apply(self, group: pd.DataFrame, data_catalog: pd.DataFrame) -> pd.DataFrame: + """ + Perform an operation on the group of datasets. + + A new group of datasets should be returned if modifications are required, + and the input group should not be modified. If no modifications are required, + return the input group unchanged. + If this operation fails, a ConstraintNotSatisfied exception should be raised. + + Parameters + ---------- + group + A group of datasets that is being validated. + data_catalog + The data catalog of datasets + + Raises + ------ + ConstraintNotSatisfied + The operation was not successful + + Returns + ------- + : + The updated group of datasets + """ + ... + + +GroupConstraint = GroupOperation | GroupValidator +""" +A constraint that must be satisfied when executing a given metric run. + +This is applied to a group of datasets representing the inputs to a potential metric execution. +The group must satisfy all constraints to be processed. + +This can include operations that are applied to a group of datasets which may modify the group, +but may also include validators that check if the group satisfies a certain condition. +""" + + +def apply_constraint( + dataframe: pd.DataFrame, constraint: GroupConstraint, data_catalog: pd.DataFrame +) -> pd.DataFrame | None: + """ + Apply a constraint to a group of datasets + + Parameters + ---------- + constraint + The constraint to apply + data_catalog + The data catalog of datasets + + Returns + ------- + : + The updated group of datasets or None if the constraint was not satisfied + """ + try: + updated_group = ( + constraint.apply(dataframe, data_catalog) if isinstance(constraint, GroupOperation) else dataframe + ) + + valid = constraint.validate(updated_group) if isinstance(constraint, GroupValidator) else True + if not valid: + logger.debug(f"Constraint {constraint} not satisfied for {dataframe}") + raise ConstraintNotSatisfied(f"Constraint {constraint} not satisfied for {dataframe}") + except ConstraintNotSatisfied: + logger.debug(f"Constraint {constraint} not satisfied for {dataframe}") + return None + + return updated_group + + +@frozen +class RequireFacets: + """ + A constraint that requires a dataset to have certain facets. + """ + + dimension: str + required_facets: list[str] + + def validate(self, group: pd.DataFrame) -> bool: + """ + Check that the required facets are present in the group + """ + if self.dimension not in group: + logger.warning(f"Dimension {self.dimension} not present in group {group}") + return False + return all(value in group[self.dimension].values for value in self.required_facets) + + +@frozen +class SelectParentExperiment: + """ + Include a dataset's parent experiment in the selection + """ + + def apply(self, group: pd.DataFrame, data_catalog: pd.DataFrame) -> pd.DataFrame: + """ + Include a dataset's parent experiment in the selection + + Not yet implemented + """ + raise NotImplementedError("This is not implemented yet") # pragma: no cover diff --git a/packages/ref-core/src/ref_core/datasets.py b/packages/ref-core/src/ref_core/datasets.py index 57ee1d7..1105a87 100644 --- a/packages/ref-core/src/ref_core/datasets.py +++ b/packages/ref-core/src/ref_core/datasets.py @@ -1,5 +1,7 @@ import enum +from attrs import field, frozen + class SourceDatasetType(enum.Enum): """ @@ -8,3 +10,41 @@ class SourceDatasetType(enum.Enum): CMIP6 = "cmip6" CMIP7 = "cmip7" + + +def _clean_facets(raw_values: dict[str, str | tuple[str, ...] | list[str]]) -> dict[str, tuple[str, ...]]: + """ + Clean the value of a facet filter to a tuple of strings + """ + result = {} + + for key, value in raw_values.items(): + if isinstance(value, list): + result[key] = tuple(value) + elif isinstance(value, str): + result[key] = (value,) + elif isinstance(value, tuple): + result[key] = value + return result + + +@frozen +class FacetFilter: + """ + A filter to apply to a data catalog of datasets. + """ + + facets: dict[str, tuple[str, ...]] = field(converter=_clean_facets) + """ + Filters to apply to the data catalog. + + The keys are the metadata fields to filter on, and the values are the values to filter on. + The result will only contain datasets where for all fields, + the value of the field is one of the given values. + """ + keep: bool = True + """ + Whether to keep or remove datasets that match the filter. + + If true (default), datasets that match the filter will be kept else they will be removed. + """ diff --git a/packages/ref-core/src/ref_core/exceptions.py b/packages/ref-core/src/ref_core/exceptions.py index 2238386..a3af6e1 100644 --- a/packages/ref-core/src/ref_core/exceptions.py +++ b/packages/ref-core/src/ref_core/exceptions.py @@ -28,3 +28,9 @@ def __init__(self, dataset_path: pathlib.Path, root_path: pathlib.Path) -> None: ) super().__init__(message) + + +class ConstraintNotSatisfied(RefException): + """Exception raised when a constraint is violated""" + + # TODO: implement when we have agreed on using constraints diff --git a/packages/ref-core/src/ref_core/metrics.py b/packages/ref-core/src/ref_core/metrics.py index 2d2bb02..5cc4c51 100644 --- a/packages/ref-core/src/ref_core/metrics.py +++ b/packages/ref-core/src/ref_core/metrics.py @@ -2,7 +2,11 @@ import pathlib from typing import Any, Protocol, runtime_checkable -from attrs import frozen +import pandas as pd +from attrs import field, frozen + +from ref_core.constraints import GroupConstraint +from ref_core.datasets import FacetFilter, SourceDatasetType @frozen @@ -87,6 +91,80 @@ class TriggerInfo: # dataset metadata +@frozen(hash=True) +class DataRequirement: + """ + Definition of the input datasets that a metric requires to run. + + This is used to create groups of datasets. + Each group will result in an execution of the metric + and defines the input data for that execution. + + The data catalog is filtered according to the `filters` field, + then grouped according to the `group_by` field, + and then each group is checked that it satisfies the `constraints`. + Each such group will be processed as a separate execution of the metric. + """ + + source_type: SourceDatasetType + """ + Type of the source dataset (CMIP6, CMIP7 etc) + """ + + filters: tuple[FacetFilter, ...] + """ + Filters to apply to the data catalog of datasets. + + This is used to reduce the set of datasets to only those that are required by the metric. + The filters are applied iteratively to reduce the set of datasets. + """ + + group_by: tuple[str, ...] | None + """ + The fields to group the datasets by. + + This groupby operation is performed after the data catalog is filtered according to `filters`. + Each group will contain a unique combination of values from the metadata fields, + and will result in a separate execution of the metric. + If `group_by=None`, all datasets will be processed together as a single execution. + """ + + constraints: tuple[GroupConstraint, ...] = field(factory=tuple) + """ + Constraints that must be satisfied when executing a given metric run + + All of the constraints must be satisfied for a given group to be run. + Each filter is applied iterative to a set of datasets to reduce the set of datasets. + This is effectively an AND operation. + """ + + def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame: + """ + Apply filters to a DataFrame-based data catalog. + + Parameters + ---------- + data_catalog + DataFrame to filter. + Each column contains a facet + + Returns + ------- + : + Filtered data catalog + """ + for facet_filter in self.filters: + for facet, value in facet_filter.facets.items(): + clean_value = value if isinstance(value, tuple) else (value,) + + mask = data_catalog[facet].isin(clean_value) + if not facet_filter.keep: + mask = ~mask + + data_catalog = data_catalog[mask] + return data_catalog + + @runtime_checkable class Metric(Protocol): """ @@ -97,6 +175,14 @@ class Metric(Protocol): The configuration and output of the metric should follow the Earth System Metrics and Diagnostics Standards formats as much as possible. + A metric can be executed multiple times, + each time targeting a different group of input data. + The groups are determined using the grouping the data catalog according to the `group_by` field + in the `DataRequirement` object using one or more metadata fields. + Each group must conform with a set of constraints, + to ensure that the correct data is available to run the metric. + Each group will then be processed as a separate execution of the metric. + See (ref_example.example.ExampleMetric)[] for an example implementation. """ @@ -108,18 +194,14 @@ class Metric(Protocol): but multiple providers can implement the same metric. """ - # input_variable: list[VariableDefinition] + data_requirements: tuple[DataRequirement, ...] """ - TODO: implement VariableDefinition - Should be extend the configuration defined in EMDS + Description of the required datasets for the current metric - Variables that the metric requires to run - Any modifications to the input data will trigger a new metric calculation. - """ - # observation_dataset: list[ObservationDatasetDefinition] - """ - TODO: implement ObservationDatasetDefinition - Should be extend the configuration defined in EMDS. To check with Bouwe. + This information is used to filter the a data catalog of both CMIP and/or observation datasets + that are required by the metric. + + Any modifications to the input data will new metric calculation. """ def run(self, configuration: Configuration, trigger: TriggerInfo | None) -> MetricResult: diff --git a/packages/ref-core/tests/conftest.py b/packages/ref-core/tests/conftest.py index d4f3143..7b6164f 100644 --- a/packages/ref-core/tests/conftest.py +++ b/packages/ref-core/tests/conftest.py @@ -1,11 +1,15 @@ import pytest -from ref_core.metrics import Configuration, MetricResult, TriggerInfo +from ref_core.datasets import SourceDatasetType +from ref_core.metrics import Configuration, DataRequirement, MetricResult, TriggerInfo from ref_core.providers import MetricsProvider class MockMetric: name = "mock" + # This runs on every dataset + data_requirements = (DataRequirement(source_type=SourceDatasetType.CMIP6, filters=(), group_by=None),) + def run(self, configuration: Configuration, trigger: TriggerInfo) -> MetricResult: return MetricResult( output_bundle=configuration.output_directory / "output.json", @@ -16,6 +20,8 @@ def run(self, configuration: Configuration, trigger: TriggerInfo) -> MetricResul class FailedMetric: name = "failed" + data_requirements = (DataRequirement(source_type=SourceDatasetType.CMIP6, filters=(), group_by=None),) + def run(self, configuration: Configuration, trigger: TriggerInfo) -> MetricResult: return MetricResult( successful=False, diff --git a/packages/ref-core/tests/unit/test_constraints.py b/packages/ref-core/tests/unit/test_constraints.py new file mode 100644 index 0000000..e276396 --- /dev/null +++ b/packages/ref-core/tests/unit/test_constraints.py @@ -0,0 +1,140 @@ +import pandas as pd +import pytest +from ref_core.constraints import ( + GroupOperation, + GroupValidator, + RequireFacets, + SelectParentExperiment, + apply_constraint, +) +from ref_core.exceptions import ConstraintNotSatisfied + + +@pytest.fixture +def data_catalog(): + return pd.DataFrame( + { + "variable": ["tas", "pr", "rsut", "tas", "tas"], + "source_id": ["CESM2", "CESM2", "CESM2", "ACCESS", "CAS"], + } + ) + + +class TestRequireFacets: + validator = RequireFacets(dimension="variable_id", required_facets=["tas", "pr"]) + + def test_is_group_validator(self): + assert isinstance(self.validator, GroupValidator) + assert not isinstance(self.validator, GroupOperation) + + @pytest.mark.parametrize( + "data, expected", + [ + (pd.DataFrame({}), False), + (pd.DataFrame({"invalid": ["tas", "pr"]}), False), + (pd.DataFrame({"variable_id": ["tas", "pr"]}), True), + (pd.DataFrame({"variable_id": ["tas", "pr"], "extra": ["a", "b"]}), True), + (pd.DataFrame({"variable_id": ["tas"]}), False), + (pd.DataFrame({"variable_id": ["tas"], "extra": ["a"]}), False), + ], + ) + def test_validate(self, data, expected): + assert self.validator.validate(data) == expected + + +class TestSelectParentExperiment: + def test_is_group_validator(self): + validator = SelectParentExperiment() + + assert isinstance(validator, GroupOperation) + assert not isinstance(validator, GroupValidator) + + +def test_apply_constraint_operation(data_catalog): + # operation that appends the "rsut" variable to the group + class ExampleOperation(GroupOperation): + def apply(self, group: pd.DataFrame, data_catalog: pd.DataFrame) -> pd.DataFrame: + return pd.concat([group, data_catalog[data_catalog["variable"] == "rsut"]]) + + result = apply_constraint( + data_catalog[data_catalog["variable"] == "tas"], + ExampleOperation(), + data_catalog, + ) + + pd.testing.assert_frame_equal( + result, + pd.DataFrame( + { + "variable": ["tas", "tas", "tas", "rsut"], + "source_id": ["CESM2", "ACCESS", "CAS", "CESM2"], + }, + index=[0, 3, 4, 2], + ), + ) + + +def test_apply_constraint_operation_mutable(data_catalog): + class MutableOperation(GroupOperation): + def apply(self, group: pd.DataFrame, data_catalog: pd.DataFrame) -> pd.DataFrame: + group["variable"] = "new" + return group + + orig_data_catalog = data_catalog.copy() + result = apply_constraint( + data_catalog, + MutableOperation(), + None, + ) + + assert (result["variable"] == "new").all() + + # Mutating the group impacts the original data catalog + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(data_catalog, orig_data_catalog) + + +def test_apply_constraint_operation_raises(): + class RaisesOperation(GroupOperation): + def apply(self, group: pd.DataFrame, data_catalog: pd.DataFrame) -> pd.DataFrame: + raise ConstraintNotSatisfied("Test exception") + + assert ( + apply_constraint( + pd.DataFrame(), + RaisesOperation(), + pd.DataFrame(), + ) + is None + ) + + +def test_apply_constraint_empty(): + assert ( + apply_constraint( + pd.DataFrame(), + RequireFacets(dimension="variable_id", required_facets=["tas", "pr"]), + pd.DataFrame(), + ) + is None + ) + + +def test_apply_constraint_validate(data_catalog): + result = apply_constraint( + data_catalog, + RequireFacets(dimension="variable", required_facets=["tas", "pr"]), + pd.DataFrame(), + ) + pd.testing.assert_frame_equal(result, data_catalog) + + +def test_apply_constraint_validate_invalid(data_catalog): + assert ( + apply_constraint( + data_catalog, + RequireFacets(dimension="variable", required_facets=["missing", "pr"]), + pd.DataFrame(), + ) + is None + ) diff --git a/packages/ref-core/tests/unit/test_metrics.py b/packages/ref-core/tests/unit/test_metrics.py index d08b557..1ff00f4 100644 --- a/packages/ref-core/tests/unit/test_metrics.py +++ b/packages/ref-core/tests/unit/test_metrics.py @@ -1,4 +1,7 @@ -from ref_core.metrics import Configuration, MetricResult +import pandas as pd +import pytest +from ref_core.datasets import FacetFilter, SourceDatasetType +from ref_core.metrics import Configuration, DataRequirement, MetricResult class TestMetricResult: @@ -13,3 +16,107 @@ def test_build(self, tmp_path): assert f.read() == '{"data": "value"}' assert result.output_bundle.is_relative_to(tmp_path) + + +@pytest.fixture +def apply_data_catalog(): + return pd.DataFrame( + { + "variable": ["tas", "pr", "rsut", "tas", "tas"], + "source_id": ["CESM2", "CESM2", "CESM2", "ACCESS", "CAS"], + } + ) + + +@pytest.mark.parametrize( + "facet_filter, expected_data, expected_index", + [ + ( + {"variable": "tas"}, + { + "variable": ["tas", "tas", "tas"], + "source_id": [ + "CESM2", + "ACCESS", + "CAS", + ], + }, + [0, 3, 4], + ), + ( + {"variable": "tas", "source_id": ["CESM2", "ACCESS"]}, + { + "variable": ["tas", "tas"], + "source_id": [ + "CESM2", + "ACCESS", + ], + }, + [0, 3], + ), + ], +) +def test_apply_filters_single(apply_data_catalog, facet_filter, expected_data, expected_index): + requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facet_filter),), + group_by=None, + ) + + filtered = requirement.apply_filters(apply_data_catalog) + + pd.testing.assert_frame_equal( + filtered, + pd.DataFrame( + expected_data, + index=expected_index, + ), + ) + + +def test_apply_filters_multi(apply_data_catalog): + requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=( + FacetFilter({"variable": "tas"}), + FacetFilter({"source_id": "ACCESS"}, keep=False), + ), + group_by=None, + ) + + filtered = requirement.apply_filters(apply_data_catalog) + + pd.testing.assert_frame_equal( + filtered, + pd.DataFrame( + { + "variable": ["tas", "tas"], + "source_id": ["CESM2", "CAS"], + }, + index=[0, 4], + ), + ) + + +def test_apply_filters_dont_keep(apply_data_catalog): + requirement = DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter({"variable": "tas"}, keep=False),), + group_by=None, + ) + + filtered = requirement.apply_filters(apply_data_catalog) + + pd.testing.assert_frame_equal( + filtered, + pd.DataFrame( + { + "variable": ["pr", "rsut"], + "source_id": [ + "CESM2", + "CESM2", + ], + }, + index=[1, 2], + ), + ) diff --git a/packages/ref-metrics-example/src/ref_metrics_example/__init__.py b/packages/ref-metrics-example/src/ref_metrics_example/__init__.py index f645a72..d0ab9d5 100644 --- a/packages/ref-metrics-example/src/ref_metrics_example/__init__.py +++ b/packages/ref-metrics-example/src/ref_metrics_example/__init__.py @@ -6,11 +6,11 @@ from ref_core.providers import MetricsProvider -from ref_metrics_example.example import ExampleMetric +from ref_metrics_example.example import GlobalMeanTimeseries __version__ = importlib.metadata.version("ref_metrics_example") __core_version__ = importlib.metadata.version("ref_core") # Initialise the metrics manager and register the example metric provider = MetricsProvider("example", __version__) -provider.register(ExampleMetric()) +provider.register(GlobalMeanTimeseries()) diff --git a/packages/ref-metrics-example/src/ref_metrics_example/example.py b/packages/ref-metrics-example/src/ref_metrics_example/example.py index 0e378c8..0bc18e8 100644 --- a/packages/ref-metrics-example/src/ref_metrics_example/example.py +++ b/packages/ref-metrics-example/src/ref_metrics_example/example.py @@ -2,7 +2,8 @@ from typing import Any import xarray as xr -from ref_core.metrics import Configuration, MetricResult, TriggerInfo +from ref_core.datasets import FacetFilter, SourceDatasetType +from ref_core.metrics import Configuration, DataRequirement, Metric, MetricResult, TriggerInfo def calculate_annual_mean_timeseries(dataset: Path) -> xr.Dataset: @@ -74,12 +75,27 @@ def format_cmec_output_bundle(dataset: xr.Dataset) -> dict[str, Any]: return cmec_output -class ExampleMetric: +class GlobalMeanTimeseries(Metric): """ Calculate the annual mean global mean timeseries for a dataset """ - name = "example" + name = "global_mean_timeseries" + + data_requirements = ( + DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=( + FacetFilter(facets={"variable_id": ("tas", "rsut")}), + # Ignore some experiments because they are not relevant + FacetFilter(facets={"experiment_id": ("1pctCO2-*", "hist-*")}, keep=False), + ), + # Add cell areas to the groups + # constraints=(AddCellAreas(),), + # Run the metric on each unique combination of model, variable, experiment, and variant + group_by=("model_id", "variable_id", "experiment_id", "variant_label"), + ), + ) def run(self, configuration: Configuration, trigger: TriggerInfo | None) -> MetricResult: """ @@ -105,7 +121,7 @@ def run(self, configuration: Configuration, trigger: TriggerInfo | None) -> Metr successful=False, ) - # This is where one would hook into how ever they want to run + # This is where one would hook into however they want to run # their benchmarking packages. # cmec-driver, python calls, subprocess calls all would work annual_mean_global_mean_timeseries = calculate_annual_mean_timeseries(trigger.dataset) diff --git a/packages/ref-metrics-example/tests/unit/test_metrics.py b/packages/ref-metrics-example/tests/unit/test_metrics.py index d9f245f..aa1b300 100644 --- a/packages/ref-metrics-example/tests/unit/test_metrics.py +++ b/packages/ref-metrics-example/tests/unit/test_metrics.py @@ -2,7 +2,7 @@ import pytest from ref_core.metrics import Configuration, TriggerInfo -from ref_metrics_example.example import ExampleMetric, calculate_annual_mean_timeseries +from ref_metrics_example.example import GlobalMeanTimeseries, calculate_annual_mean_timeseries @pytest.fixture @@ -29,7 +29,7 @@ def test_annual_mean(esgf_data_dir, test_dataset): def test_example_metric(tmp_path, test_dataset): - metric = ExampleMetric() + metric = GlobalMeanTimeseries() configuration = Configuration( output_directory=tmp_path, @@ -44,7 +44,7 @@ def test_example_metric(tmp_path, test_dataset): def test_example_metric_no_trigger(tmp_path, test_dataset): - metric = ExampleMetric() + metric = GlobalMeanTimeseries() configuration = Configuration( output_directory=tmp_path, diff --git a/packages/ref/src/ref/cli/config.py b/packages/ref/src/ref/cli/config.py index 489f2ed..71e6386 100644 --- a/packages/ref/src/ref/cli/config.py +++ b/packages/ref/src/ref/cli/config.py @@ -12,7 +12,7 @@ app = typer.Typer(help=__doc__) -def load_config(configuration_directory: Path | None) -> Config: +def load_config(configuration_directory: Path | None = None) -> Config: """ Load the configuration from the specified directory diff --git a/packages/ref/src/ref/cli/ingest.py b/packages/ref/src/ref/cli/ingest.py index a3026da..4f1606a 100644 --- a/packages/ref/src/ref/cli/ingest.py +++ b/packages/ref/src/ref/cli/ingest.py @@ -76,7 +76,7 @@ def ingest( This will register a dataset in the database to be used for metrics calculations. """ config = load_config(configuration_directory) - db = Database(config.db.database_url) + db = Database.from_config(config) logger.info(f"ingesting {file_or_directory}") diff --git a/packages/ref/src/ref/cli/solve.py b/packages/ref/src/ref/cli/solve.py index 98dbbed..0c75cbb 100644 --- a/packages/ref/src/ref/cli/solve.py +++ b/packages/ref/src/ref/cli/solve.py @@ -21,6 +21,6 @@ def solve( since the last solve. """ config = load_config(configuration_directory) - db = Database(config.db.database_url) + db = Database.from_config(config) solve_metrics(db, dry_run=dry_run) diff --git a/packages/ref/src/ref/config.py b/packages/ref/src/ref/config.py index 4b8e685..e92760e 100644 --- a/packages/ref/src/ref/config.py +++ b/packages/ref/src/ref/config.py @@ -38,7 +38,9 @@ class Paths: log: Path = field(converter=Path) tmp: Path = field(converter=Path) - allow_out_of_tree_datasets: bool = field(default=False) + # TODO: this should probably default to False, + # but we don't have an easy way to update cong + allow_out_of_tree_datasets: bool = field(default=True) @data.default def _data_factory(self) -> Path: diff --git a/packages/ref/src/ref/datasets/__init__.py b/packages/ref/src/ref/datasets/__init__.py index 4aaec6d..e90fd00 100644 --- a/packages/ref/src/ref/datasets/__init__.py +++ b/packages/ref/src/ref/datasets/__init__.py @@ -24,7 +24,7 @@ def get_dataset_adapter(source_type: str) -> "DatasetAdapter": : DatasetAdapter instance """ - if source_type == SourceDatasetType.CMIP6.value: + if source_type.lower() == SourceDatasetType.CMIP6.value: from ref.datasets.cmip6 import CMIP6DatasetAdapter return CMIP6DatasetAdapter() diff --git a/packages/ref/src/ref/solver.py b/packages/ref/src/ref/solver.py index 0393290..ccc3d1b 100644 --- a/packages/ref/src/ref/solver.py +++ b/packages/ref/src/ref/solver.py @@ -11,14 +11,54 @@ import pandas as pd from attrs import define from loguru import logger +from ref_core.constraints import apply_constraint from ref_core.datasets import SourceDatasetType -from ref_core.metrics import Metric +from ref_core.metrics import DataRequirement, Metric from ref_core.providers import MetricsProvider from ref.database import Database from ref.provider_registry import ProviderRegistry +def extract_covered_datasets(data_catalog: pd.DataFrame, requirement: DataRequirement) -> list[pd.DataFrame]: + """ + Determine the different metric executions that should be performed with the current data catalog + """ + subset = requirement.apply_filters(data_catalog) + + if len(subset) == 0: + logger.debug(f"No datasets found for requirement {requirement}") + return [] + + if requirement.group_by is None: + # Use a single group + groups = [(None, subset)] + else: + groups = subset.groupby(list(requirement.group_by)) # type: ignore + + results = [] + + for name, group in groups: + constrained_group = _process_group_constraints(data_catalog, group, requirement) + + if constrained_group is not None: + results.append(constrained_group) + + return results + + +def _process_group_constraints( + data_catalog: pd.DataFrame, group: pd.DataFrame, requirement: DataRequirement +) -> pd.DataFrame | None: + for constraint in requirement.constraints or []: + constrained_group = apply_constraint(group, constraint, data_catalog) + if constrained_group is None: + return None + + group = constrained_group + return group + + @define class MetricSolver: """ @@ -51,6 +91,8 @@ def _can_solve(self, metric: Metric) -> bool: This should probably be passed via DI """ + # TODO: Implement this method + # TODO: wrap the result in a class representing a metric run return True def _find_solvable(self) -> typing.Generator[tuple[MetricsProvider, Metric], None, None]: @@ -65,7 +107,7 @@ def _find_solvable(self) -> typing.Generator[tuple[MetricsProvider, Metric], Non for provider in self.provider_registry.providers: for metric in provider.metrics(): if self._can_solve(metric): - yield provider, metric + yield (provider, metric) def solve(self, dry_run: bool = False, max_iterations: int = 10) -> None: """ diff --git a/packages/ref/tests/unit/cli/test_root.py b/packages/ref/tests/unit/cli/test_root.py index 2e5b693..77ced33 100644 --- a/packages/ref/tests/unit/cli/test_root.py +++ b/packages/ref/tests/unit/cli/test_root.py @@ -23,7 +23,7 @@ def test_version(): def test_verbose(): - exp_log = "| DEBUG | ref.config:default:176 - Loading default configuration from" + exp_log = "| DEBUG | ref.config:default:178 - Loading default configuration from" result = runner.invoke( app, ["--verbose", "config", "list"], diff --git a/packages/ref/tests/unit/test_config.py b/packages/ref/tests/unit/test_config.py index 593cb38..a664d51 100644 --- a/packages/ref/tests/unit/test_config.py +++ b/packages/ref/tests/unit/test_config.py @@ -96,7 +96,7 @@ def test_defaults(self, monkeypatch): "data": "test/data", "log": "test/log", "tmp": "test/tmp", - "allow_out_of_tree_datasets": False, + "allow_out_of_tree_datasets": True, }, "db": {"database_url": "sqlite:///test/db/ref.db", "run_migrations": True}, } diff --git a/packages/ref/tests/unit/test_solver.py b/packages/ref/tests/unit/test_solver.py index 23a7ab1..249bf2b 100644 --- a/packages/ref/tests/unit/test_solver.py +++ b/packages/ref/tests/unit/test_solver.py @@ -1,7 +1,11 @@ +import pandas as pd import pytest +from ref_core.constraints import RequireFacets, SelectParentExperiment +from ref_core.datasets import SourceDatasetType +from ref_core.metrics import DataRequirement, FacetFilter from ref.provider_registry import ProviderRegistry -from ref.solver import MetricSolver +from ref.solver import MetricSolver, extract_covered_datasets @pytest.fixture @@ -24,3 +28,146 @@ def test_solver_solve_empty(self, solver): solver.solve() # TODO: Check that nothing was solved + + +@pytest.mark.parametrize( + "requirement,data_catalog,expected", + [ + pytest.param( + DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"variable_id": "missing"}),), + group_by=("variable_id", "experiment_id"), + ), + pd.DataFrame( + { + "variable_id": ["tas", "tas", "pr"], + "experiment_id": ["ssp119", "ssp126", "ssp119"], + "variant_label": ["r1i1p1f1", "r1i1p1f1", "r1i1p1f1"], + } + ), + [], + id="empty", + ), + pytest.param( + DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"variable_id": "tas"}),), + group_by=("variable_id", "experiment_id"), + ), + pd.DataFrame( + { + "variable_id": ["tas", "tas", "pr"], + "experiment_id": ["ssp119", "ssp126", "ssp119"], + "variant_label": ["r1i1p1f1", "r1i1p1f1", "r1i1p1f1"], + } + ), + [ + pd.DataFrame( + { + "variable_id": ["tas"], + "experiment_id": ["ssp119"], + "variant_label": ["r1i1p1f1"], + }, + index=[0], + ), + pd.DataFrame( + { + "variable_id": ["tas"], + "experiment_id": ["ssp126"], + "variant_label": ["r1i1p1f1"], + }, + index=[1], + ), + ], + id="simple-filter", + ), + pytest.param( + DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"variable_id": ("tas", "pr")}),), + group_by=("experiment_id",), + ), + pd.DataFrame( + { + "variable_id": ["tas", "tas", "pr"], + "experiment_id": ["ssp119", "ssp126", "ssp119"], + } + ), + [ + pd.DataFrame( + { + "variable_id": ["tas", "pr"], + "experiment_id": ["ssp119", "ssp119"], + }, + index=[0, 2], + ), + pd.DataFrame( + { + "variable_id": ["tas"], + "experiment_id": ["ssp126"], + }, + index=[1], + ), + ], + id="simple-or", + ), + pytest.param( + DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"variable_id": ("tas", "pr")}),), + constraints=(SelectParentExperiment(),), + group_by=("variable_id", "experiment_id"), + ), + pd.DataFrame( + { + "variable_id": ["tas", "tas"], + "experiment_id": ["ssp119", "historical"], + "parent_experiment_id": ["historical", "none"], + } + ), + [ + pd.DataFrame( + { + "variable_id": ["tas", "tas"], + "experiment_id": ["historical", "ssp119"], + }, + # The order of the rows is not guaranteed + index=[1, 0], + ), + ], + marks=[pytest.mark.xfail(reason="Parent experiment not implemented")], + id="parent", + ), + pytest.param( + DataRequirement( + source_type=SourceDatasetType.CMIP6, + filters=(FacetFilter(facets={"variable_id": ("tas", "pr")}),), + constraints=(RequireFacets(dimension="variable_id", required_facets=["tas", "pr"]),), + group_by=("experiment_id",), + ), + pd.DataFrame( + { + "variable_id": ["tas", "tas", "pr"], + "experiment_id": ["ssp119", "ssp126", "ssp119"], + } + ), + [ + pd.DataFrame( + { + "variable_id": ["tas", "pr"], + "experiment_id": ["ssp119", "ssp119"], + }, + index=[0, 2], + ), + ], + id="simple-validation", + ), + ], +) +def test_data_coverage(requirement, data_catalog, expected): + result = extract_covered_datasets(data_catalog, requirement) + + for res, exp in zip(result, expected): + pd.testing.assert_frame_equal(res, exp) + assert len(result) == len(expected) diff --git a/pyproject.toml b/pyproject.toml index b2b7548..47c7f9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,8 @@ show_missing = true exclude_also = [ # Don't complain about missing type checking code: "if TYPE_CHECKING", + # Exclude ... literals + "\\.\\.\\." ] [tool.mypy] diff --git a/ruff.toml b/ruff.toml index 1fa58ba..9268df3 100644 --- a/ruff.toml +++ b/ruff.toml @@ -37,16 +37,10 @@ ignore = [ "PLR2004" # Magic value used in comparison ] "docs/*" = [ - "D100", # Missing docstring at the top of file - "E402", # Module level import not at top of file - "S101", # Use of `assert` detected -] -"notebooks/*" = [ - "D100", # Missing docstring at the top of file - "D103", # Missing docstring in public function + "D", "E402", # Module level import not at top of file "S101", # Use of `assert` detected - "PD901", # `df` is a bad variable name. + "PLR2004", # Magic value used in comparison. ] "*/alembic/versions/*" = [ "D103", # Missing docstring in public function