Skip to content

Commit

Permalink
feat: Make a more real-world example
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjared committed Oct 31, 2024
1 parent 55b108c commit d734028
Show file tree
Hide file tree
Showing 7 changed files with 460 additions and 117 deletions.
2 changes: 1 addition & 1 deletion packages/ref-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Topic :: Scientific/Engineering",
]
dependencies = [
"pydantic>=2.0"
"attrs"
]

[tool.uv]
Expand Down
41 changes: 36 additions & 5 deletions packages/ref-core/src/ref_core/metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import pathlib
from typing import Protocol, runtime_checkable

from pydantic import BaseModel
from attrs import frozen


class Configuration(BaseModel):
@frozen
class Configuration:
"""
Configuration that describes the input data sources
"""
Expand All @@ -17,7 +19,8 @@ class Configuration(BaseModel):
# TODO: Add more configuration options here


class MetricResult(BaseModel):
@frozen
class MetricResult:
"""
The result of running a metric.
Expand All @@ -27,7 +30,7 @@ class MetricResult(BaseModel):

# Do we want to load a serialised version of the output bundle here or just a file path?

output_bundle: pathlib.Path
output_bundle: pathlib.Path | None
"""
Path to the output bundle file.
Expand All @@ -40,8 +43,36 @@ class MetricResult(BaseModel):
"""
# Log info is in the output bundle file already, but is definitely useful

@staticmethod
def build(configuration: Configuration, cmec_output_bundle: dict) -> "MetricResult":
"""
Build a MetricResult from a CMEC output bundle.
Parameters
----------
configuration
The configuration used to run the metric.
cmec_output_bundle
An output bundle in the CMEC format.
TODO: This needs a better type hint
Returns
-------
:
A prepared MetricResult object.
The output bundle will be written to the output directory.
"""
with open(configuration.output_directory / "output.json", "w") as file_handle:
json.dump(cmec_output_bundle, file_handle)
return MetricResult(
output_bundle=configuration.output_directory / "output.json",
successful=True,
)


class TriggerInfo(BaseModel):
@frozen
class TriggerInfo:
"""
The reason why the metric was run.
"""
Expand Down
5 changes: 4 additions & 1 deletion packages/ref-metrics-example/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ classifiers = [
"Topic :: Scientific/Engineering",
]
dependencies = [
"ref-core"
"ref-core",
"xarray >= 2022",
"netcdf4",
"dask>=2024.10.0",
]

[tool.uv]
Expand Down
93 changes: 83 additions & 10 deletions packages/ref-metrics-example/src/ref_metrics_example/example.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,110 @@
import json
from pathlib import Path

import xarray as xr
from ref_core.metrics import Configuration, MetricResult, TriggerInfo


def calculate_annual_mean_timeseries(dataset: Path) -> xr.Dataset:
"""
Calculate the annual mean timeseries for a dataset.
While this function is implemented here,
in most cases the metric calculation will be in the underlying benchmarking package.
How the metric is calculated is up to the provider.
Parameters
----------
dataset
A path to a CMIP6 dataset.
This dataset may consist of multiple data files.
Returns
-------
:
The annual mean timeseries of the dataset
"""
input_files = dataset.glob("*.nc")

dataset = xr.open_mfdataset(list(input_files), combine="by_coords", chunks=None)

annual_mean = dataset.resample(time="YS").mean()
return annual_mean.mean(dim=["lat", "lon"], keep_attrs=True)


def format_cmec_output_bundle(dataset: xr.Dataset) -> dict:
"""
Create a simple CMEC output bundle for the dataset.
Parameters
----------
dataset
Processed dataset
Returns
-------
A CMEC output bundle ready to be written to disk
"""
cmec_output = {
"DIMENSIONS": {
"dimensions": {
"source_id": {dataset.attrs["source_id"]: {}},
"region": {"global": {}},
"variable": {"tas": {}},
},
"json_structure": [
"model",
"region",
"statistic",
],
},
# Is the schema tracked?
"SCHEMA": {
"name": "CMEC-REF",
"package": "example",
"version": "v1",
},
"RESULTS": {
dataset.attrs["source_id"]: {"global": {"tas": ""}},
},
}

return cmec_output


class ExampleMetric:
"""
Example metric that does nothing but count the number of times it has been run.
"""

name = "example"

def __init__(self) -> None:
self._count = 0

def run(self, configuration: Configuration, trigger: TriggerInfo | None) -> MetricResult:
"""
Run a metric
Parameters
----------
trigger
Trigger for what caused the metric to be executed.
configuration
Configuration object
Returns
-------
:
The result of running the metric.
"""
self._count += 1
if trigger is None:
# TODO: This should probably raise an exception
return MetricResult(
output_bundle=configuration.output_directory / "output.json",
successful=False,
)

with open(configuration.output_directory / "output.json", "w") as fh:
json.dump(({"count": self._count}), fh)
annual_mean_global_mean_timeseries = calculate_annual_mean_timeseries(trigger.dataset)

return MetricResult(
output_bundle=configuration.output_directory / "output.json",
successful=True,
return MetricResult.build(
configuration, format_cmec_output_bundle(annual_mean_global_mean_timeseries)
)
49 changes: 45 additions & 4 deletions packages/ref-metrics-example/tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,56 @@
from ref_core.metrics import Configuration
from ref_metrics_example.example import ExampleMetric
from pathlib import Path

import pytest
from ref_core.metrics import Configuration, TriggerInfo
from ref_metrics_example.example import ExampleMetric, calculate_annual_mean_timeseries

def test_example_metric(tmp_path):

@pytest.fixture
def test_dataset(esgf_data_dir) -> Path:
return (
esgf_data_dir
/ "CMIP6"
/ "ScenarioMIP"
/ "CSIRO"
/ "ACCESS-ESM1-5"
/ "ssp245"
/ "r1i1p1f1"
/ "Amon"
/ "tas"
/ "gn"
/ "v20191115"
)


def test_annual_mean(esgf_data_dir):
input_files = list(
(
esgf_data_dir
/ "CMIP6"
/ "ScenarioMIP"
/ "CSIRO"
/ "ACCESS-ESM1-5"
/ "ssp245"
/ "r1i1p1f1"
/ "Amon"
/ "tas"
/ "gn"
/ "v20191115"
).glob("*.nc")
)
annual_mean = calculate_annual_mean_timeseries(input_files)

assert annual_mean.time.size == 86


def test_example_metric(tmp_path, test_dataset):
metric = ExampleMetric()

configuration = Configuration(
output_directory=tmp_path,
)

result = metric.run(configuration, trigger=None)
result = metric.run(configuration, trigger=TriggerInfo(dataset=test_dataset))

assert result.successful
assert result.output_bundle.exists()
Expand Down
3 changes: 1 addition & 2 deletions scripts/fetch_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def queue_esgf_download( # noqa: PLR0913
]
)

res = subprocess.run(
subprocess.run(
[
"esgpull",
"update",
Expand All @@ -78,7 +78,6 @@ def queue_esgf_download( # noqa: PLR0913
input=b"y",
check=False,
)
res.check_returncode()

return search_tag

Expand Down
Loading

0 comments on commit d734028

Please sign in to comment.