From bc9352ed33952da6be6d4e6d6af9fe9db9faf756 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 24 Jan 2025 23:00:48 +0800 Subject: [PATCH] Fix deprecated usage in zarr (#8313) Fixes #8298 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/inferers/merger.py | 23 +++++++++++++++++++---- tests/test_zarr_avg_merger.py | 7 ++++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index d01d334142..1344207e18 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -15,12 +15,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import nullcontext +from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any import numpy as np import torch -from monai.utils import ensure_tuple_size, optional_import, require_pkg +from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq if TYPE_CHECKING: import zarr @@ -233,7 +234,7 @@ def __init__( store: zarr.storage.Store | str = "merged.zarr", value_store: zarr.storage.Store | str | None = None, count_store: zarr.storage.Store | str | None = None, - compressor: str = "default", + compressor: str | None = None, value_compressor: str | None = None, count_compressor: str | None = None, chunks: Sequence[int] | bool = True, @@ -246,8 +247,22 @@ def __init__( self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store - self.value_store = zarr.storage.TempStore() if value_store is None else value_store - self.count_store = zarr.storage.TempStore() if count_store is None else count_store + self.tmpdir: TemporaryDirectory | None + if version_geq(get_package_version("zarr"), "3.0.0"): + if value_store is None: + self.tmpdir = TemporaryDirectory() + self.value_store = zarr.storage.LocalStore(self.tmpdir.name) + else: + self.value_store = value_store + if count_store is None: + self.tmpdir = TemporaryDirectory() + self.count_store = zarr.storage.LocalStore(self.tmpdir.name) + else: + self.count_store = count_store + else: + self.tmpdir = None + self.value_store = zarr.storage.TempStore() if value_store is None else value_store + self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks self.compressor = compressor self.value_compressor = value_compressor diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index a52dbceb4c..3c89e4fb03 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -287,15 +287,16 @@ class ZarrAvgMergerTests(unittest.TestCase): ] ) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + codec_reg = numcodecs.registry.codec_registry if "compressor" in arguments: if arguments["compressor"] != "default": - arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]() + arguments["compressor"] = codec_reg[arguments["compressor"].lower()]() if "value_compressor" in arguments: if arguments["value_compressor"] != "default": - arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]() + arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]() if "count_compressor" in arguments: if arguments["count_compressor"] != "default": - arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]() + arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]() merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1])