Skip to content

Commit

Permalink
test: Fix up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjared committed Dec 3, 2024
1 parent 9c8a115 commit 922c91a
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
__core_version__ = importlib.metadata.version("ref_core")

# Initialise the metrics manager and register the example metric
provider = MetricsProvider("Example", "example", __version__)
provider = MetricsProvider("Example", __version__)
provider.register(GlobalMeanTimeseries())
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class GlobalMeanTimeseries(Metric):
Calculate the annual mean global mean timeseries for a dataset
"""

name = "global_mean_timeseries"
name = "Global Mean Timeseries"
slug = "global-mean-timeseries"

data_requirements = (
DataRequirement(
Expand Down
2 changes: 1 addition & 1 deletion packages/ref-metrics-example/tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def metric_dataset(cmip6_data_catalog) -> MetricDataset:


def test_annual_mean(esgf_data_dir, metric_dataset):
annual_mean = calculate_annual_mean_timeseries(metric_dataset["cmip6"].path[0])
annual_mean = calculate_annual_mean_timeseries(metric_dataset["cmip6"].path.to_list())

assert annual_mean.time.size == 286

Expand Down
3 changes: 2 additions & 1 deletion packages/ref-metrics-example/tests/unit/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def test_version():


def test_provider():
assert provider.name == "example"
assert provider.name == "Example"
assert provider.slug == "example"
assert provider.version == __version__

assert len(provider) == 1
13 changes: 10 additions & 3 deletions packages/ref/src/ref/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from attrs import define, frozen
from loguru import logger
from ref_core.constraints import apply_constraint
from ref_core.datasets import MetricDataset, SourceDatasetType
from ref_core.datasets import DatasetCollection, MetricDataset, SourceDatasetType
from ref_core.exceptions import InvalidMetricException
from ref_core.executor import get_executor
from ref_core.metrics import DataRequirement, Metric, MetricExecutionDefinition
from ref_core.providers import MetricsProvider

from ref.database import Database
from ref.datasets import get_dataset_adapter
from ref.datasets.cmip6 import CMIP6DatasetAdapter
from ref.env import env
from ref.provider_registry import ProviderRegistry
Expand All @@ -40,7 +41,8 @@ def build_metric_execution_info(self) -> MetricExecutionDefinition:
"""
Build the metric execution info for the current metric execution
"""
slug = f"{self.provider.slug}-{self.metric.slug}-{self.metric_dataset.slug}"
# TODO: We might want to pretty print the dataset slug
slug = "_".join([self.provider.slug, self.metric.slug, self.metric_dataset.slug])

return MetricExecutionDefinition(
output_fragment=pathlib.Path(self.provider.slug) / self.metric.slug / self.metric_dataset.slug,
Expand Down Expand Up @@ -154,7 +156,12 @@ def solve(self) -> typing.Generator[MetricExecution, None, None]:
provider=provider,
metric=metric,
metric_dataset=MetricDataset(
{key: value for key, value in zip(dataset_groups.keys(), items)}
{
key: DatasetCollection(
datasets=value, slug_column=get_dataset_adapter(key.value).slug_column
)
for key, value in zip(dataset_groups.keys(), items)
}
),
)

Expand Down
9 changes: 4 additions & 5 deletions packages/ref/tests/unit/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ class TestMetricSolver:
def test_solver_build_from_db(self, solver):
assert isinstance(solver, MetricSolver)
assert isinstance(solver.provider_registry, ProviderRegistry)
assert solver.data_catalog == {}

def test_solver_solve_empty(self, solver):
assert len(list(solver.solve())) == 0
assert SourceDatasetType.CMIP6 in solver.data_catalog
assert isinstance(solver.data_catalog[SourceDatasetType.CMIP6], pd.DataFrame)
assert len(solver.data_catalog[SourceDatasetType.CMIP6])


@pytest.mark.parametrize(
Expand Down Expand Up @@ -185,7 +184,7 @@ def test_solve_metrics_default_solver(mock_executor, db_seeded, solver):
def test_solve_metrics(mock_executor, db_seeded, solver):
solve_metrics(db_seeded, dry_run=False, solver=solver)

assert mock_executor.return_value.run_metric.call_count == 1
assert mock_executor.return_value.run_metric.call_count == 4


def test_solve_metrics_dry_run(db_seeded):
Expand Down

0 comments on commit 922c91a

Please sign in to comment.