diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4b63675..db176c76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,8 +37,8 @@ repos: rev: v0.4.6 hooks: - id: ruff - # Next line if for documenation cod snippets - exclude: '^[^_].*_\.py$' + # Next line is to exclude for documentation code snippets + exclude: 'docs/(.*/)?[a-z]\w+_.py$' args: - --line-length=120 - --fix diff --git a/docs/building/sources.rst b/docs/building/sources.rst index 3e5e8aee..15c955b9 100644 --- a/docs/building/sources.rst +++ b/docs/building/sources.rst @@ -23,6 +23,7 @@ The following `sources` are currently available: sources/mars sources/grib sources/netcdf + sources/xarray sources/opendap sources/forcings sources/accumulations diff --git a/docs/building/sources/xarray.rst b/docs/building/sources/xarray.rst new file mode 100644 index 00000000..4eba74e0 --- /dev/null +++ b/docs/building/sources/xarray.rst @@ -0,0 +1,6 @@ +######## + xarray +######## + +.. literalinclude:: xarray.yaml + :language: yaml diff --git a/docs/building/sources/xarray.yaml b/docs/building/sources/xarray.yaml new file mode 100644 index 00000000..0d2cd449 --- /dev/null +++ b/docs/building/sources/xarray.yaml @@ -0,0 +1,3 @@ +input: + xarray: + url: https://... diff --git a/pyproject.toml b/pyproject.toml index b5b59afb..02586bdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,22 +50,24 @@ dynamic = [ "version", ] dependencies = [ - "anemoi-utils[provenance]>=0.3.5", + "anemoi-utils[provenance]>=0.3.11", "numpy", "pyyaml", "semantic-version", "tqdm", - "zarr<=2.17", - + "zarr", ] optional-dependencies.all = [ + "aiohttp", "boto3", "earthkit-data[mars]>=0.9", "earthkit-geo>=0.2", "earthkit-meteo", "ecmwflibs>=0.6.3", "entrypoints", + "gcsfs", + "kerchunk", "pyproj", "requests", "s3fs", @@ -81,12 +83,15 @@ optional-dependencies.create = [ ] optional-dependencies.dev = [ + "aiohttp", "boto3", "earthkit-data[mars]>=0.9", "earthkit-geo>=0.2", "earthkit-meteo", "ecmwflibs>=0.6.3", "entrypoints", + "gcsfs", + "kerchunk", "nbsphinx", "pandoc", "pyproj", @@ -108,7 +113,14 @@ optional-dependencies.docs = [ "sphinx-rtd-theme", ] +optional-dependencies.kerchunk = [ + "gcsfs", + "kerchunk", + "s3fs", +] + optional-dependencies.remote = [ + "aiohttp", "boto3", "requests", "s3fs", diff --git a/src/anemoi/datasets/commands/compare.py b/src/anemoi/datasets/commands/compare.py index 135a88fb..b46a64ce 100644 --- a/src/anemoi/datasets/commands/compare.py +++ b/src/anemoi/datasets/commands/compare.py @@ -8,6 +8,10 @@ # nor does it submit to any jurisdiction. # +import numpy as np +import tqdm +import zarr + from anemoi.datasets import open_dataset from . import Command @@ -19,6 +23,8 @@ class Compare(Command): def add_arguments(self, command_parser): command_parser.add_argument("dataset1") command_parser.add_argument("dataset2") + command_parser.add_argument("--data", action="store_true", help="Compare the data.") + command_parser.add_argument("--statistics", action="store_true", help="Compare the statistics.") def run(self, args): ds1 = open_dataset(args.dataset1) @@ -42,5 +48,58 @@ def run(self, args): f"{ds2.statistics['mean'][ds2.name_to_index[v]]:14g}", ) + if args.data: + print() + print("Data:") + print("-----") + print() + + diff = 0 + for a, b in tqdm.tqdm(zip(ds1, ds2)): + if not np.array_equal(a, b, equal_nan=True): + diff += 1 + + print(f"Number of different rows: {diff}/{len(ds1)}") + + if args.data: + print() + print("Data 2:") + print("-----") + print() + + ds1 = zarr.open(args.dataset1, mode="r") + ds2 = zarr.open(args.dataset2, mode="r") + + for name in ( + "data", + "count", + "sums", + "squares", + "mean", + "stdev", + "minimum", + "maximum", + "latitudes", + "longitudes", + ): + a1 = ds1[name] + a2 = ds2[name] + + if len(a1) != len(a2): + print(f"{name}: lengths mismatch {len(a1)} != {len(a2)}") + continue + + diff = 0 + for a, b in tqdm.tqdm(zip(a1, a2), leave=False): + if not np.array_equal(a, b, equal_nan=True): + if diff == 0: + print(f"\n{name}: first different row:") + print(a[a != b]) + print(b[a != b]) + + diff += 1 + + print(f"{name}: {diff} different rows out of {len(a1)}") + command = Compare diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index eb9bc738..1fa6292d 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -1,7 +1,39 @@ -from anemoi.datasets.create import Creator +import datetime +import logging +import time +from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import as_completed + +import tqdm +from anemoi.utils.humanize import seconds_to_human + +from anemoi.datasets.create.trace import enable_trace from . import Command +LOG = logging.getLogger(__name__) + + +def task(what, options, *args, **kwargs): + """ + Make sure `import Creator` is done in the sub-processes, and not in the main one. + """ + + now = datetime.datetime.now() + LOG.debug(f"Task {what}({args},{kwargs}) starting") + + from anemoi.datasets.create import Creator + + if "trace" in options: + enable_trace(options["trace"]) + + c = Creator(**options) + result = getattr(c, what)(*args, **kwargs) + + LOG.debug(f"Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") + return result + class Create(Command): """Create a dataset.""" @@ -22,12 +54,61 @@ def add_arguments(self, command_parser): ) command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the dataset.") command_parser.add_argument("path", help="Path to store the created data.") + group = command_parser.add_mutually_exclusive_group() + group.add_argument("--threads", help="Use `n` parallel thread workers.", type=int, default=0) + group.add_argument("--processes", help="Use `n` parallel process workers.", type=int, default=0) + command_parser.add_argument("--trace", action="store_true") def run(self, args): - kwargs = vars(args) + now = time.time() + if args.threads + args.processes: + self.parallel_create(args) + else: + self.serial_create(args) + LOG.info(f"Create completed in {seconds_to_human(time.time()-now)}") - c = Creator(**kwargs) + def serial_create(self, args): + from anemoi.datasets.create import Creator + + options = vars(args) + c = Creator(**options) c.create() + def parallel_create(self, args): + """Some modules, like fsspec do not work well with fork() + Other modules may not be thread safe. So we implement + parallel loadining using multiprocessing before any + of the modules are imported. + """ + + options = vars(args) + parallel = args.threads + args.processes + args.use_threads = args.threads > 0 + + if args.use_threads: + ExecutorClass = ThreadPoolExecutor + else: + ExecutorClass = ProcessPoolExecutor + + with ExecutorClass(max_workers=1) as executor: + total = executor.submit(task, "init", options).result() + + futures = [] + + with ExecutorClass(max_workers=parallel) as executor: + for n in range(total): + futures.append(executor.submit(task, "load", options, parts=f"{n+1}/{total}")) + + for future in tqdm.tqdm( + as_completed(futures), desc="Loading", total=len(futures), colour="green", position=parallel + 1 + ): + future.result() + + with ExecutorClass(max_workers=1) as executor: + executor.submit(task, "statistics", options).result() + executor.submit(task, "additions", options).result() + executor.submit(task, "cleanup", options).result() + executor.submit(task, "verify", options).result() + command = Create diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 566780c1..fddb1d0a 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -217,7 +217,7 @@ def print_sizes(self, size): if total_size is not None: print(f"πŸ’½ Size : {bytes(total_size)} ({bytes_to_human(total_size)})") if n is not None: - print(f"πŸ“ Files : {n}") + print(f"πŸ“ Files : {n:,}") @property def statistics(self): diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index c2f0effc..e69043d3 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -7,8 +7,15 @@ # nor does it submit to any jurisdiction. # +import logging import os +LOG = logging.getLogger(__name__) + + +def _ignore(*args, **kwargs): + pass + class Creator: def __init__( @@ -16,19 +23,21 @@ def __init__( path, config=None, cache=None, - print=print, + use_threads=False, statistics_tmp=None, overwrite=False, test=None, + progress=None, **kwargs, ): self.path = path # Output path self.config = config self.cache = cache - self.print = print + self.use_threads = use_threads self.statistics_tmp = statistics_tmp self.overwrite = overwrite self.test = test + self.progress = progress if progress is not None else _ignore def init(self, check_name=False): # check path @@ -44,10 +53,11 @@ def init(self, check_name=False): path=self.path, config=self.config, statistics_tmp=self.statistics_tmp, - print=self.print, + use_threads=self.use_threads, + progress=self.progress, test=self.test, ) - obj.initialise(check_name=check_name) + return obj.initialise(check_name=check_name) def load(self, parts=None): from .loaders import ContentLoader @@ -56,7 +66,8 @@ def load(self, parts=None): loader = ContentLoader.from_dataset_config( path=self.path, statistics_tmp=self.statistics_tmp, - print=self.print, + use_threads=self.use_threads, + progress=self.progress, parts=parts, ) loader.load() @@ -66,7 +77,8 @@ def statistics(self, force=False, output=None, start=None, end=None): loader = StatisticsAdder.from_dataset( path=self.path, - print=self.print, + use_threads=self.use_threads, + progress=self.progress, statistics_tmp=self.statistics_tmp, statistics_output=output, recompute=False, @@ -74,20 +86,22 @@ def statistics(self, force=False, output=None, start=None, end=None): statistics_end=end, ) loader.run() + assert loader.ready() def size(self): from .loaders import DatasetHandler from .size import compute_directory_sizes metadata = compute_directory_sizes(self.path) - handle = DatasetHandler.from_dataset(path=self.path, print=self.print) + handle = DatasetHandler.from_dataset(path=self.path, use_threads=self.use_threads) handle.update_metadata(**metadata) + assert handle.ready() def cleanup(self): from .loaders import DatasetHandlerWithStatistics cleaner = DatasetHandlerWithStatistics.from_dataset( - path=self.path, print=self.print, statistics_tmp=self.statistics_tmp + path=self.path, use_threads=self.use_threads, progress=self.progress, statistics_tmp=self.statistics_tmp ) cleaner.tmp_statistics.delete() cleaner.registry.clean() @@ -103,15 +117,17 @@ def init_additions(self, delta=[1, 3, 6, 12, 24], statistics=True): from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency if statistics: - a = StatisticsAddition.from_dataset(path=self.path, print=self.print) + a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads) a.initialise() for d in delta: try: - a = TendenciesStatisticsAddition.from_dataset(path=self.path, print=self.print, delta=d) + a = TendenciesStatisticsAddition.from_dataset( + path=self.path, use_threads=self.use_threads, progress=self.progress, delta=d + ) a.initialise() except TendenciesStatisticsDeltaNotMultipleOfFrequency: - self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") + LOG.info(f"Skipping delta={d} as it is not a multiple of the frequency.") def run_additions(self, parts=None, delta=[1, 3, 6, 12, 24], statistics=True): from .loaders import StatisticsAddition @@ -119,15 +135,17 @@ def run_additions(self, parts=None, delta=[1, 3, 6, 12, 24], statistics=True): from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency if statistics: - a = StatisticsAddition.from_dataset(path=self.path, print=self.print) + a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads) a.run(parts) for d in delta: try: - a = TendenciesStatisticsAddition.from_dataset(path=self.path, print=self.print, delta=d) + a = TendenciesStatisticsAddition.from_dataset( + path=self.path, use_threads=self.use_threads, progress=self.progress, delta=d + ) a.run(parts) except TendenciesStatisticsDeltaNotMultipleOfFrequency: - self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") + LOG.debug(f"Skipping delta={d} as it is not a multiple of the frequency.") def finalise_additions(self, delta=[1, 3, 6, 12, 24], statistics=True): from .loaders import StatisticsAddition @@ -135,15 +153,17 @@ def finalise_additions(self, delta=[1, 3, 6, 12, 24], statistics=True): from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency if statistics: - a = StatisticsAddition.from_dataset(path=self.path, print=self.print) + a = StatisticsAddition.from_dataset(path=self.path, use_threads=self.use_threads) a.finalise() for d in delta: try: - a = TendenciesStatisticsAddition.from_dataset(path=self.path, print=self.print, delta=d) + a = TendenciesStatisticsAddition.from_dataset( + path=self.path, use_threads=self.use_threads, progress=self.progress, delta=d + ) a.finalise() except TendenciesStatisticsDeltaNotMultipleOfFrequency: - self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") + LOG.debug(f"Skipping delta={d} as it is not a multiple of the frequency.") def finalise(self, **kwargs): self.statistics(**kwargs) @@ -174,3 +194,10 @@ def _path_readable(self): return True except zarr.errors.PathNotFoundError: return False + + def verify(self): + from .loaders import DatasetVerifier + + handle = DatasetVerifier.from_dataset(path=self.path, use_threads=self.use_threads) + + handle.verify() diff --git a/src/anemoi/datasets/create/check.py b/src/anemoi/datasets/create/check.py index e24c7e02..0c262113 100644 --- a/src/anemoi/datasets/create/check.py +++ b/src/anemoi/datasets/create/check.py @@ -56,7 +56,7 @@ def raise_if_not_valid(self, print=print): raise ValueError(self.error_message) def _parse(self, name): - pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?(.*)$" + pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?([a-zA-Z0-9-]+)$" match = re.match(pattern, name) assert match, (name, pattern) @@ -136,18 +136,19 @@ class StatisticsValueError(ValueError): pass -def check_data_values(arr, *, name: str, log=[], allow_nan=False): - if allow_nan is False: - allow_nan = lambda x: False # noqa: E731 +def check_data_values(arr, *, name: str, log=[], allow_nans=False): - if allow_nan(name): + if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans: arr = arr[~np.isnan(arr)] + assert arr.size > 0, (name, *log) + min, max = arr.min(), arr.max() assert not (np.isnan(arr).any()), (name, min, max, *log) if min == 9999.0: warnings.warn(f"Min value 9999 for {name}") + if max == 9999.0: warnings.warn(f"Max value 9999 for {name}") diff --git a/src/anemoi/datasets/create/chunks.py b/src/anemoi/datasets/create/chunks.py index 4dc988f6..3ab70b8c 100644 --- a/src/anemoi/datasets/create/chunks.py +++ b/src/anemoi/datasets/create/chunks.py @@ -57,7 +57,7 @@ def __init__(self, *, parts, total): if not parts: warnings.warn(f"Nothing to do for chunk {i}/{n}.") - LOG.info(f"Running parts: {parts}") + LOG.debug(f"Running parts: {parts}") self.allowed = parts diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 352ff848..605ded96 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -12,10 +12,10 @@ from copy import deepcopy import yaml +from anemoi.utils.config import DotDict +from anemoi.utils.config import load_any_dict_format from earthkit.data.core.order import normalize_order_by -from .utils import load_json_or_yaml - LOG = logging.getLogger(__name__) @@ -43,31 +43,10 @@ def check_dict_value_and_set(dic, key, value): if dic[key] == value: return raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") - print(f"Setting {key}={value} in config") + LOG.info(f"Setting {key}={value} in config") dic[key] = value -class DictObj(dict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - for key, value in self.items(): - if isinstance(value, dict): - self[key] = DictObj(value) - continue - if isinstance(value, list): - self[key] = [DictObj(item) if isinstance(item, dict) else item for item in value] - continue - - def __getattr__(self, attr): - try: - return self[attr] - except KeyError: - raise AttributeError(attr) - - def __setattr__(self, attr, value): - self[attr] = value - - def resolve_includes(config): if isinstance(config, list): return [resolve_includes(c) for c in config] @@ -79,11 +58,11 @@ def resolve_includes(config): return config -class Config(DictObj): +class Config(DotDict): def __init__(self, config=None, **kwargs): if isinstance(config, str): self.config_path = os.path.realpath(config) - config = load_json_or_yaml(config) + config = load_any_dict_format(config) else: config = deepcopy(config if config is not None else {}) config = resolve_includes(config) diff --git a/src/anemoi/datasets/create/functions/filters/rename.py b/src/anemoi/datasets/create/functions/filters/rename.py index cd6300a0..071a4578 100644 --- a/src/anemoi/datasets/create/functions/filters/rename.py +++ b/src/anemoi/datasets/create/functions/filters/rename.py @@ -26,15 +26,23 @@ def __init__(self, field, what, renaming): self.what = what self.renaming = renaming - def metadata(self, key, **kwargs): + def metadata(self, key=None, **kwargs): + if key is None: + return self.field.metadata(**kwargs) + value = self.field.metadata(key, **kwargs) if key == self.what: return self.renaming.get(value, value) + return value def __getattr__(self, name): return getattr(self.field, name) + def __repr__(self) -> str: + return repr(self.field) + return f"{self.field} -> {self.what} -> {self.renaming}" + class RenamedFieldFormat: """Rename a field based on a format string. diff --git a/src/anemoi/datasets/create/functions/filters/rotate_winds.py b/src/anemoi/datasets/create/functions/filters/rotate_winds.py index b18970b9..a644ff22 100644 --- a/src/anemoi/datasets/create/functions/filters/rotate_winds.py +++ b/src/anemoi/datasets/create/functions/filters/rotate_winds.py @@ -9,6 +9,8 @@ from collections import defaultdict +import tqdm +from anemoi.utils.humanize import plural from earthkit.data.indexing.fieldlist import FieldArray from earthkit.geo.rotate import rotate_vector @@ -24,6 +26,9 @@ def to_numpy(self, *args, **kwargs): def __getattr__(self, name): return getattr(self.field, name) + def __repr__(self) -> str: + return repr(self.field) + def execute( context, @@ -35,6 +40,8 @@ def execute( ): from pyproj import CRS + context.trace("πŸ”„", "Rotating winds (extracting winds from ", plural(len(input), "field")) + result = FieldArray() wind_params = (x_wind, y_wind) @@ -55,7 +62,9 @@ def execute( wind_pairs[key][param] = f - for _, pairs in wind_pairs.items(): + context.trace("πŸ”„", "Rotating", plural(len(wind_pairs), "wind"), "(speed will likely include data download)") + + for _, pairs in tqdm.tqdm(list(wind_pairs.items())): if len(pairs) != 2: raise ValueError("Missing wind component") diff --git a/src/anemoi/datasets/create/functions/sources/__init__.py b/src/anemoi/datasets/create/functions/sources/__init__.py index 33d7fa0a..6192702a 100644 --- a/src/anemoi/datasets/create/functions/sources/__init__.py +++ b/src/anemoi/datasets/create/functions/sources/__init__.py @@ -6,3 +6,42 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # + +import glob +import logging + +from earthkit.data.utils.patterns import Pattern + +LOG = logging.getLogger(__name__) + + +def _expand(paths): + for path in paths: + if path.startswith("file://"): + path = path[7:] + + if path.startswith("http://"): + yield path + continue + + if path.startswith("https://"): + yield path + continue + + cnt = 0 + for p in glob.glob(path): + yield p + cnt += 1 + if cnt == 0: + yield path + + +def iterate_patterns(path, dates, **kwargs): + given_paths = path if isinstance(path, list) else [path] + + dates = [d.isoformat() for d in dates] + + for path in given_paths: + paths = Pattern(path, ignore_missing_keys=True).substitute(date=dates, **kwargs) + for path in _expand(paths): + yield path, dates diff --git a/src/anemoi/datasets/create/functions/sources/accumulations.py b/src/anemoi/datasets/create/functions/sources/accumulations.py index 63ae2a89..33e0875b 100644 --- a/src/anemoi/datasets/create/functions/sources/accumulations.py +++ b/src/anemoi/datasets/create/functions/sources/accumulations.py @@ -24,7 +24,7 @@ LOG = logging.getLogger(__name__) -def member(field): +def _member(field): # Bug in eccodes has number=0 randomly number = field.metadata("number", default=0) if number is None: @@ -68,7 +68,7 @@ def check(self, field): self.time, field.metadata("time"), ) - assert self.number == member(field), (self.number, member(field)) + assert self.number == _member(field), (self.number, _member(field)) return @@ -241,17 +241,17 @@ def _mars_date_time_step(cls, base_date, step1, step2, add_step, frequency): ) -def identity(x): +def _identity(x): return x -def compute_accumulations( +def _compute_accumulations( context, dates, request, user_accumulation_period=6, data_accumulation_period=None, - patch=identity, + patch=_identity, base_times=None, ): adjust_step = isinstance(user_accumulation_period, int) @@ -340,7 +340,7 @@ def compute_accumulations( field.metadata("date"), field.metadata("time"), field.metadata("step"), - member(field), + _member(field), ) values = field.values # optimisation assert accumulations[key], key @@ -365,43 +365,13 @@ def compute_accumulations( return ds -def to_list(x): +def _to_list(x): if isinstance(x, (list, tuple)): return x return [x] -def normalise_time_to_hours(r): - r = deepcopy(r) - if "time" not in r: - return r - - times = [] - for t in to_list(r["time"]): - assert len(t) == 4, r - assert t.endswith("00"), r - times.append(int(t) // 100) - r["time"] = tuple(times) - return r - - -def normalise_number(r): - if "number" not in r: - return r - number = r["number"] - number = to_list(number) - - if len(number) > 4 and (number[1] == "to" and number[3] == "by"): - return list(range(int(number[0]), int(number[2]) + 1, int(number[4]))) - - if len(number) > 2 and number[1] == "to": - return list(range(int(number[0]), int(number[2]) + 1)) - - r["number"] = number - return r - - -def scda(request): +def _scda(request): if request["time"] in (6, 18, 600, 1800): request["stream"] = "scda" else: @@ -410,14 +380,14 @@ def scda(request): def accumulations(context, dates, **request): - to_list(request["param"]) + _to_list(request["param"]) class_ = request.get("class", "od") stream = request.get("stream", "oper") user_accumulation_period = request.pop("accumulation_period", 6) KWARGS = { - ("od", "oper"): dict(patch=scda), + ("od", "oper"): dict(patch=_scda), ("od", "elda"): dict(base_times=(6, 18)), ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), @@ -427,7 +397,7 @@ def accumulations(context, dates, **request): context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") - return compute_accumulations( + return _compute_accumulations( context, dates, request, diff --git a/src/anemoi/datasets/create/functions/sources/constants.py b/src/anemoi/datasets/create/functions/sources/constants.py index 97490d20..ba716579 100644 --- a/src/anemoi/datasets/create/functions/sources/constants.py +++ b/src/anemoi/datasets/create/functions/sources/constants.py @@ -18,6 +18,9 @@ def constants(context, dates, template, param): stacklevel=2, ) context.trace("βœ…", f"from_source(constants, {template}, {param}") + if len(template) == 0: + raise ValueError("Forcings template is empty.") + return from_source("forcings", source_or_dataset=template, date=dates, param=param) diff --git a/src/anemoi/datasets/create/functions/sources/grib.py b/src/anemoi/datasets/create/functions/sources/grib.py index d4f3c07c..1ddca353 100644 --- a/src/anemoi/datasets/create/functions/sources/grib.py +++ b/src/anemoi/datasets/create/functions/sources/grib.py @@ -26,8 +26,12 @@ def check(ds, paths, **kwargs): def _expand(paths): for path in paths: + cnt = 0 for p in glob.glob(path): yield p + cnt += 1 + if cnt == 0: + yield path def execute(context, dates, path, *args, **kwargs): diff --git a/src/anemoi/datasets/create/functions/sources/hindcasts.py b/src/anemoi/datasets/create/functions/sources/hindcasts.py index 2df6f281..0bbe4fb4 100644 --- a/src/anemoi/datasets/create/functions/sources/hindcasts.py +++ b/src/anemoi/datasets/create/functions/sources/hindcasts.py @@ -7,21 +7,13 @@ # nor does it submit to any jurisdiction. # import datetime -import warnings -from copy import deepcopy - -import earthkit.data as ekd -import numpy as np -from earthkit.data.core.temporary import temp_file -from earthkit.data.readers.grib.output import new_grib_output -from earthkit.utils.availability import Availability from anemoi.datasets.create.functions.sources.mars import mars DEBUG = True -def member(field): +def _member(field): # Bug in eccodes has number=0 randomly number = field.metadata("number") if number is None: @@ -29,368 +21,12 @@ def member(field): return number -class Accumulation: - def __init__(self, out, /, param, date, time, number, step, frequency, **kwargs): - self.out = out - self.param = param - self.date = date - self.time = time - self.steps = step - self.number = number - self.values = None - self.seen = set() - self.startStep = None - self.endStep = None - self.done = False - self.frequency = frequency - self._check = None - - @property - def key(self): - return (self.param, self.date, self.time, self.steps, self.number) - - def check(self, field): - if self._check is None: - self._check = field.as_mars() - - assert self.param == field.metadata("param"), ( - self.param, - field.metadata("param"), - ) - assert self.date == field.metadata("date"), ( - self.date, - field.metadata("date"), - ) - assert self.time == field.metadata("time"), ( - self.time, - field.metadata("time"), - ) - assert self.number == member(field), (self.number, member(field)) - - return - - mars = field.as_mars() - keys1 = sorted(self._check.keys()) - keys2 = sorted(mars.keys()) - - assert keys1 == keys2, (keys1, keys2) - - for k in keys1: - if k not in ("step",): - assert self._check[k] == mars[k], (k, self._check[k], mars[k]) - - def write(self, template): - - assert self.startStep != self.endStep, (self.startStep, self.endStep) - assert np.all(self.values >= 0), (np.amin(self.values), np.amax(self.values)) - - self.out.write( - self.values, - template=template, - stepType="accum", - startStep=self.startStep, - endStep=self.endStep, - ) - self.values = None - self.done = True - - def add(self, field, values): - - self.check(field) - - step = field.metadata("step") - if step not in self.steps: - return - - if not np.all(values >= 0): - warnings.warn(f"Negative values for {field}: {np.amin(values)} {np.amax(values)}") - - assert not self.done, (self.key, step) - assert step not in self.seen, (self.key, step) - - startStep = field.metadata("startStep") - endStep = field.metadata("endStep") - - if self.buggy_steps and startStep == endStep: - startStep = 0 - - assert step == endStep, (startStep, endStep, step) - - self.compute(values, startStep, endStep) - - self.seen.add(step) - - if len(self.seen) == len(self.steps): - self.write(template=field) - - @classmethod - def mars_date_time_steps(cls, dates, step1, step2, frequency, base_times, adjust_step): - - # assert step1 > 0, (step1, step2, frequency) - - for valid_date in dates: - base_date = valid_date - datetime.timedelta(hours=step2) - add_step = 0 - if base_date.hour not in base_times: - if not adjust_step: - raise ValueError( - f"Cannot find a base time in {base_times} that validates on {valid_date.isoformat()} for step={step2}" - ) - - while base_date.hour not in base_times: - # print(f'{base_date=}, {base_times=}, {add_step=} {frequency=}') - base_date -= datetime.timedelta(hours=1) - add_step += 1 - - yield cls._mars_date_time_step(base_date, step1, step2, add_step, frequency) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.key})" - - -class AccumulationFromStart(Accumulation): - buggy_steps = True - - def compute(self, values, startStep, endStep): - - assert startStep == 0, startStep - - if self.values is None: - - self.values = np.copy(values) - self.startStep = 0 - self.endStep = endStep - - else: - assert endStep != self.endStep, (self.endStep, endStep) - - if endStep > self.endStep: - # assert endStep - self.endStep == self.stepping, (self.endStep, endStep, self.stepping) - self.values = values - self.values - self.startStep = self.endStep - self.endStep = endStep - else: - # assert self.endStep - endStep == self.stepping, (self.endStep, endStep, self.stepping) - self.values = self.values - values - self.startStep = endStep - - if not np.all(self.values >= 0): - warnings.warn(f"Negative values for {self.param}: {np.amin(self.values)} {np.amax(self.values)}") - self.values = np.maximum(self.values, 0) - - @classmethod - def _mars_date_time_step(cls, base_date, step1, step2, add_step, frequency): - assert not frequency, frequency - - steps = (step1 + add_step, step2 + add_step) - if steps[0] == 0: - steps = (steps[1],) - - return ( - base_date.year * 10000 + base_date.month * 100 + base_date.day, - base_date.hour * 100 + base_date.minute, - steps, - ) - - -class AccumulationFromLastStep(Accumulation): - buggy_steps = False - - def compute(self, values, startStep, endStep): - - assert endStep - startStep == self.frequency, ( - startStep, - endStep, - self.frequency, - ) - - if self.startStep is None: - self.startStep = startStep - else: - self.startStep = min(self.startStep, startStep) - - if self.endStep is None: - self.endStep = endStep - else: - self.endStep = max(self.endStep, endStep) - - if self.values is None: - self.values = np.zeros_like(values) - - self.values += values - - @classmethod - def _mars_date_time_step(cls, base_date, step1, step2, add_step, frequency): - assert frequency > 0, frequency - # assert step1 > 0, (step1, step2, frequency, add_step, base_date) - - steps = [] - for step in range(step1 + frequency, step2 + frequency, frequency): - steps.append(step + add_step) - return ( - base_date.year * 10000 + base_date.month * 100 + base_date.day, - base_date.hour * 100 + base_date.minute, - tuple(steps), - ) - - -def identity(x): - return x - - -def compute_accumulations( - dates, - request, - user_accumulation_period=6, - data_accumulation_period=None, - patch=identity, - base_times=None, -): - adjust_step = isinstance(user_accumulation_period, int) - - if not isinstance(user_accumulation_period, (list, tuple)): - user_accumulation_period = (0, user_accumulation_period) - - assert len(user_accumulation_period) == 2, user_accumulation_period - step1, step2 = user_accumulation_period - assert step1 < step2, user_accumulation_period - - if base_times is None: - base_times = [0, 6, 12, 18] - - base_times = [t // 100 if t > 100 else t for t in base_times] - - AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep - - mars_date_time_steps = AccumulationClass.mars_date_time_steps( - dates, - step1, - step2, - data_accumulation_period, - base_times, - adjust_step, - ) - - request = deepcopy(request) - - param = request["param"] - if not isinstance(param, (list, tuple)): - param = [param] - - number = request.get("number", [0]) - assert isinstance(number, (list, tuple)) - - frequency = data_accumulation_period - - type_ = request.get("type", "an") - if type_ == "an": - type_ = "fc" - - request.update({"type": type_, "levtype": "sfc"}) - - tmp = temp_file() - path = tmp.path - out = new_grib_output(path) - - requests = [] - - accumulations = {} - - for date, time, steps in mars_date_time_steps: - for p in param: - for n in number: - requests.append( - patch( - { - "param": p, - "date": date, - "time": time, - "step": sorted(steps), - "number": n, - } - ) - ) - - compressed = Availability(requests) - ds = ekd.from_source("empty") - for r in compressed.iterate(): - request.update(r) - print("🌧️", request) - ds = ds + ekd.from_source("mars", **request) - - accumulations = {} - for a in [AccumulationClass(out, frequency=frequency, **r) for r in requests]: - for s in a.steps: - key = (a.param, a.date, a.time, s, a.number) - accumulations.setdefault(key, []).append(a) - - for field in ds: - key = ( - field.metadata("param"), - field.metadata("date"), - field.metadata("time"), - field.metadata("step"), - member(field), - ) - values = field.values # optimisation - assert accumulations[key], key - for a in accumulations[key]: - a.add(field, values) - - for acc in accumulations.values(): - for a in acc: - assert a.done, (a.key, a.seen, a.steps) - - out.close() - - ds = ekd.from_source("file", path) - - assert len(ds) / len(param) / len(number) == len(dates), ( - len(ds), - len(param), - len(dates), - ) - ds._tmp = tmp - - return ds - - -def to_list(x): +def _to_list(x): if isinstance(x, (list, tuple)): return x return [x] -def normalise_time_to_hours(r): - r = deepcopy(r) - if "time" not in r: - return r - - times = [] - for t in to_list(r["time"]): - assert len(t) == 4, r - assert t.endswith("00"), r - times.append(int(t) // 100) - r["time"] = tuple(times) - return r - - -def normalise_number(r): - if "number" not in r: - return r - number = r["number"] - number = to_list(number) - - if len(number) > 4 and (number[1] == "to" and number[3] == "by"): - return list(range(int(number[0]), int(number[2]) + 1, int(number[4]))) - - if len(number) > 2 and number[1] == "to": - return list(range(int(number[0]), int(number[2]) + 1)) - - r["number"] = number - return r - - class HindcastCompute: def __init__(self, base_times, available_steps, request): self.base_times = base_times @@ -398,22 +34,34 @@ def __init__(self, base_times, available_steps, request): self.request = request def compute_hindcast(self, date): - for step in self.available_steps: + result = [] + for step in sorted(self.available_steps): # Use the shortest step start_date = date - datetime.timedelta(hours=step) hours = start_date.hour if hours in self.base_times: - r = deepcopy(self.request) + r = self.request.copy() r["date"] = start_date r["time"] = f"{start_date.hour:02d}00" r["step"] = step - return r - raise ValueError( - f"Cannot find data for {self.request} for {date} (base_times={self.base_times}, available_steps={self.available_steps})" - ) + result.append(r) + + if not result: + raise ValueError( + f"Cannot find data for {self.request} for {date} (base_times={self.base_times}, " + f"available_steps={self.available_steps})" + ) + + if len(result) > 1: + raise ValueError( + f"Multiple requests for {self.request} for {date} (base_times={self.base_times}, " + f"available_steps={self.available_steps})" + ) + + return result[0] def use_reference_year(reference_year, request): - request = deepcopy(request) + request = request.copy() hdate = request.pop("date") date = datetime.datetime(reference_year, hdate.month, hdate.day) request.update(date=date.strftime("%Y-%m-%d"), hdate=hdate.strftime("%Y-%m-%d")) @@ -421,15 +69,15 @@ def use_reference_year(reference_year, request): def hindcasts(context, dates, **request): - request["param"] = to_list(request["param"]) - request["step"] = to_list(request["step"]) + request["param"] = _to_list(request["param"]) + request["step"] = _to_list(request["step"]) request["step"] = [int(_) for _ in request["step"]] if request.get("stream") == "enfh" and "base_times" not in request: request["base_times"] = [0] available_steps = request.pop("step") - available_steps = to_list(available_steps) + available_steps = _to_list(available_steps) base_times = request.pop("base_times") @@ -444,7 +92,14 @@ def hindcasts(context, dates, **request): req = use_reference_year(reference_year, req) requests.append(req) - return mars(context, dates, *requests, date_key="hdate") + + return mars( + context, + dates, + *requests, + date_key="hdate", + request_already_using_valid_datetime=True, + ) execute = hindcasts diff --git a/src/anemoi/datasets/create/functions/sources/mars.py b/src/anemoi/datasets/create/functions/sources/mars.py index f0417deb..a36ccf21 100644 --- a/src/anemoi/datasets/create/functions/sources/mars.py +++ b/src/anemoi/datasets/create/functions/sources/mars.py @@ -7,7 +7,6 @@ # nor does it submit to any jurisdiction. # import datetime -from copy import deepcopy from anemoi.utils.humanize import did_you_mean from earthkit.data import from_source @@ -43,25 +42,27 @@ def normalise_time_delta(t): return t -def _expand_mars_request(request, date, date_key="date"): +def _expand_mars_request(request, date, request_already_using_valid_datetime=False, date_key="date"): requests = [] step = to_list(request.get("step", [0])) for s in step: - r = deepcopy(request) - - if isinstance(s, str) and "-" in s: - assert s.count("-") == 1, s - # this takes care of the cases where the step is a period such as 0-24 or 12-24 - hours = int(str(s).split("-")[-1]) - - base = date - datetime.timedelta(hours=hours) - r.update( - { - date_key: base.strftime("%Y%m%d"), - "time": base.strftime("%H%M"), - "step": s, - } - ) + r = request.copy() + + if not request_already_using_valid_datetime: + + if isinstance(s, str) and "-" in s: + assert s.count("-") == 1, s + # this takes care of the cases where the step is a period such as 0-24 or 12-24 + hours = int(str(s).split("-")[-1]) + + base = date - datetime.timedelta(hours=hours) + r.update( + { + date_key: base.strftime("%Y%m%d"), + "time": base.strftime("%H%M"), + "step": s, + } + ) for pproc in ("grid", "rotation", "frame", "area", "bitmap", "resol"): if pproc in r: @@ -73,13 +74,18 @@ def _expand_mars_request(request, date, date_key="date"): return requests -def factorise_requests(dates, *requests, date_key="date"): +def factorise_requests(dates, *requests, request_already_using_valid_datetime=False, date_key="date"): updates = [] for req in requests: # req = normalise_request(req) for d in dates: - updates += _expand_mars_request(req, date=d, date_key=date_key) + updates += _expand_mars_request( + req, + date=d, + request_already_using_valid_datetime=request_already_using_valid_datetime, + date_key=date_key, + ) compressed = Availability(updates) for r in compressed.iterate(): @@ -171,7 +177,7 @@ def use_grib_paramid(r): ] -def mars(context, dates, *requests, date_key="date", **kwargs): +def mars(context, dates, *requests, request_already_using_valid_datetime=False, date_key="date", **kwargs): if not requests: requests = [kwargs] @@ -191,7 +197,12 @@ def mars(context, dates, *requests, date_key="date", **kwargs): "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" ) - requests = factorise_requests(dates, *requests, date_key=date_key) + requests = factorise_requests( + dates, + *requests, + request_already_using_valid_datetime=request_already_using_valid_datetime, + date_key=date_key, + ) ds = from_source("empty") for r in requests: r = {k: v for k, v in r.items() if v != ("-",)} @@ -207,7 +218,11 @@ def mars(context, dates, *requests, date_key="date", **kwargs): raise ValueError( f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" ) - ds = ds + from_source("mars", **r) + try: + ds = ds + from_source("mars", **r) + except Exception as e: + if "File is empty:" not in str(e): + raise return ds diff --git a/src/anemoi/datasets/create/functions/sources/netcdf.py b/src/anemoi/datasets/create/functions/sources/netcdf.py index 9b0ebcbf..9870374b 100644 --- a/src/anemoi/datasets/create/functions/sources/netcdf.py +++ b/src/anemoi/datasets/create/functions/sources/netcdf.py @@ -7,66 +7,8 @@ # nor does it submit to any jurisdiction. # -import glob - -from earthkit.data import from_source -from earthkit.data.utils.patterns import Pattern - - -def _expand(paths): - for path in paths: - if path.startswith("file://"): - path = path[7:] - - if path.startswith("http://"): - yield path - continue - - if path.startswith("https://"): - yield path - continue - - for p in glob.glob(path): - yield p - - -def check(what, ds, paths, **kwargs): - count = 1 - for k, v in kwargs.items(): - if isinstance(v, (tuple, list)): - count *= len(v) - - if len(ds) != count: - raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})") - - -def load_netcdfs(emoji, what, context, dates, path, *args, **kwargs): - given_paths = path if isinstance(path, list) else [path] - - dates = [d.isoformat() for d in dates] - ds = from_source("empty") - - for path in given_paths: - paths = Pattern(path, ignore_missing_keys=True).substitute(*args, date=dates, **kwargs) - - levels = kwargs.get("level", kwargs.get("levelist")) - - for path in _expand(paths): - context.trace(emoji, what.upper(), path) - s = from_source("opendap", path) - s = s.sel( - valid_datetime=dates, - param=kwargs["param"], - step=kwargs.get("step", 0), - ) - if levels: - s = s.sel(levelist=levels) - ds = ds + s - - check(what, ds, given_paths, valid_datetime=dates, **kwargs) - - return ds +from .xarray import load_many def execute(context, dates, path, *args, **kwargs): - return load_netcdfs("πŸ“", "path", context, dates, path, *args, **kwargs) + return load_many("πŸ“", context, dates, path, *args, **kwargs) diff --git a/src/anemoi/datasets/create/functions/sources/opendap.py b/src/anemoi/datasets/create/functions/sources/opendap.py index ffbfc3e8..d25d52f8 100644 --- a/src/anemoi/datasets/create/functions/sources/opendap.py +++ b/src/anemoi/datasets/create/functions/sources/opendap.py @@ -7,8 +7,9 @@ # nor does it submit to any jurisdiction. # -from .netcdf import load_netcdfs + +from .xarray import load_many def execute(context, dates, url, *args, **kwargs): - return load_netcdfs("🌐", "url", context, dates, url, *args, **kwargs) + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py new file mode 100644 index 00000000..468b7a35 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py @@ -0,0 +1,73 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging + +from earthkit.data.core.fieldlist import MultiFieldList + +from anemoi.datasets.data.stores import name_to_zarr_store + +from .. import iterate_patterns +from .fieldlist import XarrayFieldList + +LOG = logging.getLogger(__name__) + + +def check(what, ds, paths, **kwargs): + count = 1 + for k, v in kwargs.items(): + if isinstance(v, (tuple, list)): + count *= len(v) + + if len(ds) != count: + raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})") + + +def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs): + import xarray as xr + + """ + We manage the S3 client ourselve, bypassing fsspec and s3fs layers, because sometimes something on the stack + zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of + connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs. + See https://github.com/pydata/xarray/issues/8842 + + We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`. + """ + + context.trace(emoji, dataset, options) + + if isinstance(dataset, str) and ".zarr" in dataset: + data = xr.open_zarr(name_to_zarr_store(dataset), **options) + else: + data = xr.open_dataset(dataset, **options) + + fs = XarrayFieldList.from_xarray(data, flavour) + result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) + + if len(result) == 0: + LOG.warning(f"No data found for {dataset} and dates {dates}") + LOG.warning(f"Options: {options}") + LOG.warning(data) + + return result + + +def load_many(emoji, context, dates, pattern, **kwargs): + + result = [] + + for path, dates in iterate_patterns(pattern, dates, **kwargs): + result.append(load_one(emoji, context, dates, path, **kwargs)) + + return MultiFieldList(result) + + +def execute(context, dates, url, *args, **kwargs): + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py b/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py new file mode 100644 index 00000000..8e92b0e6 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py @@ -0,0 +1,234 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime +import logging + +import numpy as np +from earthkit.data.utils.dates import to_datetime + +LOG = logging.getLogger(__name__) + + +def is_scalar(variable): + shape = variable.shape + if shape == (1,): + return True + if len(shape) == 0: + return True + return False + + +def extract_single_value(variable): + shape = variable.shape + if np.issubdtype(variable.values.dtype, np.datetime64): + if len(shape) == 0: + return to_datetime(variable.values) # Convert to python datetime + assert False, (shape, variable.values) + + if np.issubdtype(variable.values.dtype, np.timedelta64): + if len(shape) == 0: + # Convert to python timedelta64 + return datetime.timedelta(seconds=variable.values.astype("timedelta64[s]").astype(int).item()) + assert False, (shape, variable.values) + + if shape == (1,): + return variable.values[0] + + if len(shape) == 0: + return variable.values.item() + + assert False, (shape, variable.values) + + +class Coordinate: + is_grid = False + is_dim = True + is_lat = False + is_lon = False + is_time = False + is_step = False + is_date = False + + def __init__(self, variable): + self.variable = variable + self.scalar = is_scalar(variable) + self.kwargs = {} # Used when creating a new coordinate (reduced method) + + def __len__(self): + return 1 if self.scalar else len(self.variable) + + def __repr__(self): + return "%s[name=%s,values=%s]" % ( + self.__class__.__name__, + self.variable.name, + self.variable.values if self.scalar else len(self), + ) + + def reduced(self, i): + """Create a new coordinate with a single value + + Parameters + ---------- + i : int + the index of the value to select + + Returns + ------- + Coordinate + the new coordinate + """ + return self.__class__( + self.variable.isel({self.variable.dims[0]: i}), + **self.kwargs, + ) + + def index(self, value): + """Return the index of the value in the coordinate + + Parameters + ---------- + value : Any + The value to search for + + Returns + ------- + int or None + The index of the value in the coordinate or None if not found + """ + + if isinstance(value, (list, tuple)): + if len(value) == 1: + return self._index_single(value) + else: + return self._index_multiple(value) + return self._index_single(value) + + def _index_single(self, value): + + values = self.variable.values + + # Assume the array is sorted + index = np.searchsorted(values, value) + + if index < len(values) and values[index] == value: + return index + + # If not found, we need to check if the value is in the array + + index = np.where(values == value)[0] + if len(index) > 0: + return index[0] + + return None + + def _index_multiple(self, value): + + values = self.variable.values + + # Assume the array is sorted + + index = np.searchsorted(values, value) + index = index[index < len(values)] + + if np.all(values[index] == value): + return index + + # If not found, we need to check if the value is in the array + + index = np.where(np.isin(values, value))[0] + + # We could also return incomplete matches + if len(index) == len(value): + return index + + return None + + @property + def name(self): + return self.variable.name + + def normalise(self, value): + # Subclasses to format values that will be added to the field metadata + return value + + @property + def single_value(self): + return extract_single_value(self.variable) + + +class TimeCoordinate(Coordinate): + is_time = True + mars_names = ("valid_datetime",) + + def index(self, time): + return super().index(np.datetime64(time)) + + +class DateCoordinate(Coordinate): + is_date = True + mars_names = ("date",) + + def index(self, date): + return super().index(np.datetime64(date)) + + +class StepCoordinate(Coordinate): + is_step = True + mars_names = ("step",) + + +class LevelCoordinate(Coordinate): + mars_names = ("level", "levelist") + + def __init__(self, variable, levtype): + super().__init__(variable) + self.levtype = levtype + # kwargs is used when creating a new coordinate (reduced method) + self.kwargs = {"levtype": levtype} + + def normalise(self, value): + # Some netcdf have pressue levels in float + if int(value) == value: + return int(value) + return value + + +class EnsembleCoordinate(Coordinate): + mars_names = ("number",) + + +class LongitudeCoordinate(Coordinate): + is_grid = True + is_lon = True + mars_names = ("longitude",) + + +class LatitudeCoordinate(Coordinate): + is_grid = True + is_lat = True + mars_names = ("latitude",) + + +class XCoordinate(Coordinate): + is_grid = True + mars_names = ("x",) + + +class YCoordinate(Coordinate): + is_grid = True + mars_names = ("y",) + + +class ScalarCoordinate(Coordinate): + is_grid = False + + @property + def mars_names(self): + return (self.variable.name,) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/field.py b/src/anemoi/datasets/create/functions/sources/xarray/field.py new file mode 100644 index 00000000..d464df78 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/field.py @@ -0,0 +1,109 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging + +from earthkit.data.core.fieldlist import Field +from earthkit.data.core.fieldlist import math + +from .coordinates import extract_single_value +from .coordinates import is_scalar +from .metadata import XArrayMetadata + +LOG = logging.getLogger(__name__) + + +class EmptyFieldList: + def __len__(self): + return 0 + + def __getitem__(self, i): + raise IndexError(i) + + def __repr__(self) -> str: + return "EmptyFieldList()" + + +class XArrayField(Field): + + def __init__(self, owner, selection): + """Create a new XArrayField object. + + Parameters + ---------- + owner : Variable + The variable that owns this field. + selection : XArrayDataArray + A 2D sub-selection of the variable's underlying array. + This is actually a nD object, but the first dimensions are always 1. + The other two dimensions are latitude and longitude. + """ + super().__init__(owner.array_backend) + + self.owner = owner + self.selection = selection + + # Copy the metadata from the owner + self._md = owner._metadata.copy() + + for coord_name, coord_value in self.selection.coords.items(): + if is_scalar(coord_value): + # Extract the single value from the scalar dimension + # and store it in the metadata + coordinate = owner.by_name[coord_name] + self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value)) + + # print(values.ndim, values.shape, selection.dims) + # By now, the only dimensions should be latitude and longitude + self._shape = tuple(list(self.selection.shape)[-2:]) + if math.prod(self._shape) != math.prod(self.selection.shape): + print(self.selection.ndim, self.selection.shape) + print(self.selection) + raise ValueError("Invalid shape for selection") + + @property + def shape(self): + return self._shape + + def to_numpy(self, flatten=False, dtype=None): + values = self.selection.values + + assert dtype is None + if flatten: + return values.flatten() + return values.reshape(self.shape) + + def _make_metadata(self): + return XArrayMetadata(self, self.owner.mapping) + + def grid_points(self): + return self.owner.grid_points() + + @property + def resolution(self): + return None + + @property + def grid_mapping(self): + return self.owner.grid_mapping + + @property + def latitudes(self): + return self.owner.latitudes + + @property + def longitudes(self): + return self.owner.longitudes + + @property + def forecast_reference_time(self): + return self.owner.forecast_reference_time + + def __repr__(self): + return repr(self._metadata) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py new file mode 100644 index 00000000..90ab44aa --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py @@ -0,0 +1,171 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import json +import logging + +import yaml +from earthkit.data.core.fieldlist import FieldList + +from .coordinates import is_scalar as is_scalar +from .field import EmptyFieldList +from .flavour import CoordinateGuesser +from .metadata import XArrayMetadata as XArrayMetadata +from .time import Time +from .variable import FilteredVariable +from .variable import Variable + +LOG = logging.getLogger(__name__) + + +class XarrayFieldList(FieldList): + def __init__(self, ds, variables): + self.ds = ds + self.variables = variables.copy() + self.total_length = sum(v.length for v in variables) + + def __repr__(self): + return f"XarrayFieldList({self.total_length})" + + def __len__(self): + return self.total_length + + def __getitem__(self, i): + k = i + + if i < 0: + i = self.total_length + i + + for v in self.variables: + if i < v.length: + return v[i] + i -= v.length + + raise IndexError(k) + + @classmethod + def from_xarray(cls, ds, flavour=None): + variables = [] + + if isinstance(flavour, str): + with open(flavour) as f: + if flavour.endswith(".yaml") or flavour.endswith(".yml"): + flavour = yaml.safe_load(f) + else: + flavour = json.load(f) + + guess = CoordinateGuesser.from_flavour(ds, flavour) + + skip = set() + + def _skip_attr(v, attr_name): + attr_val = getattr(v, attr_name, "") + if isinstance(attr_val, str): + skip.update(attr_val.split(" ")) + + for name in ds.data_vars: + v = ds[name] + _skip_attr(v, "coordinates") + _skip_attr(v, "bounds") + _skip_attr(v, "grid_mapping") + + # Select only geographical variables + for name in ds.data_vars: + + if name in skip: + continue + + v = ds[name] + coordinates = [] + + for coord in v.coords: + + c = guess.guess(ds[coord], coord) + assert c, f"Could not guess coordinate for {coord}" + if coord not in v.dims: + c.is_dim = False + coordinates.append(c) + + grid_coords = sum(1 for c in coordinates if c.is_grid and c.is_dim) + assert grid_coords <= 2 + + if grid_coords < 2: + continue + + variables.append( + Variable( + ds=ds, + var=v, + coordinates=coordinates, + grid=guess.grid(coordinates), + time=Time.from_coordinates(coordinates), + metadata={}, + ) + ) + + return cls(ds, variables) + + def sel(self, **kwargs): + """Override the FieldList's sel method + + Returns + ------- + FieldList + The new FieldList + + The algorithm is as follows: + 1 - Use the kwargs to select the variables that match the selection (`param` or `variable`) + 2 - For each variable, use the remaining kwargs to select the coordinates (`level`, `number`, ...) + 3 - Some mars like keys, like `date`, `time`, `step` are not found in the coordinates, + but added to the metadata of the selected fields. A example is `step` that is added to the + metadata of the field. Step 2 may return a variable that contain all the fields that + verify at the same `valid_datetime`, with different base `date` and `time` and a different `step`. + So we get an extra chance to filter the fields by the metadata. + """ + + variables = [] + count = 0 + + for v in self.variables: + + v.update_metadata_mapping(kwargs) + + # First, select matching variables + # This will consume 'param' or 'variable' from kwargs + # and return the rest + match, rest = v.match(**kwargs) + + if match: + count += 1 + missing = {} + + # Select from the variable's coordinates (time, level, number, ....) + # This may return a new variable with a isel() slice of the selection + # or None if the selection is not possible. In this case, missing is updated + # with the values of kwargs (rest) that are not relevant for this variable + v = v.sel(missing, **rest) + if missing: + if v is not None: + # The remaining kwargs are passed used to create a FilteredVariable + # that will select 2D slices based on their metadata + v = FilteredVariable(v, **missing) + else: + LOG.warning(f"Variable {v} has missing coordinates: {missing}") + + if v is not None: + variables.append(v) + + if count == 0: + LOG.warning("No variable found for %s", kwargs) + LOG.warning("Variables: %s", sorted([v.name for v in self.variables])) + + if not variables: + return EmptyFieldList() + + return self.__class__(self.ds, variables) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/flavour.py b/src/anemoi/datasets/create/functions/sources/xarray/flavour.py new file mode 100644 index 00000000..373cd963 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/flavour.py @@ -0,0 +1,330 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from .coordinates import DateCoordinate +from .coordinates import LatitudeCoordinate +from .coordinates import LevelCoordinate +from .coordinates import LongitudeCoordinate +from .coordinates import ScalarCoordinate +from .coordinates import StepCoordinate +from .coordinates import TimeCoordinate +from .coordinates import XCoordinate +from .coordinates import YCoordinate +from .grid import MeshedGrid +from .grid import UnstructuredGrid + + +class CoordinateGuesser: + + def __init__(self, ds): + self.ds = ds + self._cache = {} + + @classmethod + def from_flavour(cls, ds, flavour): + if flavour is None: + return DefaultCoordinateGuesser(ds) + else: + return FlavourCoordinateGuesser(ds, flavour) + + def guess(self, c, coord): + if coord not in self._cache: + self._cache[coord] = self._guess(c, coord) + return self._cache[coord] + + def _guess(self, c, coord): + + name = c.name + standard_name = getattr(c, "standard_name", "").lower() + axis = getattr(c, "axis", "") + long_name = getattr(c, "long_name", "").lower() + units = getattr(c, "units", "") + + d = self._is_longitude( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_latitude( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_x( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_y( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_time( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_step( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_date( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + d = self._is_level( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + + if c.shape in ((1,), tuple()): + return ScalarCoordinate(c) + + raise NotImplementedError( + f"Coordinate {coord} not supported\n{axis=}, {name=}," + f" {long_name=}, {standard_name=}, units\n\n{c}\n\n{type(c.values)} {c.shape}" + ) + + def grid(self, coordinates): + lat = [c for c in coordinates if c.is_lat] + lon = [c for c in coordinates if c.is_lon] + + if len(lat) != 1: + raise NotImplementedError(f"Expected 1 latitude coordinate, got {len(lat)}") + + if len(lon) != 1: + raise NotImplementedError(f"Expected 1 longitude coordinate, got {len(lon)}") + + lat = lat[0] + lon = lon[0] + + if (lat.name, lon.name) in self._cache: + return self._cache[(lat.name, lon.name)] + + assert len(lat.variable.shape) == len(lon.variable.shape), (lat.variable.shape, lon.variable.shape) + if len(lat.variable.shape) == 1: + grid = MeshedGrid(lat, lon) + else: + grid = UnstructuredGrid(lat, lon) + + self._cache[(lat.name, lon.name)] = grid + return grid + + +class DefaultCoordinateGuesser(CoordinateGuesser): + def __init__(self, ds): + super().__init__(ds) + + def _is_longitude(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "longitude": + return LongitudeCoordinate(c) + + if long_name == "longitude" and units == "degrees_east": + return LongitudeCoordinate(c) + + if name == "longitude": # WeatherBench + return LongitudeCoordinate(c) + + def _is_latitude(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "latitude": + return LatitudeCoordinate(c) + + if long_name == "latitude" and units == "degrees_north": + return LatitudeCoordinate(c) + + if name == "latitude": # WeatherBench + return LatitudeCoordinate(c) + + def _is_x(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "projection_x_coordinate": + return XCoordinate(c) + + if name == "x": + return XCoordinate(c) + + def _is_y(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "projection_y_coordinate": + return YCoordinate(c) + + if name == "y": + return YCoordinate(c) + + def _is_time(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "time": + return TimeCoordinate(c) + + if name == "time": + return TimeCoordinate(c) + + def _is_date(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "forecast_reference_time": + return DateCoordinate(c) + if name == "forecast_reference_time": + return DateCoordinate(c) + + def _is_step(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "forecast_period": + return StepCoordinate(c) + + if long_name == "time elapsed since the start of the forecast": + return StepCoordinate(c) + + if name == "prediction_timedelta": # WeatherBench + return StepCoordinate(c) + + def _is_level(self, c, *, axis, name, long_name, standard_name, units): + if standard_name == "atmosphere_hybrid_sigma_pressure_coordinate": + return LevelCoordinate(c, "ml") + + if long_name == "height" and units == "m": + return LevelCoordinate(c, "height") + + if standard_name == "air_pressure" and units == "hPa": + return LevelCoordinate(c, "pl") + + if name == "level": + return LevelCoordinate(c, "pl") + + if name == "vertical" and units == "hPa": + return LevelCoordinate(c, "pl") + + if standard_name == "depth": + return LevelCoordinate(c, "depth") + + if name == "pressure": + return LevelCoordinate(c, "pl") + + +class FlavourCoordinateGuesser(CoordinateGuesser): + def __init__(self, ds, flavour): + super().__init__(ds) + self.flavour = flavour + + def _match(self, c, key, values): + + if key not in self.flavour["rules"]: + return None + + rules = self.flavour["rules"][key] + + if not isinstance(rules, list): + rules = [rules] + + for rule in rules: + ok = True + for k, v in rule.items(): + if isinstance(v, str) and values.get(k) != v: + ok = False + if ok: + return rule + + return None + + def _is_longitude(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "longitude", locals()): + return LongitudeCoordinate(c) + + def _is_latitude(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "latitude", locals()): + return LatitudeCoordinate(c) + + def _is_x(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "x", locals()): + return XCoordinate(c) + + def _is_y(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "y", locals()): + return YCoordinate(c) + + def _is_time(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "time", locals()): + return TimeCoordinate(c) + + def _is_step(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "step", locals()): + return StepCoordinate(c) + + def _is_date(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "date", locals()): + return DateCoordinate(c) + + def _is_level(self, c, *, axis, name, long_name, standard_name, units): + + rule = self._match(c, "level", locals()) + if rule: + # assert False, rule + return LevelCoordinate( + c, + self._levtype( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ), + ) + + def _levtype(self, c, *, axis, name, long_name, standard_name, units): + if "levtype" in self.flavour: + return self.flavour["levtype"] + + raise NotImplementedError(f"levtype for {c=}") diff --git a/src/anemoi/datasets/create/functions/sources/xarray/grid.py b/src/anemoi/datasets/create/functions/sources/xarray/grid.py new file mode 100644 index 00000000..8a6349ed --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/grid.py @@ -0,0 +1,46 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import numpy as np + + +class Grid: + def __init__(self, lat, lon): + self.lat = lat + self.lon = lon + + @property + def latitudes(self): + return self.grid_points()[0] + + @property + def longitudes(self): + return self.grid_points()[1] + + +class MeshedGrid(Grid): + _cache = None + + def grid_points(self): + if self._cache is not None: + return self._cache + lat = self.lat.variable.values + lon = self.lon.variable.values + + lat, lon = np.meshgrid(lat, lon) + self._cache = (lat.flatten(), lon.flatten()) + return self._cache + + +class UnstructuredGrid(Grid): + def grid_points(self): + lat = self.lat.variable.values.flatten() + lon = self.lon.variable.values.flatten() + return lat, lon diff --git a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py new file mode 100644 index 00000000..471fc6bc --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py @@ -0,0 +1,159 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from functools import cached_property + +from earthkit.data.core.geography import Geography +from earthkit.data.core.metadata import RawMetadata +from earthkit.data.utils.dates import to_datetime +from earthkit.data.utils.projections import Projection + +LOG = logging.getLogger(__name__) + + +class MDMapping: + + def __init__(self, mapping): + self.user_to_internal = mapping + + def from_user(self, kwargs): + if isinstance(kwargs, str): + return self.user_to_internal.get(kwargs, kwargs) + return {self.user_to_internal.get(k, k): v for k, v in kwargs.items()} + + def __len__(self): + return len(self.user_to_internal) + + def __repr__(self): + return f"MDMapping({self.user_to_internal})" + + +class XArrayMetadata(RawMetadata): + LS_KEYS = ["variable", "level", "valid_datetime", "units"] + NAMESPACES = ["default", "mars"] + MARS_KEYS = ["param", "step", "levelist", "levtype", "number", "date", "time"] + + def __init__(self, field, mapping): + self._field = field + md = field._md.copy() + + self._mapping = mapping + if mapping is None: + time_coord = [c for c in field.owner.coordinates if c.is_time] + if len(time_coord) == 1: + time_key = time_coord[0].name + else: + time_key = "time" + else: + time_key = mapping.from_user("valid_datetime") + self._time = to_datetime(md.pop(time_key)) + self._field.owner.time.fill_time_metadata(self._time, md) + md["valid_datetime"] = self._time.isoformat() + + super().__init__(md) + + @cached_property + def geography(self): + return XArrayFieldGeography(self._field, self._field.owner.grid) + + def as_namespace(self, namespace=None): + if not isinstance(namespace, str) and namespace is not None: + raise TypeError("namespace must be a str or None") + + if namespace == "default" or namespace == "" or namespace is None: + return dict(self) + + elif namespace == "mars": + return self._as_mars() + + def _as_mars(self): + return dict( + param=self["variable"], + step=self["step"], + levelist=self["level"], + levtype=self["levtype"], + number=self["number"], + date=self["date"], + time=self["time"], + ) + + def _base_datetime(self): + return self._field.forecast_reference_time + + def _valid_datetime(self): + return self._time + + def _get(self, key, **kwargs): + + if key.startswith("mars."): + key = key[5:] + if key not in self.MARS_KEYS: + if kwargs.get("raise_on_missing", False): + raise KeyError(f"Invalid key '{key}' in namespace='mars'") + else: + return kwargs.get("default", None) + + key = self._mapping.from_user(key) + return super()._get(key, **kwargs) + + +class XArrayFieldGeography(Geography): + def __init__(self, field, grid): + self._field = field + self._grid = grid + + def _unique_grid_id(self): + raise NotImplementedError() + + def bounding_box(self): + raise NotImplementedError() + # return BoundingBox(north=self.north, south=self.south, east=self.east, west=self.west) + + def gridspec(self): + raise NotImplementedError() + + def latitudes(self, dtype=None): + result = self._grid.latitudes + if dtype is not None: + return result.astype(dtype) + return result + + def longitudes(self, dtype=None): + result = self._grid.longitudes + if dtype is not None: + return result.astype(dtype) + return result + + def resolution(self): + # TODO: implement resolution + return None + + @property + def mars_grid(self): + # TODO: implement mars_grid + return None + + @property + def mars_area(self): + # TODO: code me + # return [self.north, self.west, self.south, self.east] + return None + + def x(self, dtype=None): + raise NotImplementedError() + + def y(self, dtype=None): + raise NotImplementedError() + + def shape(self): + return self._field.shape + + def projection(self): + return Projection.from_cf_grid_mapping(**self._field.grid_mapping) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/time.py b/src/anemoi/datasets/create/functions/sources/xarray/time.py new file mode 100644 index 00000000..eb2e2eaa --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/time.py @@ -0,0 +1,98 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import datetime + + +class Time: + @classmethod + def from_coordinates(cls, coordinates): + time_coordinate = [c for c in coordinates if c.is_time] + step_coordinate = [c for c in coordinates if c.is_step] + date_coordinate = [c for c in coordinates if c.is_date] + + if len(date_coordinate) == 0 and len(time_coordinate) == 1 and len(step_coordinate) == 1: + return ForecasstFromValidTimeAndStep(step_coordinate[0]) + + if len(date_coordinate) == 0 and len(time_coordinate) == 1 and len(step_coordinate) == 0: + return Analysis() + + if len(date_coordinate) == 0 and len(time_coordinate) == 0 and len(step_coordinate) == 0: + return Constant() + + if len(date_coordinate) == 1 and len(time_coordinate) == 1 and len(step_coordinate) == 0: + return ForecastFromValidTimeAndBaseTime(date_coordinate[0]) + + if len(date_coordinate) == 1 and len(time_coordinate) == 0 and len(step_coordinate) == 1: + return ForecastFromBaseTimeAndDate(date_coordinate[0], step_coordinate[0]) + + raise NotImplementedError(f"{date_coordinate=} {time_coordinate=} {step_coordinate=}") + + +class Constant(Time): + + def fill_time_metadata(self, time, metadata): + metadata["date"] = time.strftime("%Y%m%d") + metadata["time"] = time.strftime("%H%M") + metadata["step"] = 0 + + +class Analysis(Time): + + def fill_time_metadata(self, time, metadata): + metadata["date"] = time.strftime("%Y%m%d") + metadata["time"] = time.strftime("%H%M") + metadata["step"] = 0 + + +class ForecasstFromValidTimeAndStep(Time): + def __init__(self, step_coordinate): + self.step_name = step_coordinate.variable.name + + def fill_time_metadata(self, time, metadata): + step = metadata.pop(self.step_name) + assert isinstance(step, datetime.timedelta) + base = time - step + + hours = step.total_seconds() / 3600 + assert int(hours) == hours + + metadata["date"] = base.strftime("%Y%m%d") + metadata["time"] = base.strftime("%H%M") + metadata["step"] = int(hours) + + +class ForecastFromValidTimeAndBaseTime(Time): + def __init__(self, date_coordinate): + self.date_coordinate = date_coordinate + + def fill_time_metadata(self, time, metadata): + + step = time - self.date_coordinate + + hours = step.total_seconds() / 3600 + assert int(hours) == hours + + metadata["date"] = self.date_coordinate.single_value.strftime("%Y%m%d") + metadata["time"] = self.date_coordinate.single_value.strftime("%H%M") + metadata["step"] = int(hours) + + +class ForecastFromBaseTimeAndDate(Time): + def __init__(self, date_coordinate, step_coordinate): + self.date_coordinate = date_coordinate + self.step_coordinate = step_coordinate + + def fill_time_metadata(self, time, metadata): + metadata["date"] = self.date_coordinate.single_value.strftime("%Y%m%d") + metadata["time"] = self.date_coordinate.single_value.strftime("%H%M") + hours = self.step_coordinate.total_seconds() / 3600 + assert int(hours) == hours + metadata["step"] = int(hours) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/variable.py b/src/anemoi/datasets/create/functions/sources/xarray/variable.py new file mode 100644 index 00000000..e1f0225a --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/variable.py @@ -0,0 +1,198 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import math +from functools import cached_property + +import numpy as np +from earthkit.data.utils.array import ensure_backend + +from anemoi.datasets.create.functions.sources.xarray.metadata import MDMapping + +from .field import XArrayField + +LOG = logging.getLogger(__name__) + + +class Variable: + def __init__(self, *, ds, var, coordinates, grid, time, metadata, mapping=None, array_backend=None): + self.ds = ds + self.var = var + + self.grid = grid + self.coordinates = coordinates + + # print("Variable", var.name) + # for c in coordinates: + # print(" ", c) + + self._metadata = metadata.copy() + # self._metadata.update(var.attrs) + self._metadata.update({"variable": var.name}) + + # self._metadata.setdefault("level", None) + # self._metadata.setdefault("number", 0) + # self._metadata.setdefault("levtype", "sfc") + self._mapping = mapping + + self.time = time + + self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid) + self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid} + self.by_name = {c.variable.name: c for c in coordinates} + + self.length = math.prod(self.shape) + self.array_backend = ensure_backend(array_backend) + + def update_metadata_mapping(self, kwargs): + + result = {} + + for k, v in kwargs.items(): + if k == "param": + result[k] = "variable" + continue + + for c in self.coordinates: + if k in c.mars_names: + for v in c.mars_names: + result[v] = c.variable.name + break + + self._mapping = MDMapping(result) + + @property + def name(self): + return self.var.name + + def __len__(self): + return self.length + + @property + def grid_mapping(self): + grid_mapping = self.var.attrs.get("grid_mapping", None) + if grid_mapping is None: + return None + return self.ds[grid_mapping].attrs + + def grid_points(self): + return self.grid.grid_points() + + @property + def latitudes(self): + return self.grid.latitudes + + @property + def longitudes(self): + return self.grid.longitudes + + def __repr__(self): + return "Variable[name=%s,coordinates=%s,metadata=%s]" % ( + self.var.name, + self.coordinates, + self._metadata, + ) + + def __getitem__(self, i): + """ + Get a 2D field from the variable + """ + if i >= self.length: + raise IndexError(i) + + coords = np.unravel_index(i, self.shape) + kwargs = {k: v for k, v in zip(self.names, coords)} + return XArrayField(self, self.var.isel(kwargs)) + + @property + def mapping(self): + return self._mapping + + def sel(self, missing, **kwargs): + + if not kwargs: + return self + + kwargs = self._mapping.from_user(kwargs) + + k, v = kwargs.popitem() + + c = self.by_name.get(k) + + if c is None: + missing[k] = v + return self.sel(missing, **kwargs) + + i = c.index(v) + if i is None: + LOG.warning(f"Could not find {k}={v} in {c}") + return None + + coordinates = [x.reduced(i) if c is x else x for x in self.coordinates] + + metadata = self._metadata.copy() + metadata.update({k: v}) + + variable = Variable( + ds=self.ds, + var=self.var.isel({k: i}), + coordinates=coordinates, + grid=self.grid, + time=self.time, + metadata=metadata, + mapping=self.mapping, + ) + + return variable.sel(missing, **kwargs) + + def match(self, **kwargs): + kwargs = self._mapping.from_user(kwargs) + + if "variable" in kwargs: + name = kwargs.pop("variable") + if not isinstance(name, (list, tuple)): + name = [name] + if self.var.name not in name: + return False, None + return True, kwargs + return True, kwargs + + +class FilteredVariable: + def __init__(self, variable, **kwargs): + self.variable = variable + self.kwargs = kwargs + + @cached_property + def fields(self): + """Filter the fields of a variable based on metadata. + + Returns + ------- + list + A list of fields that match the metadata. + """ + return [ + field + for field in self.variable + if all(field.metadata(k, default=None) == v for k, v in self.kwargs.items()) + ] + + @property + def length(self): + return len(self.fields) + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i >= self.length: + raise IndexError(i) + return self.fields[i] diff --git a/src/anemoi/datasets/create/functions/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/functions/sources/xarray_kerchunk.py new file mode 100644 index 00000000..a13da6a5 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray_kerchunk.py @@ -0,0 +1,42 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from earthkit.data.core.fieldlist import MultiFieldList + +from . import iterate_patterns +from .xarray import load_one + + +def load_many(emoji, context, dates, pattern, options, **kwargs): + + result = [] + options = options.copy() if options is not None else {} + + options.setdefault("engine", "zarr") + options.setdefault("backend_kwargs", {}) + + backend_kwargs = options["backend_kwargs"] + backend_kwargs.setdefault("consolidated", False) + backend_kwargs.setdefault("storage_options", {}) + + storage_options = backend_kwargs["storage_options"] + storage_options.setdefault("remote_protocol", "s3") + storage_options.setdefault("remote_options", {"anon": True}) + + for path, dates in iterate_patterns(pattern, dates, **kwargs): + storage_options["fo"] = path + + result.append(load_one(emoji, context, dates, "reference://", options=options, **kwargs)) + + return MultiFieldList(result) + + +def execute(context, dates, json, options=None, **kwargs): + return load_many("🧱", context, dates, json, options, **kwargs) diff --git a/src/anemoi/datasets/create/functions/sources/xarray_zarr.py b/src/anemoi/datasets/create/functions/sources/xarray_zarr.py new file mode 100644 index 00000000..49a9add3 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray_zarr.py @@ -0,0 +1,15 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from .xarray import load_many + + +def execute(context, dates, url, *args, **kwargs): + return load_many("πŸ‡Ώ", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/functions/sources/zenodo.py b/src/anemoi/datasets/create/functions/sources/zenodo.py new file mode 100644 index 00000000..edb83b0f --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/zenodo.py @@ -0,0 +1,40 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +from earthkit.data.core.fieldlist import MultiFieldList +from earthkit.data.sources.url import download_and_cache + +from . import iterate_patterns +from .xarray import load_one + + +def execute(context, dates, record_id, file_key, *args, **kwargs): + import requests + + result = [] + + URLPATTERN = "https://zenodo.org/api/records/{record_id}" + url = URLPATTERN.format(record_id=record_id) + r = requests.get(url) + r.raise_for_status() + record = r.json() + + urls = {} + for file in record["files"]: + urls[file["key"]] = file["links"]["self"] + + for url, dates in iterate_patterns(file_key, dates, **kwargs): + if url not in urls: + continue + + path = download_and_cache(urls[url]) + result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) + + return MultiFieldList(result) diff --git a/src/anemoi/datasets/create/input.py b/src/anemoi/datasets/create/input.py index 0b639cca..7527e731 100644 --- a/src/anemoi/datasets/create/input.py +++ b/src/anemoi/datasets/create/input.py @@ -7,7 +7,9 @@ # nor does it submit to any jurisdiction. # import datetime +import itertools import logging +import math import time from collections import defaultdict from copy import deepcopy @@ -15,7 +17,10 @@ from functools import wraps import numpy as np +from anemoi.utils.humanize import seconds_to_human +from anemoi.utils.humanize import shorten_list from earthkit.data.core.fieldlist import FieldList +from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.core.order import build_remapping from anemoi.datasets.dates import Dates @@ -25,29 +30,33 @@ from .template import notify_result from .template import resolve from .template import substitute -from .template import trace -from .template import trace_datasource -from .template import trace_select -from .utils import seconds +from .trace import trace +from .trace import trace_datasource +from .trace import trace_select LOG = logging.getLogger(__name__) def parse_function_name(name): - if "-" in name: - name, delta = name.split("-") - sign = -1 - elif "+" in name: - name, delta = name.split("+") - sign = 1 + if name.endswith("h") and name[:-1].isdigit(): - else: - return name, None + if "-" in name: + name, delta = name.split("-") + sign = -1 - assert delta[-1] == "h", (name, delta) - delta = sign * int(delta[:-1]) - return name, delta + elif "+" in name: + name, delta = name.split("+") + sign = 1 + + else: + return name, None + + assert delta[-1] == "h", (name, delta) + delta = sign * int(delta[:-1]) + return name, delta + + return name, None def time_delta_to_string(delta): @@ -134,141 +143,6 @@ def sort(old_dic): return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class Coords: - def __init__(self, owner): - self.owner = owner - - @cached_property - def _build_coords(self): - from_data = self.owner.get_cube().user_coords - from_config = self.owner.context.order_by - - keys_from_config = list(from_config.keys()) - keys_from_data = list(from_data.keys()) - assert ( - keys_from_data == keys_from_config - ), f"Critical error: {keys_from_data=} != {keys_from_config=}. {self.owner=}" - - variables_key = list(from_config.keys())[1] - ensembles_key = list(from_config.keys())[2] - - if isinstance(from_config[variables_key], (list, tuple)): - assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), ( - from_data[variables_key], - from_config[variables_key], - ) - - self._variables = from_data[variables_key] # "param_level" - self._ensembles = from_data[ensembles_key] # "number" - - first_field = self.owner.datasource[0] - grid_points = first_field.grid_points() - - lats, lons = grid_points - north = np.amax(lats) - south = np.amin(lats) - east = np.amax(lons) - west = np.amin(lons) - - assert -90 <= south <= north <= 90, (south, north, first_field) - assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), ( - west, - east, - first_field, - ) - - grid_values = list(range(len(grid_points[0]))) - - self._grid_points = grid_points - self._resolution = first_field.resolution - self._grid_values = grid_values - self._field_shape = first_field.shape - self._proj_string = first_field.proj_string if hasattr(first_field, "proj_string") else None - - @cached_property - def variables(self): - self._build_coords - return self._variables - - @cached_property - def ensembles(self): - self._build_coords - return self._ensembles - - @cached_property - def resolution(self): - self._build_coords - return self._resolution - - @cached_property - def grid_values(self): - self._build_coords - return self._grid_values - - @cached_property - def grid_points(self): - self._build_coords - return self._grid_points - - @cached_property - def field_shape(self): - self._build_coords - return self._field_shape - - @cached_property - def proj_string(self): - self._build_coords - return self._proj_string - - -class HasCoordsMixin: - @cached_property - def variables(self): - return self._coords.variables - - @cached_property - def ensembles(self): - return self._coords.ensembles - - @cached_property - def resolution(self): - return self._coords.resolution - - @cached_property - def grid_values(self): - return self._coords.grid_values - - @cached_property - def grid_points(self): - return self._coords.grid_points - - @cached_property - def field_shape(self): - return self._coords.field_shape - - @cached_property - def proj_string(self): - return self._coords.proj_string - - @cached_property - def shape(self): - return [ - len(self.dates), - len(self.variables), - len(self.ensembles), - len(self.grid_values), - ] - - @cached_property - def coords(self): - return { - "dates": self.dates, - "variables": self.variables, - "ensembles": self.ensembles, - "values": self.grid_values, - } - - class Action: def __init__(self, context, action_path, /, *args, **kwargs): if "args" in kwargs and "kwargs" in kwargs: @@ -323,15 +197,15 @@ def shorten(dates): return dates -class Result(HasCoordsMixin): +class Result: empty = False + _coords_already_built = False def __init__(self, context, action_path, dates): assert isinstance(context, ActionContext), type(context) assert isinstance(action_path, list), action_path self.context = context - self._coords = Coords(self) self.dates = dates self.action_path = action_path @@ -353,19 +227,142 @@ def get_cube(self): order_by = self.context.order_by flatten_grid = self.context.flatten_grid start = time.time() - LOG.info("Sorting dataset %s %s", order_by, remapping) + LOG.debug("Sorting dataset %s %s", dict(order_by), remapping) assert order_by, order_by - cube = ds.cube( - order_by, - remapping=remapping, - flatten_values=flatten_grid, - patches={"number": {None: 0}}, - ) - cube = cube.squeeze() - LOG.info(f"Sorting done in {seconds(time.time()-start)}.") + + patches = {"number": {None: 0}} + + try: + cube = ds.cube( + order_by, + remapping=remapping, + flatten_values=flatten_grid, + patches=patches, + ) + cube = cube.squeeze() + LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.") + except ValueError: + self.explain(ds, order_by, remapping=remapping, patches=patches) + # raise ValueError(f"Error in {self}") + exit(1) + + if LOG.isEnabledFor(logging.DEBUG): + LOG.debug("Cube shape: %s", cube) + for k, v in cube.user_coords.items(): + LOG.debug(" %s %s", k, shorten_list(v, max_length=10)) return cube + def explain(self, ds, *args, remapping, patches): + + METADATA = ( + "date", + "time", + "step", + "hdate", + "valid_datetime", + "levtype", + "levelist", + "number", + "level", + "shortName", + "paramId", + "variable", + ) + + # We redo the logic here + print() + print("❌" * 40) + print() + if len(args) == 1 and isinstance(args[0], (list, tuple)): + args = args[0] + + names = [] + for a in args: + if isinstance(a, str): + names.append(a) + elif isinstance(a, dict): + names += list(a.keys()) + + print(f"Building a {len(names)}D hypercube using", names) + + ds = ds.order_by(*args, remapping=remapping, patches=patches) + user_coords = ds.unique_values(*names, remapping=remapping, patches=patches) + + print() + print("Number of unique values found for each coordinate:") + for k, v in user_coords.items(): + print(f" {k:20}:", len(v)) + print() + user_shape = tuple(len(v) for k, v in user_coords.items()) + print("Shape of the hypercube :", user_shape) + print( + "Number of expected fields :", math.prod(user_shape), "=", " x ".join([str(i) for i in user_shape]) + ) + print("Number of fields in the dataset :", len(ds)) + print("Difference :", abs(len(ds) - math.prod(user_shape))) + print() + + remapping = build_remapping(remapping, patches) + expected = set(itertools.product(*user_coords.values())) + + if math.prod(user_shape) > len(ds): + print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.") + + for f in ds: + metadata = remapping(f.metadata) + expected.remove(tuple(metadata(n) for n in names)) + + print("Missing fields:") + print() + for i, f in enumerate(sorted(expected)): + print(" ", f) + if i >= 9 and len(expected) > 10: + print("...", len(expected) - i - 1, "more") + break + + print() + print("To solve this issue, you can:") + print( + " - Provide a better selection, like 'step: 0' or 'level: 1000' to " + "reduce the number of selected fields." + ) + print( + " - Split the 'input' part in smaller sections using 'join', " + "making sure that each section represent a full hypercube." + ) + + else: + print(f"More fields in dataset that expected for {names}. " "This means that some fields are duplicated.") + duplicated = defaultdict(list) + for f in ds: + # print(f.metadata(namespace="default")) + metadata = remapping(f.metadata) + key = tuple(metadata(n, default=None) for n in names) + duplicated[key].append(f) + + print("Duplicated fields:") + print() + duplicated = {k: v for k, v in duplicated.items() if len(v) > 1} + for i, (k, v) in enumerate(sorted(duplicated.items())): + print(" ", k) + for f in v: + x = {k: f.metadata(k, default=None) for k in METADATA if f.metadata(k, default=None) is not None} + print(" ", f, x) + if i >= 9 and len(duplicated) > 10: + print("...", len(duplicated) - i - 1, "more") + break + + print() + print("To solve this issue, you can:") + print(" - Provide a better selection, like 'step: 0' or 'level: 1000'") + print(" - Change the way 'param' is computed using 'variable_naming' " "in the 'build' section.") + + print() + print("❌" * 40) + print() + exit(1) + def __repr__(self, *args, _indent_="\n", **kwargs): more = ",".join([str(a)[:5000] for a in args]) more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) @@ -391,6 +388,109 @@ def _raise_not_implemented(self): def _trace_datasource(self, *args, **kwargs): return f"{self.__class__.__name__}({shorten(self.dates)})" + def build_coords(self): + if self._coords_already_built: + return + from_data = self.get_cube().user_coords + from_config = self.context.order_by + + keys_from_config = list(from_config.keys()) + keys_from_data = list(from_data.keys()) + assert keys_from_data == keys_from_config, f"Critical error: {keys_from_data=} != {keys_from_config=}. {self=}" + + variables_key = list(from_config.keys())[1] + ensembles_key = list(from_config.keys())[2] + + if isinstance(from_config[variables_key], (list, tuple)): + assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), ( + from_data[variables_key], + from_config[variables_key], + ) + + self._variables = from_data[variables_key] # "param_level" + self._ensembles = from_data[ensembles_key] # "number" + + first_field = self.datasource[0] + grid_points = first_field.grid_points() + + lats, lons = grid_points + + assert len(lats) == len(lons), (len(lats), len(lons), first_field) + assert len(lats) == math.prod(first_field.shape), (len(lats), first_field.shape, first_field) + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + assert -90 <= south <= north <= 90, (south, north, first_field) + assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), ( + west, + east, + first_field, + ) + + grid_values = list(range(len(grid_points[0]))) + + self._grid_points = grid_points + self._resolution = first_field.resolution + self._grid_values = grid_values + self._field_shape = first_field.shape + self._proj_string = first_field.proj_string if hasattr(first_field, "proj_string") else None + + @property + def variables(self): + self.build_coords() + return self._variables + + @property + def ensembles(self): + self.build_coords() + return self._ensembles + + @property + def resolution(self): + self.build_coords() + return self._resolution + + @property + def grid_values(self): + self.build_coords() + return self._grid_values + + @property + def grid_points(self): + self.build_coords() + return self._grid_points + + @property + def field_shape(self): + self.build_coords() + return self._field_shape + + @property + def proj_string(self): + self.build_coords() + return self._proj_string + + @cached_property + def shape(self): + return [ + len(self.dates), + len(self.variables), + len(self.ensembles), + len(self.grid_values), + ] + + @cached_property + def coords(self): + return { + "dates": self.dates, + "variables": self.variables, + "ensembles": self.ensembles, + "values": self.grid_values, + } + class EmptyResult(Result): empty = True @@ -411,6 +511,22 @@ def variables(self): return [] +def _flatten(ds): + if isinstance(ds, MultiFieldList): + return [_tidy(f) for s in ds._indexes for f in _flatten(s)] + return [ds] + + +def _tidy(ds, indent=0): + if isinstance(ds, MultiFieldList): + + sources = [s for s in _flatten(ds) if len(s) > 0] + if len(sources) == 1: + return sources[0] + return MultiFieldList(sources) + return ds + + class FunctionResult(Result): def __init__(self, context, action_path, dates, action): super().__init__(context, action_path, dates) @@ -430,7 +546,7 @@ def datasource(self): args, kwargs = resolve(self.context, (self.args, self.kwargs)) try: - return self.action.function(FunctionContext(self), self.dates, *args, **kwargs) + return _tidy(self.action.function(FunctionContext(self), self.dates, *args, **kwargs)) except Exception: LOG.error(f"Error in {self.action.function.__name__}", exc_info=True) raise @@ -459,7 +575,7 @@ def datasource(self): ds = EmptyResult(self.context, self.action_path, self.dates).datasource for i in self.results: ds += i.datasource - return ds + return _tidy(ds) def __repr__(self): content = "\n".join([str(i) for i in self.results]) @@ -533,7 +649,7 @@ def __getattr__(self, name): ds = self.result.datasource ds = FieldArray([DateShiftedField(fs, self.action.delta) for fs in ds]) - return ds + return _tidy(ds) class FunctionAction(Action): @@ -620,11 +736,13 @@ class StepFunctionResult(StepResult): @trace_datasource def datasource(self): try: - return self.action.function( - FunctionContext(self), - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, + return _tidy( + self.action.function( + FunctionContext(self), + self.upstream_result.datasource, + *self.action.args[1:], + **self.action.kwargs, + ) ) except Exception: @@ -643,7 +761,7 @@ class FilterStepResult(StepResult): def datasource(self): ds = self.upstream_result.datasource ds = ds.sel(**self.action.kwargs) - return ds + return _tidy(ds) class FilterStepAction(StepAction): @@ -672,7 +790,7 @@ def datasource(self): ds = EmptyResult(self.context, self.action_path, self.dates).datasource for i in self.results: ds += i.datasource - return ds + return _tidy(ds) @property def variables(self): @@ -708,7 +826,7 @@ def datasource(self): self.context.notify_result(i.action_path[:-1], i.datasource) # then return the input result # which can use the datasources of the included results - return self.input_result.datasource + return _tidy(self.input_result.datasource) class DataSourcesAction(Action): diff --git a/src/anemoi/datasets/create/loaders.py b/src/anemoi/datasets/create/loaders.py index 62499820..b90e07f6 100644 --- a/src/anemoi/datasets/create/loaders.py +++ b/src/anemoi/datasets/create/loaders.py @@ -5,6 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. import datetime +import json import logging import os import time @@ -13,7 +14,10 @@ from functools import cached_property import numpy as np +import tqdm import zarr +from anemoi.utils.config import DotDict +from anemoi.utils.humanize import seconds_to_human from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset @@ -25,7 +29,6 @@ from .check import DatasetName from .check import check_data_values from .chunks import ChunkFilter -from .config import DictObj from .config import build_output from .config import loader_config from .input import build_input @@ -35,8 +38,6 @@ from .statistics import compute_statistics from .statistics import default_statistics_dates from .utils import normalize_and_check_dates -from .utils import progress_bar -from .utils import seconds from .writer import ViewCacheArray from .zarr import ZarrBuiltRegistry from .zarr import add_zarr_dataset @@ -65,7 +66,7 @@ def set_element_to_test(obj): for v in obj: set_element_to_test(v) return - if isinstance(obj, (dict, DictObj)): + if isinstance(obj, (dict, DotDict)): if "grid" in obj: previous = obj["grid"] obj["grid"] = "20./20." @@ -77,12 +78,16 @@ def set_element_to_test(obj): LOG.warn(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") for k, v in obj.items(): set_element_to_test(v) + if "constants" in obj: + constants = obj["constants"] + if "param" in constants and isinstance(constants["param"], list): + constants["param"] = ["cos_latitude"] set_element_to_test(cfg) class GenericDatasetHandler: - def __init__(self, *, path, print=print, **kwargs): + def __init__(self, *, path, use_threads=False, **kwargs): # Catch all floating point errors, including overflow, sqrt(<0), etc np.seterr(all="raise", under="warn") @@ -91,33 +96,33 @@ def __init__(self, *, path, print=print, **kwargs): self.path = path self.kwargs = kwargs - self.print = print + self.use_threads = use_threads if "test" in kwargs: self.test = kwargs["test"] @classmethod - def from_config(cls, *, config, path, print=print, **kwargs): + def from_config(cls, *, config, path, use_threads=False, **kwargs): """Config is the path to the config file or a dict with the config""" assert isinstance(config, dict) or isinstance(config, str), config - return cls(config=config, path=path, print=print, **kwargs) + return cls(config=config, path=path, use_threads=use_threads, **kwargs) @classmethod - def from_dataset_config(cls, *, path, print=print, **kwargs): + def from_dataset_config(cls, *, path, use_threads=False, **kwargs): """Read the config saved inside the zarr dataset and instantiate the class for this config.""" assert os.path.exists(path), f"Path {path} does not exist." z = zarr.open(path, mode="r") config = z.attrs["_create_yaml_config"] - LOG.info(f"Config loaded from zarr config: {config}") - return cls.from_config(config=config, path=path, print=print, **kwargs) + LOG.debug("Config loaded from zarr config:\n%s", json.dumps(config, indent=4, sort_keys=True, default=str)) + return cls.from_config(config=config, path=path, use_threads=use_threads, **kwargs) @classmethod - def from_dataset(cls, *, path, **kwargs): + def from_dataset(cls, *, path, use_threads=False, **kwargs): """Instanciate the class from the path to the zarr dataset, without config.""" assert os.path.exists(path), f"Path {path} does not exist." - return cls(path=path, **kwargs) + return cls(path=path, use_threads=use_threads, **kwargs) def read_dataset_metadata(self): ds = open_dataset(self.path) @@ -131,14 +136,22 @@ def read_dataset_metadata(self): z = zarr.open(self.path, "r") missing_dates = z.attrs.get("missing_dates", []) missing_dates = sorted([np.datetime64(d) for d in missing_dates]) - assert missing_dates == self.missing_dates, (missing_dates, self.missing_dates) + + if missing_dates != self.missing_dates: + LOG.warn("Missing dates given in recipe do not match the actual missing dates in the dataset.") + LOG.warn(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") + LOG.warn(f"Missing dates in dataset: {sorted(str(x) for x in self.missing_dates)}") + raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") @cached_property def registry(self): - return ZarrBuiltRegistry(self.path) + return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) + + def ready(self): + return all(self.registry.get_flags()) def update_metadata(self, **kwargs): - LOG.info(f"Updating metadata {kwargs}") + LOG.debug(f"Updating metadata {kwargs}") z = zarr.open(self.path, mode="w+") for k, v in kwargs.items(): if isinstance(v, np.datetime64): @@ -170,7 +183,7 @@ class DatasetHandler(GenericDatasetHandler): class DatasetHandlerWithStatistics(GenericDatasetHandler): def __init__(self, statistics_tmp=None, **kwargs): super().__init__(**kwargs) - statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".tmp_data", "statistics") + statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".storage_for_statistics.tmp") self.tmp_statistics = TmpStatistics(statistics_tmp) @@ -186,12 +199,16 @@ def build_input(self): remapping=build_remapping(self.output.remapping), use_grib_paramid=self.main_config.build.use_grib_paramid, ) - LOG.info("βœ… INPUT_BUILDER") - LOG.info(builder) + LOG.debug("βœ… INPUT_BUILDER") + LOG.debug(builder) return builder - def allow_nan(self, name): - return name in self.main_config.statistics.get("allow_nans", []) + @property + def allow_nans(self): + if "allow_nans" in self.main_config.build: + return self.main_config.build.allow_nans + + return self.main_config.statistics.get("allow_nans", []) class InitialiserLoader(Loader): @@ -202,7 +219,7 @@ def __init__(self, config, **kwargs): if self.test: set_to_test_mode(self.main_config) - LOG.info(self.main_config.dates) + LOG.info(dict(self.main_config.dates)) self.tmp_statistics.delete() @@ -255,26 +272,25 @@ def initialise(self, check_name=True): Read a small part of the data to get the shape of the data and the resolution and more metadata. """ - self.print("Config loaded ok:") - LOG.info(self.main_config) + LOG.info("Config loaded ok:") + # LOG.info(self.main_config) dates = self.groups.dates frequency = dates.frequency assert isinstance(frequency, int), frequency - self.print(f"Found {len(dates)} datetimes.") + LOG.info(f"Found {len(dates)} datetimes.") LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") LOG.info(f"Missing dates: {len(dates.missing)}") - lengths = [len(g) for g in self.groups] - self.print(f"Found {len(dates)} datetimes {'+'.join([str(_) for _ in lengths])}.") + lengths = tuple(len(g) for g in self.groups) variables = self.minimal_input.variables - self.print(f"Found {len(variables)} variables : {','.join(variables)}.") + LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") variables_with_nans = self.main_config.statistics.get("allow_nans", []) ensembles = self.minimal_input.ensembles - self.print(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") + LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") grid_points = self.minimal_input.grid_points LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") @@ -286,13 +302,13 @@ def initialise(self, check_name=True): coords["dates"] = dates total_shape = self.minimal_input.shape total_shape[0] = len(dates) - self.print(f"total_shape = {total_shape}") + LOG.info(f"total_shape = {total_shape}") chunks = self.output.get_chunking(coords) LOG.info(f"{chunks=}") dtype = self.output.dtype - self.print(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") + LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") metadata = {} metadata["uuid"] = str(uuid.uuid4()) @@ -312,6 +328,7 @@ def initialise(self, check_name=True): metadata["ensemble_dimension"] = len(ensembles) metadata["variables"] = variables metadata["variables_with_nans"] = variables_with_nans + metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) metadata["resolution"] = resolution metadata["data_request"] = self.minimal_input.data_request @@ -328,7 +345,7 @@ def initialise(self, check_name=True): if check_name: basename, ext = os.path.splitext(os.path.basename(self.path)) # noqa: F841 ds_name = DatasetName(basename, resolution, dates[0], dates[-1], frequency) - ds_name.raise_if_not_valid(print=self.print) + ds_name.raise_if_not_valid() if len(dates) != total_shape[0]: raise ValueError( @@ -348,10 +365,16 @@ def initialise(self, check_name=True): self.update_metadata(**metadata) - self._add_dataset(name="data", chunks=chunks, dtype=dtype, shape=total_shape) - self._add_dataset(name="dates", array=dates) - self._add_dataset(name="latitudes", array=grid_points[0]) - self._add_dataset(name="longitudes", array=grid_points[1]) + self._add_dataset( + name="data", + chunks=chunks, + dtype=dtype, + shape=total_shape, + dimensions=("time", "variable", "ensemble", "cell"), + ) + self._add_dataset(name="dates", array=dates, dimensions=("time",)) + self._add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) + self._add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) self.registry.create(lengths=lengths) self.tmp_statistics.create(exist_ok=False) @@ -368,6 +391,9 @@ def initialise(self, check_name=True): assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks()) + # Return the number of groups to process, so we can show a nice progress bar + return len(lengths) + class ContentLoader(Loader): def __init__(self, config, parts, **kwargs): @@ -387,35 +413,29 @@ def __init__(self, config, parts, **kwargs): self.n_groups = len(self.groups) def load(self): - self.registry.add_to_history("loading_data_start", parts=self.parts) - for igroup, group in enumerate(self.groups): if not self.chunk_filter(igroup): continue if self.registry.get_flag(igroup): LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") continue - # self.print(f" -> Processing {igroup} total={len(self.groups)}") - # print("========", group) + assert isinstance(group[0], datetime.datetime), group result = self.input.select(dates=group) assert result.dates == group, (len(result.dates), len(group)) - msg = f"Building data for group {igroup}/{self.n_groups}" - LOG.info(msg) - self.print(msg) + LOG.debug(f"Building data for group {igroup}/{self.n_groups}") # There are several groups. # There is one result to load for each group. self.load_result(result) self.registry.set_flag(igroup) - self.registry.add_to_history("loading_data_end", parts=self.parts) self.registry.add_provenance(name="provenance_load") self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) - self.print_info() + # self.print_info() def load_result(self, result): # There is one cube to load for each result. @@ -430,7 +450,7 @@ def load_result(self, result): shape = cube.extended_user_shape dates_in_data = cube.user_coords["valid_datetime"] - LOG.info(f"Loading {shape=} in {self.data_array.shape=}") + LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") def check_dates_in_data(lst, lst2): lst2 = [np.datetime64(_) for _ in lst2] @@ -450,7 +470,7 @@ def dates_to_indexes(dates, all_dates): array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) self.load_cube(cube, array) - stats = compute_statistics(array.cache, self.variables_names, allow_nan=self.allow_nan) + stats = compute_statistics(array.cache, self.variables_names, allow_nans=self.allow_nans) self.tmp_statistics.write(indexes, stats, dates=dates_in_data) array.flush() @@ -463,11 +483,19 @@ def load_cube(self, cube, array): reading_chunks = None total = cube.count(reading_chunks) - self.print(f"Loading datacube: {cube}") - bar = progress_bar( + LOG.debug(f"Loading datacube: {cube}") + + def position(x): + if isinstance(x, str) and "/" in x: + x = x.split("/") + return int(x[0]) + return None + + bar = tqdm.tqdm( iterable=cube.iterate_cubelets(reading_chunks), total=total, desc=f"Loading datacube {cube}", + position=position(self.parts), ) for i, cubelet in enumerate(bar): bar.set_description(f"Loading {i}/{total}") @@ -482,7 +510,7 @@ def load_cube(self, cube, array): data[:], name=name, log=[i, data.shape, local_indexes], - allow_nan=self.allow_nan, + allow_nans=self.allow_nans, ) now = time.time() @@ -491,10 +519,11 @@ def load_cube(self, cube, array): now = time.time() save += time.time() - now - LOG.info("Written.") - msg = f"Elapsed: {seconds(time.time() - start)}, load time: {seconds(load)}, write time: {seconds(save)}." - self.print(msg) - LOG.info(msg) + LOG.debug( + f"Elapsed: {seconds_to_human(time.time() - start)}, " + f"load time: {seconds_to_human(load)}, " + f"write time: {seconds_to_human(save)}." + ) class StatisticsAdder(DatasetHandlerWithStatistics): @@ -518,12 +547,16 @@ def __init__( self.read_dataset_metadata() - def allow_nan(self, name): + @cached_property + def allow_nans(self): z = zarr.open(self.path, mode="r") + if "allow_nans" in z.attrs: + return z.attrs["allow_nans"] + if "variables_with_nans" in z.attrs: - return name in z.attrs["variables_with_nans"] + return z.attrs["variables_with_nans"] - warnings.warn(f"Cannot find 'variables_with_nans' in {self.path}. Assuming nans allowed for {name}.") + warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") return True def _get_statistics_dates(self): @@ -562,7 +595,7 @@ def assert_dtype(d): def run(self): dates = self._get_statistics_dates() - stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nan) + stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nans) self.output_writer(stats) def write_stats_to_file(self, stats): @@ -591,7 +624,7 @@ def write_stats_to_dataset(self, stats): "count", "has_nans", ]: - self._add_dataset(name=k, array=stats[k]) + self._add_dataset(name=k, array=stats[k], dimensions=("variable",)) self.registry.add_to_history("compute_statistics_end") LOG.info(f"Wrote statistics in {self.path}") @@ -625,6 +658,7 @@ def run(self, parts): raise NotImplementedError() def finalise(self): + shape = (len(self.dates), len(self.variables)) agg = dict( minimum=np.full(shape, np.nan, dtype=np.float64), @@ -634,7 +668,7 @@ def finalise(self): count=np.full(shape, -1, dtype=np.int64), has_nans=np.full(shape, False, dtype=np.bool_), ) - LOG.info(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") + LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") found = set() ifound = set() @@ -730,9 +764,9 @@ def _write(self, summary): "has_nans", ]: name = self.final_storage_name(k) - self._add_dataset(name=name, array=summary[k]) + self._add_dataset(name=name, array=summary[k], dimensions=("variable",)) self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") - LOG.info(f"Wrote additions in {self.path} ({self.final_storage_name('*')})") + LOG.debug(f"Wrote additions in {self.path} ({self.final_storage_name('*')})") def check_statistics(self): pass @@ -744,10 +778,19 @@ def _variables_with_nans(self): return z.attrs["variables_with_nans"] return None - def allow_nan(self, name): + @cached_property + def _allow_nans(self): + z = zarr.open(self.path, mode="r") + return z.attrs.get("allow_nans", False) + + def allow_nans(self): + + if self._allow_nans: + return True + if self._variables_with_nans is not None: - return name in self._variables_with_nans - warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, Assuming nans allowed for {name}.") + return self._variables_with_nans + warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") return True @@ -768,7 +811,7 @@ def __init__(self, **kwargs): @property def tmp_storage_path(self): - return f"{self.path}.tmp_storage_statistics" + return f"{self.path}.storage_statistics.tmp" def final_storage_name(self, k): return k @@ -781,12 +824,12 @@ def run(self, parts): date = self.dates[i] try: arr = self.ds[i : i + 1, ...] - stats = compute_statistics(arr, self.variables, allow_nan=self.allow_nan) + stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) self.tmp_storage.add([date, i, stats], key=date) except MissingDateError: self.tmp_storage.add([date, i, "missing"], key=date) self.tmp_storage.flush() - LOG.info(f"Dataset {self.path} additions run.") + LOG.debug(f"Dataset {self.path} additions run.") def check_statistics(self): ds = open_dataset(self.path) @@ -846,7 +889,7 @@ def __init__(self, path, delta=None, **kwargs): @property def tmp_storage_path(self): - return f"{self.path}.tmp_storage_statistics_{self.delta}h" + return f"{self.path}.storage_statistics_{self.delta}h.tmp" def final_storage_name(self, k): return self.final_storage_name_from_delta(k, delta=self.delta) @@ -867,9 +910,15 @@ def run(self, parts): date = self.dates[i] try: arr = self.ds[i] - stats = compute_statistics(arr, self.variables, allow_nan=self.allow_nan) + stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) self.tmp_storage.add([date, i, stats], key=date) except MissingDateError: self.tmp_storage.add([date, i, "missing"], key=date) self.tmp_storage.flush() - LOG.info(f"Dataset {self.path} additions run.") + LOG.debug(f"Dataset {self.path} additions run.") + + +class DatasetVerifier(GenericDatasetHandler): + + def verify(self): + pass diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py index 8240d00e..20014c0d 100755 --- a/src/anemoi/datasets/create/patch.py +++ b/src/anemoi/datasets/create/patch.py @@ -1,9 +1,12 @@ #!/usr/bin/env python3 import json +import logging import os import zarr +LOG = logging.getLogger(__name__) + def fix_order_by(order_by): if isinstance(order_by, list): @@ -48,7 +51,7 @@ def fix_provenance(provenance): provenance["module_versions"][k] = os.path.join("...", os.path.basename(v)) for k, v in list(provenance["git_versions"].items()): - print(k, v) + LOG.debug(k, v) modified_files = v["git"].get("modified_files", []) untracked_files = v["git"].get("untracked_files", []) if not isinstance(modified_files, int): @@ -63,21 +66,21 @@ def fix_provenance(provenance): } ) - print(json.dumps(provenance, indent=2)) + LOG.debug(json.dumps(provenance, indent=2)) # assert False return provenance def apply_patch(path, verbose=True, dry_run=False): - print("====================") - print(f"Patching {path}") - print("====================") + LOG.debug("====================") + LOG.debug(f"Patching {path}") + LOG.debug("====================") try: attrs = zarr.open(path, mode="r").attrs.asdict() except zarr.errors.PathNotFoundError as e: - print(f"Failed to open {path}") - print(e) + LOG.error(f"Failed to open {path}") + LOG.error(e) exit(0) FIXES = { @@ -94,23 +97,23 @@ def apply_patch(path, verbose=True, dry_run=False): for k, v in attrs.items(): v = attrs[k] if k in REMOVE: - print(f"βœ… Remove {k}") + LOG.info(f"βœ… Remove {k}") continue if k not in FIXES: assert not k.startswith("provenance"), f"[{k}]" - print(f"βœ… Don't fix {k}") + LOG.debug(f"βœ… Don't fix {k}") fixed_attrs[k] = v continue new_v = FIXES[k](v) if json.dumps(new_v, sort_keys=True) != json.dumps(v, sort_keys=True): - print(f"βœ… Fix {k}") + LOG.info(f"βœ… Fix {k}") if verbose: - print(f" Before : {k}= {v}") - print(f" After : {k}= {new_v}") + LOG.info(f" Before : {k}= {v}") + LOG.info(f" After : {k}= {new_v}") else: - print(f"βœ… Unchanged {k}") + LOG.debug(f"βœ… Unchanged {k}") fixed_attrs[k] = new_v if dry_run: @@ -125,6 +128,6 @@ def apply_patch(path, verbose=True, dry_run=False): after = json.dumps(z.attrs.asdict(), sort_keys=True) if before != after: - print("CHANGED") + LOG.info("Dataset changed by patch") assert json.dumps(z.attrs.asdict(), sort_keys=True) == json.dumps(fixed_attrs, sort_keys=True) diff --git a/src/anemoi/datasets/create/persistent.py b/src/anemoi/datasets/create/persistent.py index 51963f7f..207553e7 100644 --- a/src/anemoi/datasets/create/persistent.py +++ b/src/anemoi/datasets/create/persistent.py @@ -49,7 +49,7 @@ def __str__(self): def items(self): # use glob to read all pickles files = glob.glob(self.dirname + "/*.pickle") - LOG.info(f"Reading {self.name} data, found {len(files)} files in {self.dirname}") + LOG.debug(f"Reading {self.name} data, found {len(files)} files in {self.dirname}") assert len(files) > 0, f"No files found in {self.dirname}" for f in files: with open(f, "rb") as f: diff --git a/src/anemoi/datasets/create/size.py b/src/anemoi/datasets/create/size.py index 35c20994..2191a08f 100644 --- a/src/anemoi/datasets/create/size.py +++ b/src/anemoi/datasets/create/size.py @@ -10,9 +10,8 @@ import logging import os -from anemoi.utils.humanize import bytes - -from anemoi.datasets.create.utils import progress_bar +import tqdm +from anemoi.utils.humanize import bytes_to_human LOG = logging.getLogger(__name__) @@ -22,14 +21,14 @@ def compute_directory_sizes(path): return None size, n = 0, 0 - bar = progress_bar(iterable=os.walk(path), desc=f"Computing size of {path}") + bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}") for dirpath, _, filenames in bar: for filename in filenames: file_path = os.path.join(dirpath, filename) size += os.path.getsize(file_path) n += 1 - LOG.info(f"Total size: {bytes(size)}") + LOG.info(f"Total size: {bytes_to_human(size)}") LOG.info(f"Total number of files: {n}") return dict(total_size=size, total_number_of_files=n) diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py index 5f78d80c..568bc410 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -89,20 +89,23 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa continue print("---") print(f"❗ Negative variance for {name=}, variance={y}") - print(f" max={maximum[i]} min={minimum[i]} mean={mean[i]} count={count[i]} sum={sums[i]} square={squares[i]}") + print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}") print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}") print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}") print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}") + print( + f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}" + ) raise ValueError("Negative variance") -def compute_statistics(array, check_variables_names=None, allow_nan=False): +def compute_statistics(array, check_variables_names=None, allow_nans=False): """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary.""" nvars = array.shape[1] - LOG.info(f"Stats {nvars}, {array.shape}, {check_variables_names}") + LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}") if check_variables_names: assert nvars == len(check_variables_names), (nvars, check_variables_names) stats_shape = (array.shape[0], nvars) @@ -118,7 +121,7 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False): values = chunk.reshape((nvars, -1)) for j, name in enumerate(check_variables_names): - check_data_values(values[j, :], name=name, allow_nan=allow_nan) + check_data_values(values[j, :], name=name, allow_nans=allow_nans) if np.isnan(values[j, :]).all(): # LOG.warning(f"All NaN values for {name} ({j}) for date {i}") raise ValueError(f"All NaN values for {name} ({j}) for date {i}") @@ -179,12 +182,12 @@ def write(self, key, data, dates): pickle.dump((key, dates, data), f) shutil.move(tmp_path, path) - LOG.info(f"Written statistics data for {len(dates)} dates in {path} ({dates})") + LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})") def _gather_data(self): # use glob to read all pickles files = glob.glob(self.dirname + "/*.npz") - LOG.info(f"Reading stats data, found {len(files)} files in {self.dirname}") + LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}") assert len(files) > 0, f"No files found in {self.dirname}" for f in files: with open(f, "rb") as f: @@ -211,17 +214,17 @@ def normalise_dates(dates): class StatAggregator: NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"] - def __init__(self, owner, dates, variables_names, allow_nan): + def __init__(self, owner, dates, variables_names, allow_nans): dates = sorted(dates) dates = to_datetimes(dates) assert dates, "No dates selected" self.owner = owner self.dates = dates self.variables_names = variables_names - self.allow_nan = allow_nan + self.allow_nans = allow_nans self.shape = (len(self.dates), len(self.variables_names)) - LOG.info(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}") + LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}") self.minimum = np.full(self.shape, np.nan, dtype=np.float64) self.maximum = np.full(self.shape, np.nan, dtype=np.float64) @@ -284,7 +287,7 @@ def check_type(a, b): assert d in found, f"Statistics for date {d} not precomputed." assert len(self.dates) == len(found), "Not all dates found in precomputed statistics" assert len(self.dates) == offset, "Not all dates found in precomputed statistics." - LOG.info(f"Statistics for {len(found)} dates found.") + LOG.debug(f"Statistics for {len(found)} dates found.") def aggregate(self): minimum = np.nanmin(self.minimum, axis=0) @@ -298,13 +301,43 @@ def aggregate(self): assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape x = squares / count - mean * mean - # remove negative variance due to numerical errors - # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 - check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares) - stdev = np.sqrt(x) - for j, name in enumerate(self.variables_names): - check_data_values(np.array([mean[j]]), name=name, allow_nan=False) + # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares): + # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape + # assert x.shape == (1,) + # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0] + # if x >= 0: + # return x + # + # order = np.sqrt((squares / count + mean * mean)/2) + # range = maximum - minimum + # LOG.warning(f"Negative variance for {name=}, variance={x}") + # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}") + # LOG.warning(f"Variable order of magnitude is {order}.") + # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).") + # LOG.warning(f"Count is {count}.") + # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6: + # LOG.warning(f"Variance is negative but very small, setting to 0.") + # return x*0 + # return x + + for i, name in enumerate(self.variables_names): + # remove negative variance due to numerical errors + # Not needed for now, fix_variance is disabled + # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1]) + check_variance( + x[i : i + 1], + [name], + minimum[i : i + 1], + maximum[i : i + 1], + mean[i : i + 1], + count[i : i + 1], + sums[i : i + 1], + squares[i : i + 1], + ) + check_data_values(np.array([mean[i]]), name=name, allow_nans=False) + + stdev = np.sqrt(x) return Summary( minimum=minimum, diff --git a/src/anemoi/datasets/create/template.py b/src/anemoi/datasets/create/template.py index 8952cacd..ec1c1a41 100644 --- a/src/anemoi/datasets/create/template.py +++ b/src/anemoi/datasets/create/template.py @@ -8,72 +8,16 @@ # import logging -import os import re import textwrap from functools import wraps -LOG = logging.getLogger(__name__) - -TRACE_INDENT = 0 - - -def step(action_path): - return f"[{'.'.join(action_path)}]" +from anemoi.utils.humanize import plural +from .trace import step +from .trace import trace -def trace(emoji, *args): - if os.environ.get("ANEMOI_DATASET_TRACE_CREATE") is None: - return - print(emoji, " " * TRACE_INDENT, *args) - - -def trace_datasource(method): - @wraps(method) - def wrapper(self, *args, **kwargs): - global TRACE_INDENT - trace( - "🌍", - "=>", - step(self.action_path), - self._trace_datasource(*args, **kwargs), - ) - TRACE_INDENT += 1 - result = method(self, *args, **kwargs) - TRACE_INDENT -= 1 - trace( - "🍎", - "<=", - step(self.action_path), - textwrap.shorten(repr(result), 256), - ) - return result - - return wrapper - - -def trace_select(method): - @wraps(method) - def wrapper(self, *args, **kwargs): - global TRACE_INDENT - trace( - "πŸ‘“", - "=>", - ".".join(self.action_path), - self._trace_select(*args, **kwargs), - ) - TRACE_INDENT += 1 - result = method(self, *args, **kwargs) - TRACE_INDENT -= 1 - trace( - "🍍", - "<=", - ".".join(self.action_path), - textwrap.shorten(repr(result), 256), - ) - return result - - return wrapper +LOG = logging.getLogger(__name__) def notify_result(method): @@ -99,7 +43,13 @@ def will_need_reference(self, key): self.used_references.add(key) def notify_result(self, key, result): - trace("🎯", step(key), "notify result", result) + trace( + "🎯", + step(key), + "notify result", + textwrap.shorten(repr(result).replace(",", ", "), width=40), + plural(len(result), "field"), + ) assert isinstance(key, (list, tuple)), key key = tuple(key) if key in self.used_references: diff --git a/src/anemoi/datasets/create/trace.py b/src/anemoi/datasets/create/trace.py new file mode 100644 index 00000000..d3dc6178 --- /dev/null +++ b/src/anemoi/datasets/create/trace.py @@ -0,0 +1,91 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import textwrap +import threading +from functools import wraps + +LOG = logging.getLogger(__name__) + + +thread_local = threading.local() +TRACE = 0 + + +def enable_trace(on_off): + global TRACE + TRACE = on_off + + +def step(action_path): + return f"[{'.'.join(action_path)}]" + + +def trace(emoji, *args): + + if not TRACE: + return + + if not hasattr(thread_local, "TRACE_INDENT"): + thread_local.TRACE_INDENT = 0 + + print(emoji, " " * thread_local.TRACE_INDENT, *args) + + +def trace_datasource(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + + if not hasattr(thread_local, "TRACE_INDENT"): + thread_local.TRACE_INDENT = 0 + + trace( + "🌍", + "=>", + step(self.action_path), + self._trace_datasource(*args, **kwargs), + ) + thread_local.TRACE_INDENT += 1 + result = method(self, *args, **kwargs) + thread_local.TRACE_INDENT -= 1 + trace( + "🍎", + "<=", + step(self.action_path), + textwrap.shorten(repr(result), 256), + ) + return result + + return wrapper + + +def trace_select(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + if not hasattr(thread_local, "TRACE_INDENT"): + thread_local.TRACE_INDENT = 0 + trace( + "πŸ‘“", + "=>", + ".".join(self.action_path), + self._trace_select(*args, **kwargs), + ) + thread_local.TRACE_INDENT += 1 + result = method(self, *args, **kwargs) + thread_local.TRACE_INDENT -= 1 + trace( + "🍍", + "<=", + ".".join(self.action_path), + textwrap.shorten(repr(result), 256), + ) + return result + + return wrapper diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py index 377ab292..e4629abd 100644 --- a/src/anemoi/datasets/create/utils.py +++ b/src/anemoi/datasets/create/utils.py @@ -7,15 +7,11 @@ # nor does it submit to any jurisdiction. # -import json import os from contextlib import contextmanager import numpy as np -import yaml from earthkit.data import settings -from earthkit.data.utils.humanize import seconds # noqa: F401 -from tqdm.auto import tqdm def cache_context(dirname): @@ -31,26 +27,6 @@ def no_cache_context(): return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname}) -def bytes(n): - """>>> bytes(4096) - '4 KiB' - >>> bytes(4000) - '3.9 KiB' - """ - if n < 0: - sign = "-" - n -= 0 - else: - sign = "" - - u = ["", " KiB", " MiB", " GiB", " TiB", " PiB", " EiB", " ZiB", " YiB"] - i = 0 - while n >= 1024: - n /= 1024.0 - i += 1 - return "%s%g%s" % (sign, int(n * 10 + 0.5) / 10.0, u[i]) - - def to_datetime_list(*args, **kwargs): from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_ @@ -63,15 +39,6 @@ def to_datetime(*args, **kwargs): return to_datetime_(*args, **kwargs) -def load_json_or_yaml(path): - with open(path, "r") as f: - if path.endswith(".json"): - return json.load(f) - if path.endswith(".yaml") or path.endswith(".yml"): - return yaml.safe_load(f) - raise ValueError(f"Cannot read file {path}. Need json or yaml with appropriate extension.") - - def make_list_int(value): if isinstance(value, str): if "/" not in value: @@ -118,18 +85,3 @@ def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s] assert d1 == d2, (i, d1, d2) return dates_ - - -def progress_bar(*, iterable=None, total=None, initial=0, desc=None): - return tqdm( - iterable=iterable, - total=total, - initial=initial, - unit_scale=True, - unit_divisor=1024, - unit="B", - disable=False, - leave=False, - desc=desc, - # dynamic_ncols=True, # make this the default? - ) diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/zarr.py index d2639f35..c6b270a3 100644 --- a/src/anemoi/datasets/create/zarr.py +++ b/src/anemoi/datasets/create/zarr.py @@ -24,8 +24,12 @@ def add_zarr_dataset( shape=None, array=None, overwrite=True, + dimensions=None, **kwargs, ): + assert dimensions is not None, "Please pass dimensions to add_zarr_dataset." + assert isinstance(dimensions, (tuple, list)) + if dtype is None: assert array is not None, (name, shape, array, dtype, zarr_root) dtype = array.dtype @@ -44,6 +48,7 @@ def add_zarr_dataset( **kwargs, ) a[...] = array + a.attrs["_ARRAY_DIMENSIONS"] = dimensions return a if "fill_value" not in kwargs: @@ -69,6 +74,7 @@ def add_zarr_dataset( overwrite=overwrite, **kwargs, ) + a.attrs["_ARRAY_DIMENSIONS"] = dimensions return a @@ -79,22 +85,27 @@ class ZarrBuiltRegistry: flags = None z = None - def __init__(self, path, synchronizer_path=None): + def __init__(self, path, synchronizer_path=None, use_threads=False): import zarr assert isinstance(path, str), path self.zarr_path = path - if synchronizer_path is None: - synchronizer_path = self.zarr_path + ".sync" - self.synchronizer_path = synchronizer_path - self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) + if use_threads: + self.synchronizer = zarr.ThreadSynchronizer() + self.synchronizer_path = None + else: + if synchronizer_path is None: + synchronizer_path = self.zarr_path + ".sync" + self.synchronizer_path = synchronizer_path + self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) def clean(self): - try: - shutil.rmtree(self.synchronizer_path) - except FileNotFoundError: - pass + if self.synchronizer_path is not None: + try: + shutil.rmtree(self.synchronizer_path) + except FileNotFoundError: + pass def _open_write(self): import zarr @@ -112,7 +123,7 @@ def _open_read(self, sync=True): def new_dataset(self, *args, **kwargs): z = self._open_write() zarr_root = z["_build"] - add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, **kwargs) + add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) def add_to_history(self, action, **kwargs): new = dict( @@ -143,6 +154,9 @@ def set_flag(self, i, value=True): z.attrs["latest_write_timestamp"] = datetime.datetime.utcnow().isoformat() z["_build"][self.name_flags][i] = value + def ready(self): + return all(self.get_flags()) + def create(self, lengths, overwrite=False): self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4")) self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool)) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index c0baf437..33ace1d6 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -8,63 +8,35 @@ import calendar import datetime import logging -import os import re from pathlib import PurePath import numpy as np import zarr +from anemoi.utils.config import load_config as load_settings from .dataset import Dataset LOG = logging.getLogger(__name__) -CONFIG = None -try: - import tomllib # Only available since 3.11 -except ImportError: - import tomli as tomllib +def load_config(): + return load_settings(defaults={"datasets": {"named": {}, "path": []}}) def add_named_dataset(name, path, **kwargs): - load_config() - if name in CONFIG["datasets"]["named"]: + config = load_config() + if name["datasets"]["named"]: raise ValueError(f"Dataset {name} already exists") - CONFIG["datasets"]["named"][name] = path + config["datasets"]["named"][name] = path def add_dataset_path(path): - load_config() - - if path not in CONFIG["datasets"]["path"]: - CONFIG["datasets"]["path"].append(path) - - # save_config() - - -def load_config(): - global CONFIG - if CONFIG is not None: - return CONFIG - - conf = os.path.expanduser("~/.config/anemoi/settings.toml") - if not os.path.exists(conf): - conf = os.path.expanduser("~/.anemoi.toml") - - if os.path.exists(conf): - - with open(conf, "rb") as f: - CONFIG = tomllib.load(f) - else: - CONFIG = {} - - CONFIG.setdefault("datasets", {}) - CONFIG["datasets"].setdefault("path", []) - CONFIG["datasets"].setdefault("named", {}) + config = load_config() - return CONFIG + if path not in config["datasets"]["path"]: + config["datasets"]["path"].append(path) def _frequency_to_hours(frequency): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 7a62d329..5dc3bb80 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -9,6 +9,7 @@ import os import warnings from functools import cached_property +from urllib.parse import urlparse import numpy as np import zarr @@ -40,7 +41,9 @@ def __iter__(self): class HTTPStore(ReadOnlyStore): - """We write our own HTTPStore because the one used by zarr (fsspec) does not play well with fork() and multiprocessing.""" + """We write our own HTTPStore because the one used by zarr (s3fs) + does not play well with fork() and multiprocessing. + """ def __init__(self, url): self.url = url @@ -58,17 +61,16 @@ def __getitem__(self, key): class S3Store(ReadOnlyStore): - """We write our own S3Store because the one used by zarr (fsspec) - does not play well with fork() and multiprocessing. Also, we get - to control the s3 client. + """We write our own S3Store because the one used by zarr (s3fs) + does not play well with fork(). We also get to control the s3 client + options using the anemoi configs. """ - def __init__(self, url): + def __init__(self, url, region=None): from anemoi.utils.s3 import s3_client _, _, self.bucket, self.key = url.split("/", 3) - - self.s3 = s3_client(self.bucket) + self.s3 = s3_client(self.bucket, region=region) def __getitem__(self, key): try: @@ -101,15 +103,27 @@ def __contains__(self, key): return key in self.store -def open_zarr(path, dont_fail=False, cache=None): - try: - store = path +def name_to_zarr_store(path_or_url): + store = path_or_url + + if store.startswith("s3://"): + store = S3Store(store) - if store.startswith("http://") or store.startswith("https://"): + elif store.startswith("http://") or store.startswith("https://"): + parsed = urlparse(store) + bits = parsed.netloc.split(".") + if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): + s3_url = f"s3://{bits[0]}{parsed.path}" + store = S3Store(s3_url, region=bits[2]) + else: store = HTTPStore(store) - elif store.startswith("s3://"): - store = S3Store(store) + return store + + +def open_zarr(path, dont_fail=False, cache=None): + try: + store = name_to_zarr_store(path) if DEBUG_ZARR_LOADING: if isinstance(store, str): @@ -117,7 +131,8 @@ def open_zarr(path, dont_fail=False, cache=None): if not os.path.isdir(store): raise NotImplementedError( - "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. Please disable it for other backends." + "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. " + "Please disable it for other backends." ) store = zarr.storage.DirectoryStore(store) store = DebugStore(store) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 2627016b..a66efc27 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -96,7 +96,7 @@ def as_dict(self): class StartEndDates(Dates): - def __init__(self, start, end, frequency=1, **kwargs): + def __init__(self, start, end, frequency=1, months=None, **kwargs): frequency = frequency_to_hours(frequency) def _(x): @@ -128,6 +128,12 @@ def _(x): date = start self.values = [] while date <= end: + + if months is not None: + if date.month not in months: + date += increment + continue + self.values.append(date) date += increment diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index a83e789b..d73bc91d 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -61,6 +61,9 @@ def __len__(self): count += 1 return count + def __repr__(self): + return f"{self.__class__.__name__}(dates={len(self)})" + class Filter: def __init__(self, missing): diff --git a/tests/xarray/test_kerchunk.py b/tests/xarray/test_kerchunk.py new file mode 100644 index 00000000..7ca5f6e8 --- /dev/null +++ b/tests/xarray/test_kerchunk.py @@ -0,0 +1,36 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import xarray as xr + +from anemoi.datasets.create.functions.sources.xarray import XarrayFieldList + + +def dont_test_kerchunk(): + + ds = xr.open_dataset( + "reference://", + engine="zarr", + backend_kwargs={ + "consolidated": False, + "storage_options": { + "fo": "combined.json", + "remote_protocol": "s3", + "remote_options": {"anon": True}, + }, + }, + ) + + fs = XarrayFieldList.from_xarray(ds) + + print(fs[-1].metadata()) + + assert len(fs) == 12432 + + +if __name__ == "__main__": + dont_test_kerchunk() diff --git a/tests/xarray/test_netcdf.py b/tests/xarray/test_netcdf.py new file mode 100644 index 00000000..bd1245c6 --- /dev/null +++ b/tests/xarray/test_netcdf.py @@ -0,0 +1,55 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import os + +import xarray as xr +from multiurl import download + +from anemoi.datasets.create.functions.sources.xarray import XarrayFieldList + +URLS = { + "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/efas.nc": dict(length=3), + "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/era5-2m-temperature-dec-1993.nc": dict(length=1), + "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/test.nc": dict(length=2), + "https://get.ecmwf.int/repository/test-data/metview/gallery/era5_2000_aug.nc": dict(length=3), + "https://get.ecmwf.int/repository/test-data/metview/gallery/era5_ozone_1999.nc": dict(length=4), + "https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/fa_ta850.nc": dict(length=37), + # 'https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/htessel_points.nc': dict(length=1), + "https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/test_single.nc": dict(length=1), + # 'https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/zgrid_rhgmet_metop_200701_R_2305_0010.nc': dict(length=1), + # 'https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/20210101-C3S-L2_GHG-GHG_PRODUCTS-TANSO2-GOSAT2-SRFP-DAILY-v2.0.0.nc': dict(length=1), + "https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/20220401-C3S-L3S_FIRE-BA-OLCI-AREA_3-fv1.1.nc": dict( + length=3 + ), + # 'https://github.com/ecmwf/magics-test/raw/master/test/efas/tamir.nc': dict(length=1), + # 'https://github.com/ecmwf/magics-test/raw/master/test/gallery/C3S_OZONE-L4-TC-ASSIM_MSR-201608-fv0020.nc': dict(length=1), + # 'https://github.com/ecmwf/magics-test/raw/master/test/gallery/avg_data.nc': dict(length=1), + "https://github.com/ecmwf/magics-test/raw/master/test/gallery/era5_2000_aug_1.nc": dict(length=3), + "https://github.com/ecmwf/magics-test/raw/master/test/gallery/missing.nc": dict(length=20), + "https://github.com/ecmwf/magics-test/raw/master/test/gallery/netcdf3_t_z.nc": dict(length=30), + "https://github.com/ecmwf/magics-test/raw/master/test/gallery/tos_O1_2001-2002.nc": dict(length=24), + "https://github.com/ecmwf/magics-test/raw/master/test/gallery/z_500.nc": dict(length=1), +} + + +def skip_test_netcdf(): + + for url, checks in URLS.items(): + print(url) + path = os.path.join(os.path.dirname(__file__), os.path.basename(url)) + if not os.path.exists(path): + download(url, path) + + ds = xr.open_dataset(path) + + fs = XarrayFieldList.from_xarray(ds) + + assert len(fs) == checks["length"], (url, len(fs)) + + +if __name__ == "__main__": + skip_test_netcdf() diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py new file mode 100644 index 00000000..3f7ce3ad --- /dev/null +++ b/tests/xarray/test_opendap.py @@ -0,0 +1,24 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import xarray as xr + +from anemoi.datasets.create.functions.sources.xarray import XarrayFieldList + + +def test_opendap(): + + ds = xr.open_dataset( + "https://thredds.met.no/thredds/dodsC/meps25epsarchive/2023/01/01/meps_det_2_5km_20230101T00Z.nc", + ) + + fs = XarrayFieldList.from_xarray(ds) + + assert len(fs) == 79529 + + +if __name__ == "__main__": + test_opendap() diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py new file mode 100644 index 00000000..d138d371 --- /dev/null +++ b/tests/xarray/test_zarr.py @@ -0,0 +1,54 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import xarray as xr + +from anemoi.datasets.create.functions.sources.xarray import XarrayFieldList + + +def test_arco_era5(): + + ds = xr.open_zarr( + "gs://gcp-public-data-arco-era5/ar/1959-2022-full_37-1h-0p25deg-chunk-1.zarr-v2", + chunks={"time": 48}, + consolidated=True, + ) + + fs = XarrayFieldList.from_xarray(ds) + print(len(fs)) + + print(fs[-1].metadata()) + print(fs[-1].to_numpy()) + + assert len(fs) == 128677526 + + +def test_weatherbench(): + ds = xr.open_zarr("gs://weatherbench2/datasets/pangu_hres_init/2020_0012_0p25.zarr") + + # https://weatherbench2.readthedocs.io/en/latest/init-vs-valid-time.html + + flavour = { + "rules": { + "latitude": {"name": "latitude"}, + "longitude": {"name": "longitude"}, + "step": {"name": "prediction_timedelta"}, + "date": {"name": "time"}, + "level": {"name": "level"}, + }, + "levtype": "pl", + } + + fs = XarrayFieldList.from_xarray(ds, flavour) + + assert len(fs) == 2430240 + + assert fs[0].metadata("valid_datetime") == "2020-01-01T00:00:00", fs[0].metadata("valid_datetime") + + +if __name__ == "__main__": + test_weatherbench()