Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/dev' into 8185-tests-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
garciadias committed Jan 20, 2025
2 parents d2e7a2e + e39bad9 commit 5632fdf
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 39 deletions.
39 changes: 38 additions & 1 deletion monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import re
import shutil
import sys
import tarfile
Expand All @@ -30,7 +31,9 @@
from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import

requests, has_requests = optional_import("requests")
gdown, has_gdown = optional_import("gdown", "4.7.3")
BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup")

if TYPE_CHECKING:
from tqdm import tqdm
Expand Down Expand Up @@ -298,6 +301,29 @@ def extractall(
)


def get_filename_from_url(data_url: str) -> str:
"""
Get the filename from the URL link.
"""
try:
response = requests.head(data_url, allow_redirects=True)
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename = re.findall('filename="?([^";]+)"?', content_disposition)
if filename:
return str(filename[0])
if "drive.google.com" in data_url:
response = requests.get(data_url)
if "text/html" in response.headers.get("Content-Type", ""):
soup = BeautifulSoup(response.text, "html.parser")
filename_div = soup.find("span", {"class": "uc-name-size"})
if filename_div:
return str(filename_div.find("a").text)
return _basename(data_url)
except Exception as e:
raise Exception(f"Error processing URL: {e}") from e


def download_and_extract(
url: str,
filepath: PathLike = "",
Expand Down Expand Up @@ -327,7 +353,18 @@ def download_and_extract(
be False.
progress: whether to display progress bar.
"""
url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes)
filepath_ext = "".join(Path(_basename(filepath)).suffixes)
if filepath not in ["", "."]:
if filepath_ext == "":
new_filepath = Path(filepath).with_suffix(url_filename_ext)
logger.warning(
f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}"
)
filepath = new_filepath
if filepath_ext and filepath_ext != url_filename_ext:
raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}")
with tempfile.TemporaryDirectory() as tmp_dir:
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve()
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
86 changes: 77 additions & 9 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from __future__ import annotations

import glob
import gzip
import io
import os
import re
import tempfile
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
Expand Down Expand Up @@ -51,6 +54,9 @@
pydicom, has_pydicom = optional_import("pydicom")
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)

cp, has_cp = optional_import("cupy")
kvikio, has_kvikio = optional_import("kvikio")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]


Expand Down Expand Up @@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
)


def _stack_images(image_list: list, meta_dict: dict):
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
if len(image_list) <= 1:
return image_list[0]
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
if to_cupy and has_cp:
return cp.concatenate(image_list, axis=channel_dim)
return np.concatenate(image_list, axis=channel_dim)
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
if to_cupy and has_cp:
return cp.stack(image_list, axis=0)
return np.stack(image_list, axis=0)


Expand Down Expand Up @@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
Load NIfTI format images based on Nibabel library.
Args:
as_closest_canonical: if True, load the image as closest to canonical axis format.
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
channel_dim: the channel dimension of the input image, default is None.
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
if None, `original_channel_dim` will be either `no_channel` or `-1`.
most Nifti files are usually "channel last", no need to specify this argument for them.
as_closest_canonical: if True, load the image as closest to canonical axis format.
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
Default is False. CuPy and Kvikio are required for this option.
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
In practical use, it's recommended to add a warm up call before the actual loading.
A related tutorial will be prepared in the future, and the document will be updated accordingly.
kwargs: additional args for `nibabel.load` API. more details about available args:
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
Expand All @@ -880,14 +896,42 @@ def __init__(
channel_dim: str | int | None = None,
as_closest_canonical: bool = False,
squeeze_non_spatial_dims: bool = False,
to_gpu: bool = False,
**kwargs,
):
super().__init__()
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
self.as_closest_canonical = as_closest_canonical
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
if to_gpu and (not has_cp or not has_kvikio):
warnings.warn(
"NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
)
to_gpu = False

if to_gpu:
self.warmup_kvikio()

self.to_gpu = to_gpu
self.kwargs = kwargs

def warmup_kvikio(self):
"""
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
This can accelerate the data loading process when `to_gpu` is set to True.
"""
if has_cp and has_kvikio:
a = cp.arange(100)
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file_name = tmp_file.name
f = kvikio.CuFile(tmp_file_name, "w")
f.write(a)
f.close()

b = cp.empty_like(a)
f = kvikio.CuFile(tmp_file_name, "r")
f.read(b)

def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
"""
Verify whether the specified file or files format is supported by Nibabel reader.
Expand Down Expand Up @@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
img_: list[Nifti1Image] = []

filenames: Sequence[PathLike] = ensure_tuple(data)
self.filenames = filenames
kwargs_ = self.kwargs.copy()
kwargs_.update(kwargs)
for name in filenames:
Expand All @@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
"""
# TODO: the actual type is list[np.ndarray | cp.ndarray]
# should figure out how to define correct types without having cupy not found error
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
img_array: list[np.ndarray] = []
compatible_meta: dict = {}

for i in ensure_tuple(img):
for i, filename in zip(ensure_tuple(img), self.filenames):
header = self._get_meta_dict(i)
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
Expand All @@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.AFFINE] = self._get_affine(i)
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
header[MetaKeys.SPACE] = SpaceKeys.RAS
data = self._get_array_data(i)
data = self._get_array_data(i, filename)
if self.squeeze_non_spatial_dims:
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
if data.shape[d - 1] == 1:
Expand All @@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
_copy_compatible_dict(header, compatible_meta)

return _stack_images(img_array, compatible_meta), compatible_meta
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta

def _get_meta_dict(self, img) -> dict:
"""
Expand Down Expand Up @@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
spatial_rank = max(min(ndim, 3), 1)
return np.asarray(size[:spatial_rank])

def _get_array_data(self, img):
def _get_array_data(self, img, filename):
"""
Get the raw array data of the image, converted to Numpy array.
Args:
img: a Nibabel image object loaded from an image file.
"""
filename: file name of the image.
"""
if self.to_gpu:
file_size = os.path.getsize(filename)
image = cp.empty(file_size, dtype=cp.uint8)
with kvikio.CuFile(filename, "r") as f:
f.read(image)
if filename.endswith(".nii.gz"):
# for compressed data, have to tansfer to CPU to decompress
# and then transfer back to GPU. It is not efficient compared to .nii file
# and may be slower than CPU loading in some cases.
warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.")
compressed_data = cp.asnumpy(image)
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
decompressed_data = gz_file.read()

image = cp.frombuffer(decompressed_data, dtype=cp.uint8)
data_shape = img.shape
data_offset = img.dataobj.offset
data_dtype = img.dataobj.dtype
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
return np.asanyarray(img.dataobj, order="C")


Expand Down
1 change: 0 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta(
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray

# if not tracking metadata, return `torch.Tensor`
if not isinstance(img, MetaTensor):
return img
Expand Down
22 changes: 11 additions & 11 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

import warnings
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence, Sized
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -121,24 +121,24 @@ def __init__(
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
if iteration_update is not None:
super().__init__(iteration_update)
else:
super().__init__(self._iteration)
super().__init__(self._iteration if iteration_update is None else iteration_update)

if isinstance(data_loader, DataLoader):
sampler = data_loader.__dict__["sampler"]
sampler = getattr(data_loader, "sampler", None)

# set the epoch value for DistributedSampler objects when an epoch starts
if isinstance(sampler, DistributedSampler):

@self.on(Events.EPOCH_STARTED)
def set_sampler_epoch(engine: Engine) -> None:
sampler.set_epoch(engine.state.epoch)

if epoch_length is None:
# if the epoch_length isn't given, attempt to get it from the length of the data loader
if epoch_length is None and isinstance(data_loader, Sized):
try:
epoch_length = len(data_loader)
else:
if epoch_length is None:
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type
pass # deliberately leave epoch_length as None

# set all sharable data for the workflow based on Ignite engine.state
self.state: Any = State(
Expand All @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None:
iteration=0,
epoch=0,
max_epochs=max_epochs,
epoch_length=epoch_length,
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
output=None,
batch=None,
metrics={},
Expand Down
6 changes: 3 additions & 3 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,9 +782,9 @@ def forward(self, x):
x1 = x[:, 1::2, 0::2, 0::2, :]
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 1::2, 0::2, :]
x5 = x[:, 1::2, 0::2, 1::2, :]
x6 = x[:, 0::2, 1::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
x = self.norm(x)
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ class NormalizeIntensity(Transform):
mean and std on each channel separately.
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
be the number of image channels if they are not None.
If the input is not of floating point type, it will be converted to float32
Args:
subtrahend: the amount to subtract by (usually the mean).
Expand Down Expand Up @@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if self.divisor is not None and len(self.divisor) != len(img):
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")

if not img.dtype.is_floating_point:
img, *_ = convert_data_type(img, dtype=torch.float32)

for i, d in enumerate(img):
img[i] = self._normalize( # type: ignore
d,
Expand Down
1 change: 0 additions & 1 deletion monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered: {self.readers}.\n{msg}"
)

img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
Expand Down
Loading

0 comments on commit 5632fdf

Please sign in to comment.