diff --git a/src/torchio/__init__.py b/src/torchio/__init__.py index d4b83ece..776ab99b 100644 --- a/src/torchio/__init__.py +++ b/src/torchio/__init__.py @@ -16,6 +16,9 @@ from .data import LabelSampler from .data import Queue from .data import ScalarImage +from .data import LazyImage +from .data import LazyScalarImage +from .data import LazyLabelMap from .data import Subject from .data import SubjectsDataset from .data import SubjectsLoader @@ -36,6 +39,9 @@ 'Image', 'ScalarImage', 'LabelMap', + 'LazyImage', + 'LazyScalarImage', + 'LazyLabelMap', 'Queue', 'Subject', 'datasets', diff --git a/src/torchio/data/__init__.py b/src/torchio/data/__init__.py index d3250c98..a875158d 100644 --- a/src/torchio/data/__init__.py +++ b/src/torchio/data/__init__.py @@ -2,6 +2,9 @@ from .image import Image from .image import LabelMap from .image import ScalarImage +from .image import LazyImage +from .image import LazyScalarImage +from .image import LazyLabelMap from .inference import GridAggregator from .loader import SubjectsLoader from .queue import Queue @@ -20,6 +23,9 @@ 'Image', 'ScalarImage', 'LabelMap', + 'LazyImage', + 'LazyScalarImage', + 'LazyLabelMap', 'GridSampler', 'GridAggregator', 'PatchSampler', diff --git a/src/torchio/data/image.py b/src/torchio/data/image.py index aba6c15f..c89ec93c 100644 --- a/src/torchio/data/image.py +++ b/src/torchio/data/image.py @@ -917,3 +917,73 @@ def count_labels(self) -> dict[int, int]: counter = Counter(values_list) counts = {label: counter[label] for label in sorted(counter)} return counts + + +class LazyImage(Image): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def load(self): + if self._is_multipath(): + message = f'No multiple paths for LazyImage' + RuntimeError(message) + + tensor, affine = self.read_and_check(self.path) + self.set_data(tensor) + self.affine = affine + self._loaded = True + + def _parse_tensor( + self, + tensor: Optional[TypeData], + none_ok: bool = True, + ) -> Optional[torch.Tensor]: + if tensor is None: + if none_ok: + return None + else: + raise RuntimeError('Input tensor cannot be None') + + ndim = tensor.ndim + if ndim != 4: + raise ValueError(f'Input tensor must be 4D, but it is {ndim}D') + + return tensor + + @staticmethod + def _parse_tensor_shape(tensor: torch.Tensor) -> TypeData: + # here we do not want to maniulate the whole data as tensor, to avoid loading + # so we skip check here, so we can not repare bad shape ... + # _parse_tensor, is already checking if ndim==4 + return tensor + + def __repr__(self): + # alternative would be to modify the __repr__ function of parent class (image + # in order to avoid the call self.data.type() (which is only defined for tensor) + properties = [] + properties.extend( + [ + f'shape: {self.shape}', + f'spacing: {self.get_spacing_string()}', + f'orientation: {"".join(self.orientation)}+', + ] + ) + if self._loaded: + # instead of adding dtype and memory, just print the data + properties.append(f'dtype: {self.data}') + else: + properties.append(f'path: "{self.path}"') + + properties = '; '.join(properties) + string = f'{self.__class__.__name__}({properties})' + return string + + +class LazyScalarImage(LazyImage, ScalarImage): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class LazyLabelMap(LazyImage, LabelMap): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/src/torchio/transforms/preprocessing/spatial/crop.py b/src/torchio/transforms/preprocessing/spatial/crop.py index 2895d266..3527a5f1 100644 --- a/src/torchio/transforms/preprocessing/spatial/crop.py +++ b/src/torchio/transforms/preprocessing/spatial/crop.py @@ -1,5 +1,6 @@ import nibabel as nib import numpy as np +import torch from ....data.subject import Subject from .bounds_transform import BoundsTransform @@ -48,7 +49,10 @@ def apply_transform(self, sample) -> Subject: new_affine[:3, 3] = new_origin i0, j0, k0 = index_ini i1, j1, k1 = index_fin - image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone()) + if isinstance(image.data, torch.Tensor): + image.set_data(image.data[:, i0:i1, j0:j1, k0:k1].clone()) + else: + image.set_data(torch.as_tensor(image.data[:, i0:i1, j0:j1, k0:k1])) image.affine = new_affine return sample