diff --git a/packages/ref-metrics-example/src/ref_metrics_example/__init__.py b/packages/ref-metrics-example/src/ref_metrics_example/__init__.py index 76b43ba..1e6590a 100644 --- a/packages/ref-metrics-example/src/ref_metrics_example/__init__.py +++ b/packages/ref-metrics-example/src/ref_metrics_example/__init__.py @@ -12,5 +12,5 @@ __core_version__ = importlib.metadata.version("ref_core") # Initialise the metrics manager and register the example metric -provider = MetricsProvider("Example", "example", __version__) +provider = MetricsProvider("Example", __version__) provider.register(GlobalMeanTimeseries()) diff --git a/packages/ref-metrics-example/src/ref_metrics_example/example.py b/packages/ref-metrics-example/src/ref_metrics_example/example.py index 3429bc2..64f20f6 100644 --- a/packages/ref-metrics-example/src/ref_metrics_example/example.py +++ b/packages/ref-metrics-example/src/ref_metrics_example/example.py @@ -78,7 +78,8 @@ class GlobalMeanTimeseries(Metric): Calculate the annual mean global mean timeseries for a dataset """ - name = "global_mean_timeseries" + name = "Global Mean Timeseries" + slug = "global-mean-timeseries" data_requirements = ( DataRequirement( diff --git a/packages/ref-metrics-example/tests/unit/test_metrics.py b/packages/ref-metrics-example/tests/unit/test_metrics.py index 8e2df05..2e0f862 100644 --- a/packages/ref-metrics-example/tests/unit/test_metrics.py +++ b/packages/ref-metrics-example/tests/unit/test_metrics.py @@ -20,7 +20,7 @@ def metric_dataset(cmip6_data_catalog) -> MetricDataset: def test_annual_mean(esgf_data_dir, metric_dataset): - annual_mean = calculate_annual_mean_timeseries(metric_dataset["cmip6"].path[0]) + annual_mean = calculate_annual_mean_timeseries(metric_dataset["cmip6"].path.to_list()) assert annual_mean.time.size == 286 diff --git a/packages/ref-metrics-example/tests/unit/test_provider.py b/packages/ref-metrics-example/tests/unit/test_provider.py index 5de9e74..fd164a0 100644 --- a/packages/ref-metrics-example/tests/unit/test_provider.py +++ b/packages/ref-metrics-example/tests/unit/test_provider.py @@ -8,7 +8,8 @@ def test_version(): def test_provider(): - assert provider.name == "example" + assert provider.name == "Example" + assert provider.slug == "example" assert provider.version == __version__ assert len(provider) == 1 diff --git a/packages/ref/src/ref/solver.py b/packages/ref/src/ref/solver.py index f051625..6138a1a 100644 --- a/packages/ref/src/ref/solver.py +++ b/packages/ref/src/ref/solver.py @@ -14,13 +14,14 @@ from attrs import define, frozen from loguru import logger from ref_core.constraints import apply_constraint -from ref_core.datasets import MetricDataset, SourceDatasetType +from ref_core.datasets import DatasetCollection, MetricDataset, SourceDatasetType from ref_core.exceptions import InvalidMetricException from ref_core.executor import get_executor from ref_core.metrics import DataRequirement, Metric, MetricExecutionDefinition from ref_core.providers import MetricsProvider from ref.database import Database +from ref.datasets import get_dataset_adapter from ref.datasets.cmip6 import CMIP6DatasetAdapter from ref.env import env from ref.provider_registry import ProviderRegistry @@ -40,7 +41,8 @@ def build_metric_execution_info(self) -> MetricExecutionDefinition: """ Build the metric execution info for the current metric execution """ - slug = f"{self.provider.slug}-{self.metric.slug}-{self.metric_dataset.slug}" + # TODO: We might want to pretty print the dataset slug + slug = "_".join([self.provider.slug, self.metric.slug, self.metric_dataset.slug]) return MetricExecutionDefinition( output_fragment=pathlib.Path(self.provider.slug) / self.metric.slug / self.metric_dataset.slug, @@ -154,7 +156,12 @@ def solve(self) -> typing.Generator[MetricExecution, None, None]: provider=provider, metric=metric, metric_dataset=MetricDataset( - {key: value for key, value in zip(dataset_groups.keys(), items)} + { + key: DatasetCollection( + datasets=value, slug_column=get_dataset_adapter(key.value).slug_column + ) + for key, value in zip(dataset_groups.keys(), items) + } ), ) diff --git a/packages/ref/tests/unit/test_solver.py b/packages/ref/tests/unit/test_solver.py index 40efad8..25046e4 100644 --- a/packages/ref/tests/unit/test_solver.py +++ b/packages/ref/tests/unit/test_solver.py @@ -25,10 +25,9 @@ class TestMetricSolver: def test_solver_build_from_db(self, solver): assert isinstance(solver, MetricSolver) assert isinstance(solver.provider_registry, ProviderRegistry) - assert solver.data_catalog == {} - - def test_solver_solve_empty(self, solver): - assert len(list(solver.solve())) == 0 + assert SourceDatasetType.CMIP6 in solver.data_catalog + assert isinstance(solver.data_catalog[SourceDatasetType.CMIP6], pd.DataFrame) + assert len(solver.data_catalog[SourceDatasetType.CMIP6]) @pytest.mark.parametrize( @@ -185,7 +184,7 @@ def test_solve_metrics_default_solver(mock_executor, db_seeded, solver): def test_solve_metrics(mock_executor, db_seeded, solver): solve_metrics(db_seeded, dry_run=False, solver=solver) - assert mock_executor.return_value.run_metric.call_count == 1 + assert mock_executor.return_value.run_metric.call_count == 4 def test_solve_metrics_dry_run(db_seeded):