From 1d2f57e730188ca195415f833a4a995f2c59a01d Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 13 Dec 2024 20:46:47 +0000 Subject: [PATCH 01/11] Modify Workflow to Allow IterableDataset Inputs Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 21 ++++++++++----------- tests/test_iterable_dataset.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3629659db1..8554dc38ef 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -121,24 +121,23 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: - if iteration_update is not None: - super().__init__(iteration_update) - else: - super().__init__(self._iteration) + super().__init__(self._iteration if iteration_update is None else iteration_update) if isinstance(data_loader, DataLoader): - sampler = data_loader.__dict__["sampler"] - if isinstance(sampler, DistributedSampler): + sampler = getattr(data_loader, "sampler", None) + # set the epoch value for DistributedSampler objects when an epoch starts + if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) + # if the epoch_length isn't given, attempt to get it from the length of the data loader if epoch_length is None: - epoch_length = len(data_loader) - else: - if epoch_length is None: - raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") + try: + epoch_length = len(data_loader) + except TypeError: # raised when data_loader is given an iterable dataset which has no length + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( @@ -147,7 +146,7 @@ def set_sampler_epoch(engine: Engine) -> None: iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=epoch_length, + epoch_length=epoch_length, # None when the dataset is iterable and so has no length output=None, batch=None, metrics={}, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cfa711e4c0..5bedf09ba3 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -21,6 +21,9 @@ from monai.data import DataLoader, Dataset, IterableDataset from monai.transforms import Compose, LoadImaged, SimulateDelayd +from monai.engines import SupervisedEvaluator + +import torch.nn as nn class _Stream: @@ -59,6 +62,17 @@ def test_shape(self): for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) + def test_supervisedevaluator(self): + """ + Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader. + """ + data = list(range(10)) + dl = DataLoader(IterableDataset(data)) + evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity()) + evaluator.run() # fails if the epoch length or other internal setup is not done correctly + + self.assertEqual(evaluator.state.iteration, len(data)) + if __name__ == "__main__": unittest.main() From b1db08b46ef09586d2fab00fd34b160a10f13871 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Dec 2024 20:56:29 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/engines/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 8554dc38ef..8fc56449d7 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -127,7 +127,7 @@ def __init__( sampler = getattr(data_loader, "sampler", None) # set the epoch value for DistributedSampler objects when an epoch starts - if isinstance(sampler, DistributedSampler): + if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) From 8e82e12c30815afee4f2efea9f5798f8ea97ed10 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 13 Dec 2024 20:58:14 +0000 Subject: [PATCH 03/11] Modify Workflow to Allow IterableDataset Inputs Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 8554dc38ef..15eac1af83 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -127,7 +127,8 @@ def __init__( sampler = getattr(data_loader, "sampler", None) # set the epoch value for DistributedSampler objects when an epoch starts - if isinstance(sampler, DistributedSampler): + if isinstance(sampler, DistributedSampler): + @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) From dfcc77ce864a6455580623dbdfad2a6121ac3c48 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 13 Dec 2024 21:06:10 +0000 Subject: [PATCH 04/11] Modify Workflow to Allow IterableDataset Inputs Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 1 + tests/test_iterable_dataset.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 8fc56449d7..15eac1af83 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -128,6 +128,7 @@ def __init__( # set the epoch value for DistributedSampler objects when an epoch starts if isinstance(sampler, DistributedSampler): + @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 5bedf09ba3..fb554e391c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -18,12 +18,11 @@ import nibabel as nib import numpy as np +import torch.nn as nn from monai.data import DataLoader, Dataset, IterableDataset -from monai.transforms import Compose, LoadImaged, SimulateDelayd from monai.engines import SupervisedEvaluator - -import torch.nn as nn +from monai.transforms import Compose, LoadImaged, SimulateDelayd class _Stream: From 10d8c9a703273bcafbbb1385b4d5bd1c61b5fdd7 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 13:56:23 +0000 Subject: [PATCH 05/11] Update epoch length block Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 15eac1af83..6ddec837d4 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -133,12 +133,12 @@ def __init__( def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - # if the epoch_length isn't given, attempt to get it from the length of the data loader - if epoch_length is None: - try: - epoch_length = len(data_loader) - except TypeError: # raised when data_loader is given an iterable dataset which has no length - pass # deliberately leave epoch_length as None + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None: + try: + epoch_length = len(data_loader) + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( From aac82f567f3647f7b33636545fd760125b44e841 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 14:20:18 +0000 Subject: [PATCH 06/11] Trying a better way of getting length Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 6ddec837d4..da1fefe8fd 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Callable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sized import torch import torch.distributed as dist @@ -133,12 +133,9 @@ def __init__( def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - # if the epoch_length isn't given, attempt to get it from the length of the data loader - if epoch_length is None: - try: - epoch_length = len(data_loader) - except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type - pass # deliberately leave epoch_length as None + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader.dataset, Sized): + epoch_length = len(data_loader.dataset) # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( From 4aca7e6f35bef84e3bba359572f48077c1a6c4f8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:21:50 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/engines/workflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index da1fefe8fd..ddffc31abf 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -13,7 +13,8 @@ import warnings from collections.abc import Callable, Iterable, Sequence -from typing import TYPE_CHECKING, Any, Sized +from typing import TYPE_CHECKING, Any +from collections.abc import Sized import torch import torch.distributed as dist From 83e0ae6f9134af759bca0a27e2931518b9716b1a Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 14:25:54 +0000 Subject: [PATCH 08/11] DCO Remediation Commit for Eric Kerfoot I, Eric Kerfoot , hereby add my Signed-off-by to this commit: 10d8c9a703273bcafbbb1385b4d5bd1c61b5fdd7 I, Eric Kerfoot , hereby add my Signed-off-by to this commit: aac82f567f3647f7b33636545fd760125b44e841 Signed-off-by: Eric Kerfoot From 97dd0df6f9f764bc8e79bc7800ec4e59206028ab Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 14:38:22 +0000 Subject: [PATCH 09/11] Trying a better way of getting length Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index ddffc31abf..6734294cb1 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -12,9 +12,8 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence, Sized from typing import TYPE_CHECKING, Any -from collections.abc import Sized import torch import torch.distributed as dist From ce4360cd57fc2e5244119c6397e75ad05c62defc Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 15:44:32 +0000 Subject: [PATCH 10/11] Slight fix to how epoch length is guessed Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 6734294cb1..0c36da6d3d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -133,9 +133,12 @@ def __init__( def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - # if the epoch_length isn't given, attempt to get it from the length of the data loader - if epoch_length is None and isinstance(data_loader.dataset, Sized): - epoch_length = len(data_loader.dataset) + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader, Sized): + try: + epoch_length = len(data_loader) + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( From dc4e7d33b2cc4893195a8f1b767d1d1d0ff5f1b4 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 15:46:54 +0000 Subject: [PATCH 11/11] DCO Remediation Commit for Eric Kerfoot I, Eric Kerfoot , hereby add my Signed-off-by to this commit: 10d8c9a703273bcafbbb1385b4d5bd1c61b5fdd7 I, Eric Kerfoot , hereby add my Signed-off-by to this commit: aac82f567f3647f7b33636545fd760125b44e841 Signed-off-by: Eric Kerfoot