From 21920a34bc00a114e430e1943e1fd1f572880919 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 10 Dec 2024 20:03:35 +0800 Subject: [PATCH 01/10] Add platform-specific constraints to setup.cfg (#8260) Fixes #8258 ### Description Include platform_system conditions for dependencies in setup.cfg ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- setup.cfg | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/setup.cfg b/setup.cfg index ecfd717aff..0c69051218 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,10 +61,10 @@ all = tqdm>=4.47.0 lmdb psutil - cucim-cu12; python_version >= '3.9' and python_version <= '3.10' + cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10' openslide-python - tifffile - imagecodecs + tifffile; platform_system == "Linux" or platform_system == "Darwin" + imagecodecs; platform_system == "Linux" or platform_system == "Darwin" pandas einops transformers>=4.36.0, <4.41.0; python_version <= '3.10' @@ -78,7 +78,7 @@ all = pynrrd pydicom h5py - nni + nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna onnx>=1.13.0 onnxruntime; python_version <= '3.10' @@ -116,13 +116,13 @@ lmdb = psutil = psutil cucim = - cucim-cu12 + cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10' openslide = openslide-python tifffile = - tifffile + tifffile; platform_system == "Linux" or platform_system == "Darwin" imagecodecs = - imagecodecs + imagecodecs; platform_system == "Linux" or platform_system == "Darwin" pandas = pandas einops = @@ -152,7 +152,7 @@ pydicom = h5py = h5py nni = - nni + nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna = optuna onnx = From e1e3d8ebc1c7247aad9f1bffc649c5a20084340f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 19 Dec 2024 02:50:42 +0000 Subject: [PATCH 02/10] Modify Workflow to Allow IterableDataset Inputs (#8263) ### Description This modifies the behaviour of `Workflow` to permit `IterableDataset` to be used correctly. A check against the `epoch_length` value is removed, to allow that value to be `None`, and a test is added to verify this. The length of a data loader is not defined when using iterable datasets, so try/raise is added to allow that to be queried safely. This is related to my work on the streaming support, in my [prototype gist](https://gist.github.com/ericspod/1904713716b45631260784ac3fcd6fb3) I had to provide a bogus epoch length value in the then change it to `None` later once the evaluator object was created. This PR will remove the need for this hack. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot Signed-off-by: Eric Kerfoot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot --- monai/engines/workflow.py | 22 +++++++++++----------- tests/test_iterable_dataset.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3629659db1..0c36da6d3d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -12,7 +12,7 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence, Sized from typing import TYPE_CHECKING, Any import torch @@ -121,24 +121,24 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: - if iteration_update is not None: - super().__init__(iteration_update) - else: - super().__init__(self._iteration) + super().__init__(self._iteration if iteration_update is None else iteration_update) if isinstance(data_loader, DataLoader): - sampler = data_loader.__dict__["sampler"] + sampler = getattr(data_loader, "sampler", None) + + # set the epoch value for DistributedSampler objects when an epoch starts if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - if epoch_length is None: + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader, Sized): + try: epoch_length = len(data_loader) - else: - if epoch_length is None: - raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None: iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=epoch_length, + epoch_length=epoch_length, # None when the dataset is iterable and so has no length output=None, batch=None, metrics={}, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cfa711e4c0..fb554e391c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -18,8 +18,10 @@ import nibabel as nib import numpy as np +import torch.nn as nn from monai.data import DataLoader, Dataset, IterableDataset +from monai.engines import SupervisedEvaluator from monai.transforms import Compose, LoadImaged, SimulateDelayd @@ -59,6 +61,17 @@ def test_shape(self): for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) + def test_supervisedevaluator(self): + """ + Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader. + """ + data = list(range(10)) + dl = DataLoader(IterableDataset(data)) + evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity()) + evaluator.run() # fails if the epoch length or other internal setup is not done correctly + + self.assertEqual(evaluator.state.iteration, len(data)) + if __name__ == "__main__": unittest.main() From efff647a332f9520e7b7d7565893bd16ab26e041 Mon Sep 17 00:00:00 2001 From: Hsin-Yuan Hsieh <84929237+Jerome-Hsieh@users.noreply.github.com> Date: Sat, 21 Dec 2024 22:18:23 +0800 Subject: [PATCH 03/10] enhance download_and_extract (#8216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #5463 ### Description According to issue, the error messages are not very intuitive. I think maybe we can check if the file name matches the downloaded file’s base name before starting the download. If it doesn’t match, it will notify user. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: jerome_Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/apps/utils.py | 39 +++++++++++++++++++++++++++++- tests/test_download_and_extract.py | 3 ++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index c2e17d3247..95c1450f2a 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -15,6 +15,7 @@ import json import logging import os +import re import shutil import sys import tarfile @@ -30,7 +31,9 @@ from monai.config.type_definitions import PathLike from monai.utils import look_up_option, min_version, optional_import +requests, has_requests = optional_import("requests") gdown, has_gdown = optional_import("gdown", "4.7.3") +BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup") if TYPE_CHECKING: from tqdm import tqdm @@ -298,6 +301,29 @@ def extractall( ) +def get_filename_from_url(data_url: str) -> str: + """ + Get the filename from the URL link. + """ + try: + response = requests.head(data_url, allow_redirects=True) + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename = re.findall('filename="?([^";]+)"?', content_disposition) + if filename: + return str(filename[0]) + if "drive.google.com" in data_url: + response = requests.get(data_url) + if "text/html" in response.headers.get("Content-Type", ""): + soup = BeautifulSoup(response.text, "html.parser") + filename_div = soup.find("span", {"class": "uc-name-size"}) + if filename_div: + return str(filename_div.find("a").text) + return _basename(data_url) + except Exception as e: + raise Exception(f"Error processing URL: {e}") from e + + def download_and_extract( url: str, filepath: PathLike = "", @@ -327,7 +353,18 @@ def download_and_extract( be False. progress: whether to display progress bar. """ + url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes) + filepath_ext = "".join(Path(_basename(filepath)).suffixes) + if filepath not in ["", "."]: + if filepath_ext == "": + new_filepath = Path(filepath).with_suffix(url_filename_ext) + logger.warning( + f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}" + ) + filepath = new_filepath + if filepath_ext and filepath_ext != url_filename_ext: + raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}") with tempfile.TemporaryDirectory() as tmp_dir: - filename = filepath or Path(tmp_dir, _basename(url)).resolve() + filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve() download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress) extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base) diff --git a/tests/test_download_and_extract.py b/tests/test_download_and_extract.py index 555f7dc250..439a11bbc1 100644 --- a/tests/test_download_and_extract.py +++ b/tests/test_download_and_extract.py @@ -20,9 +20,10 @@ from parameterized import parameterized from monai.apps import download_and_extract, download_url, extractall -from tests.utils import skip_if_downloading_fails, skip_if_quick, testing_data_config +from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, testing_data_config +@SkipIfNoModule("requests") class TestDownloadAndExtract(unittest.TestCase): @skip_if_quick From d36f0c80f716c5ad040f0f2cad11407e68d0f33a Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:33:57 +0800 Subject: [PATCH 04/10] enable gpu load nifti (#8188) Related to https://github.com/Project-MONAI/MONAI/issues/8241 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/image_reader.py | 86 ++++++++++++++++++++++++++++++++---- monai/data/meta_tensor.py | 1 - monai/transforms/io/array.py | 1 - tests/test_init_reader.py | 19 ++++++++ tests/test_load_image.py | 41 ++++++++++++++++- 5 files changed, 136 insertions(+), 12 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index b4ae562911..5bc38f69ea 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -12,8 +12,11 @@ from __future__ import annotations import glob +import gzip +import io import os import re +import tempfile import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -51,6 +54,9 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) +cp, has_cp = optional_import("cupy") +kvikio, has_kvikio = optional_import("kvikio") + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict): ) -def _stack_images(image_list: list, meta_dict: dict): +def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False): if len(image_list) <= 1: return image_list[0] if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + if to_cupy and has_cp: + return cp.concatenate(image_list, axis=channel_dim) return np.concatenate(image_list, axis=channel_dim) # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + if to_cupy and has_cp: + return cp.stack(image_list, axis=0) return np.stack(image_list, axis=0) @@ -864,12 +874,18 @@ class NibabelReader(ImageReader): Load NIfTI format images based on Nibabel library. Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) channel_dim: the channel dimension of the input image, default is None. this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. if None, `original_channel_dim` will be either `no_channel` or `-1`. most Nifti files are usually "channel last", no need to specify this argument for them. + as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. + Default is False. CuPy and Kvikio are required for this option. + Note: For compressed NIfTI files, some operations may still be performed on CPU memory, + and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. + In practical use, it's recommended to add a warm up call before the actual loading. + A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py @@ -880,14 +896,42 @@ def __init__( channel_dim: str | int | None = None, as_closest_canonical: bool = False, squeeze_non_spatial_dims: bool = False, + to_gpu: bool = False, **kwargs, ): super().__init__() self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim self.as_closest_canonical = as_closest_canonical self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + if to_gpu and (not has_cp or not has_kvikio): + warnings.warn( + "NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading." + ) + to_gpu = False + + if to_gpu: + self.warmup_kvikio() + + self.to_gpu = to_gpu self.kwargs = kwargs + def warmup_kvikio(self): + """ + Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc. + This can accelerate the data loading process when `to_gpu` is set to True. + """ + if has_cp and has_kvikio: + a = cp.arange(100) + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b) + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ Verify whether the specified file or files format is supported by Nibabel reader. @@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = filenames kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. """ + # TODO: the actual type is list[np.ndarray | cp.ndarray] + # should figure out how to define correct types without having cupy not found error + # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 img_array: list[np.ndarray] = [] compatible_meta: dict = {} - for i in ensure_tuple(img): + for i, filename in zip(ensure_tuple(img), self.filenames): header = self._get_meta_dict(i) header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) @@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.AFFINE] = self._get_affine(i) header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) + data = self._get_array_data(i, filename) if self.squeeze_non_spatial_dims: for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): if data.shape[d - 1] == 1: @@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) - return _stack_images(img_array, compatible_meta), compatible_meta + return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta def _get_meta_dict(self, img) -> dict: """ @@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img): spatial_rank = max(min(ndim, 3), 1) return np.asarray(size[:spatial_rank]) - def _get_array_data(self, img): + def _get_array_data(self, img, filename): """ Get the raw array data of the image, converted to Numpy array. Args: img: a Nibabel image object loaded from an image file. - - """ + filename: file name of the image. + + """ + if self.to_gpu: + file_size = os.path.getsize(filename) + image = cp.empty(file_size, dtype=cp.uint8) + with kvikio.CuFile(filename, "r") as f: + f.read(image) + if filename.endswith(".nii.gz"): + # for compressed data, have to tansfer to CPU to decompress + # and then transfer back to GPU. It is not efficient compared to .nii file + # and may be slower than CPU loading in some cases. + warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.") + compressed_data = cp.asnumpy(image) + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: + decompressed_data = gz_file.read() + + image = cp.frombuffer(decompressed_data, dtype=cp.uint8) + data_shape = img.shape + data_offset = img.dataobj.offset + data_dtype = img.dataobj.dtype + return image[data_offset:].view(data_dtype).reshape(data_shape, order="F") return np.asanyarray(img.dataobj, order="C") diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index ac171e8508..c4c491e1b9 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta( However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned. """ img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray - # if not tracking metadata, return `torch.Tensor` if not isinstance(img, MetaTensor): return img diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4e71870fc9..1023cd7a7d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered: {self.readers}.\n{msg}" ) - img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0] diff --git a/tests/test_init_reader.py b/tests/test_init_reader.py index cb45cb5146..8331f742ec 100644 --- a/tests/test_init_reader.py +++ b/tests/test_init_reader.py @@ -30,6 +30,17 @@ def test_load_image(self): inst = LoadImaged("image", reader=r) self.assertIsInstance(inst, LoadImaged) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_load_image_to_gpu(self): + for to_gpu in [True, False]: + instance1 = LoadImage(reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance1, LoadImage) + + instance2 = LoadImaged("image", reader="NibabelReader", to_gpu=to_gpu) + self.assertIsInstance(instance2, LoadImaged) + @SkipIfNoModule("itk") @SkipIfNoModule("nibabel") @SkipIfNoModule("PIL") @@ -58,6 +69,14 @@ def test_readers(self): inst = NrrdReader() self.assertIsInstance(inst, NrrdReader) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + def test_readers_to_gpu(self): + for to_gpu in [True, False]: + inst = NibabelReader(to_gpu=to_gpu) + self.assertIsInstance(inst, NibabelReader) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 0207079d7d..a3e6d7bcfc 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -29,7 +29,7 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage from monai.utils import optional_import -from tests.utils import assert_allclose, skip_if_downloading_fails, testing_data_config +from tests.utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config itk, has_itk = optional_import("itk", allow_namespace_pkg=True) ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator") @@ -74,6 +74,22 @@ def get_data(self, _obj): TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)] +TEST_CASE_GPU_1 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii.gz"], (128, 128, 128)] + +TEST_CASE_GPU_2 = [{"reader": "nibabelreader", "to_gpu": True}, ["test_image.nii"], (128, 128, 128)] + +TEST_CASE_GPU_3 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii", "test_image2.nii", "test_image3.nii"], + (3, 128, 128, 128), +] + +TEST_CASE_GPU_4 = [ + {"reader": "nibabelreader", "to_gpu": True}, + ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], + (3, 128, 128, 128), +] + TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)] @@ -196,6 +212,29 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) + @SkipIfNoModule("nibabel") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4]) + def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): + test_image = np.random.rand(128, 128, 128) + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadImage(image_only=True, **input_param)(filenames) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + self.assertEqual(result.meta["space"], "RAS") + assert_allclose(result.affine, torch.eye(4)) + self.assertTupleEqual(result.shape, expected_shape) + + # verify gpu and cpu loaded data are the same + input_param_cpu = input_param.copy() + input_param_cpu["to_gpu"] = False + result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) + self.assertTrue(torch.equal(result_cpu, result.cpu())) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): test_image = np.random.rand(128, 128, 128) From 996e876e7542f683508aa04e74b97e284bbde72b Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 30 Dec 2024 21:14:55 +0800 Subject: [PATCH 05/10] 8274-mitigate-gpu-load-check (#8275) Fixes #8274 . ### Description I tried to use A100 with same container to test, but could not reproduce the issue. Therefore, I think here we can do a bit change on the test, and if there are still same issues, I will try to check more. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang --- tests/test_load_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index a3e6d7bcfc..aa8b71b7fa 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -233,7 +233,7 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): input_param_cpu = input_param.copy() input_param_cpu["to_gpu"] = False result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) - self.assertTrue(torch.equal(result_cpu, result.cpu())) + self.assertTrue(torch.allclose(result_cpu, result.cpu(), atol=1e-6)) @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): From 9eb0a8c41bcb0f8d95c8c7e99ea9d40d0835b7dc Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:45:21 +0800 Subject: [PATCH 06/10] 8274 Relax gpu load check (#8282) Related to #8274 , this PR is used to check potential issues. When I used the same environment as the nightly test, the error was not reproduced. Therefore, I hope the new change can show more information about the error. ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang --- tests/test_load_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index aa8b71b7fa..dc0af5e97e 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -233,7 +233,7 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): input_param_cpu = input_param.copy() input_param_cpu["to_gpu"] = False result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames) - self.assertTrue(torch.allclose(result_cpu, result.cpu(), atol=1e-6)) + assert_allclose(result_cpu, result.cpu(), atol=1e-6) @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): From 1c00ea22ead498551379cdcdbd7970ca7a6d9464 Mon Sep 17 00:00:00 2001 From: Pooya Mohammadi Kazaj Date: Fri, 10 Jan 2025 16:35:44 +0100 Subject: [PATCH 07/10] bug: Fix PatchMerging duplicate merging (#8285) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). Fixing issue #8284 In this format there are no duplicates: ``` t = [ (0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 0, 1), (1, 1, 0), (0, 1, 1), (1, 1, 1), ] print(set(t)) # {(1, 0, 1), (1, 1, 0), (0, 1, 0), (0, 0, 0), (1, 0, 0), (0, 0, 1), (1, 1, 1), (0, 1, 1)} ``` --------- Signed-off-by: pooya-mohammadi --- monai/networks/nets/swin_unetr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 77f0d2ec2f..cfc5dda41f 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -782,9 +782,9 @@ def forward(self, x): x1 = x[:, 1::2, 0::2, 0::2, :] x2 = x[:, 0::2, 1::2, 0::2, :] x3 = x[:, 0::2, 0::2, 1::2, :] - x4 = x[:, 1::2, 0::2, 1::2, :] - x5 = x[:, 0::2, 1::2, 0::2, :] - x6 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 1::2, 0::2, :] + x5 = x[:, 1::2, 0::2, 1::2, :] + x6 = x[:, 0::2, 1::2, 1::2, :] x7 = x[:, 1::2, 1::2, 1::2, :] x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) x = self.norm(x) From eaa901ce5624391f7ae7a707ee14a26a6244e3e7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 14 Jan 2025 15:27:12 +0800 Subject: [PATCH 08/10] Fix test load image issue (#8297) Fixes https://github.com/Project-MONAI/MONAI/issues/8274 . ### Description The new test has already tested with the same 24.08 + A100 env. I did some tests but cannot reproduce the original test case error (there are NaN values or significant small/large data). Since only 24.08 base image has the issue (24.10 does not have), I decided to use a different test case for 24.08 and prepared this PR ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang --- tests/test_load_image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index dc0af5e97e..498b9972b4 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -217,7 +217,12 @@ def test_nibabel_reader(self, input_param, filenames, expected_shape): @SkipIfNoModule("kvikio") @parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4]) def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): - test_image = np.random.rand(128, 128, 128) + if torch.__version__.endswith("nv24.8"): + # related issue: https://github.com/Project-MONAI/MONAI/issues/8274 + # for this version, use randint test case to avoid the issue + test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy() + else: + test_image = np.random.rand(128, 128, 128) with tempfile.TemporaryDirectory() as tempdir: for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) From 56d1f621964ba07b0f50d775a8b46c33c2fb1784 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 15 Jan 2025 15:08:47 +0800 Subject: [PATCH 09/10] Using LocalStore in Zarr v3 (#8299) Fixes #8298 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- tests/test_zarr_avg_merger.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index de7fad48da..a52dbceb4c 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -19,11 +19,18 @@ from torch.nn.functional import pad from monai.inferers import ZarrAvgMerger -from monai.utils import optional_import +from monai.utils import get_package_version, optional_import, version_geq from tests.utils import assert_allclose np.seterr(divide="ignore", invalid="ignore") zarr, has_zarr = optional_import("zarr") +if has_zarr: + if version_geq(get_package_version("zarr"), "3.0.0"): + directory_store = zarr.storage.LocalStore("test.zarr") + else: + directory_store = zarr.storage.DirectoryStore("test.zarr") +else: + directory_store = None numcodecs, has_numcodecs = optional_import("numcodecs") TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) @@ -154,7 +161,7 @@ # explicit directory store TEST_CASE_10_DIRECTORY_STORE = [ - dict(merged_shape=TENSOR_4x4.shape, store=zarr.storage.DirectoryStore("test.zarr")), + dict(merged_shape=TENSOR_4x4.shape, store=directory_store), [ (TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), From e39bad9ab433d150929b82a1eba3117276d7c254 Mon Sep 17 00:00:00 2001 From: advcu <65158236+advcu987@users.noreply.github.com> Date: Mon, 20 Jan 2025 07:26:06 +0100 Subject: [PATCH 10/10] 8267 fix normalize intensity (#8286) Fixes #8267 . ### Description Fix channel-wise intensity normalization for integer type inputs. ### Types of changes - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: advcu987 Signed-off-by: advcu <65158236+advcu987@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/transforms/intensity/array.py | 4 ++++ tests/test_normalize_intensity.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 20000c52c4..8fe658ad3e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -821,6 +821,7 @@ class NormalizeIntensity(Transform): mean and std on each channel separately. When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should be the number of image channels if they are not None. + If the input is not of floating point type, it will be converted to float32 Args: subtrahend: the amount to subtract by (usually the mean). @@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if self.divisor is not None and len(self.divisor) != len(img): raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.") + if not img.dtype.is_floating_point: + img, *_ = convert_data_type(img, dtype=torch.float32) + for i, d in enumerate(img): img[i] = self._normalize( # type: ignore d, diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 72ebf579e1..7efd0d83e5 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -108,6 +108,27 @@ def test_channel_wise(self, im_type): normalized = normalizer(input_data) assert_allclose(normalized, im_type(expected), type_test="tensor") + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise_int(self, im_type): + normalizer = NormalizeIntensity(nonzero=True, channel_wise=True) + input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4)) + expected = np.array( + [ + [ + [-1.593255, -1.3035723, -1.0138896, -0.7242068], + [-0.4345241, -0.1448414, 0.1448414, 0.4345241], + [0.7242068, 1.0138896, 1.3035723, 1.593255], + ], + [ + [-1.593255, -1.3035723, -1.0138896, -0.7242068], + [-0.4345241, -0.1448414, 0.1448414, 0.4345241], + [0.7242068, 1.0138896, 1.3035723, 1.593255], + ], + ] + ) + normalized = normalizer(input_data) + assert_allclose(normalized, im_type(expected), type_test="tensor", rtol=1e-7, atol=1e-7) # tolerance + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value_errors(self, im_type): input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))