From 2c1499e83d5f0904e01e6d2f75107c7eea158eea Mon Sep 17 00:00:00 2001 From: Lars Buntemeyer Date: Sat, 25 Jan 2025 17:21:02 +0000 Subject: [PATCH] update catalog selection --- evaltools/source.py | 12 +++++++++--- evaltools/utils.py | 10 ++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/evaltools/source.py b/evaltools/source.py index 4df1dbe..dac3b49 100644 --- a/evaltools/source.py +++ b/evaltools/source.py @@ -26,7 +26,12 @@ def open_catalog(url=None): def get_source_collection( - variable_id, frequency, driving_source_id="ERA5", add_fx=None, catalog=None + variable_id, + frequency, + driving_source_id="ERA5", + add_fx=None, + catalog=None, + **kwargs, ): """ Search the catalog for datasets matching the specified variable_id, frequency, and driving_source_id. @@ -50,15 +55,16 @@ def get_source_collection( frequency=frequency, driving_source_id=driving_source_id, require_all_on=["source_id"], + **kwargs, ) source_ids = list(subset.df.source_id.unique()) print(f"Found: {source_ids} for variables: {variable_id}") if add_fx: if add_fx is True: - fx = catalog.search(source_id=source_ids, frequency="fx") + fx = catalog.search(source_id=source_ids, frequency="fx", **kwargs) else: fx = catalog.search( - source_id=source_ids, frequency="fx", variable_id=add_fx + source_id=source_ids, frequency="fx", variable_id=add_fx, **kwargs ) if fx.df.empty: warn(f"static variables not found: {variable_id}") diff --git a/evaltools/utils.py b/evaltools/utils.py index a06b7a5..1bb6adb 100644 --- a/evaltools/utils.py +++ b/evaltools/utils.py @@ -33,7 +33,7 @@ def iid_to_dict(iid, attrs=None): return dict(zip(attrs, values)) -def dict_to_iid(attrs, drop=None): +def dict_to_iid(attrs, drop=None, delimiter="."): """ Convert a dictionary of dataset attributes to a dataset ID. @@ -45,10 +45,10 @@ def dict_to_iid(attrs, drop=None): """ if drop is None: drop = [] - return ".".join(v for k, v in attrs.items() if k not in drop) + return delimiter.join(v for k, v in attrs.items() if k not in drop) -def short_iid(iid, attrs=None): +def short_iid(iid, attrs=None, delimiter="."): """ Convert a dataset ID to a short ID. @@ -61,7 +61,9 @@ def short_iid(iid, attrs=None): """ if attrs is None: attrs = ["institution_id", "source_id", "driving_source_id", "experiment_id"] - return dict_to_iid({k: v for k, v in iid_to_dict(iid).items() if k in attrs}) + return dict_to_iid( + {k: v for k, v in iid_to_dict(iid).items() if k in attrs}, delimiter=delimiter + ) def sort_by_grid_mapping(dsets):