diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 16e32f0c6..57b36d8f0 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -784,7 +784,7 @@ def test_lazy_imports(self) -> None: self.assertFalse("datapipes" in torchdata.__dict__) - from torchdata import _extension, datapipes as dp, janitor # noqa # noqa + from torchdata import datapipes as dp, janitor # noqa # noqa self.assertTrue("datapipes" in torchdata.__dict__) dp.iter.IterableWrapper([1, 2]) diff --git a/torchdata/__init__.py b/torchdata/__init__.py index a4d8bcc39..d5257639d 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -4,13 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# import importlib - -from torchdata import _extension # noqa: F401 - -from . import datapipes - -janitor = datapipes.utils.janitor +import importlib try: from .version import __version__ # noqa: F401 @@ -26,15 +20,18 @@ assert __all__ == sorted(__all__) -# # Lazy import all modules -# def __getattr__(name): -# if name == "janitor": -# return importlib.import_module(".datapipes.utils." + name, __name__) -# else: -# try: -# return importlib.import_module("." + name, __name__) -# except ModuleNotFoundError: -# if name in globals(): -# return globals()[name] -# else: -# raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None +# Lazy import all modules +def __getattr__(name): + if name in ("janitor", "datapipes"): + from torchdata import _extension # noqa: F401 + + if name == "janitor": + return importlib.import_module(".datapipes.utils." + name, __name__) + else: + try: + return importlib.import_module("." + name, __name__) + except ModuleNotFoundError: + if name in globals(): + return globals()[name] + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None