Skip to content

Commit

Permalink
add _extension import
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Apr 24, 2024
1 parent 358f475 commit e0b12ef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
2 changes: 1 addition & 1 deletion test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def test_lazy_imports(self) -> None:

self.assertFalse("datapipes" in torchdata.__dict__)

from torchdata import _extension, datapipes as dp, janitor # noqa # noqa
from torchdata import datapipes as dp, janitor # noqa # noqa

self.assertTrue("datapipes" in torchdata.__dict__)
dp.iter.IterableWrapper([1, 2])
Expand Down
35 changes: 16 additions & 19 deletions torchdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# import importlib

from torchdata import _extension # noqa: F401

from . import datapipes

janitor = datapipes.utils.janitor
import importlib

try:
from .version import __version__ # noqa: F401
Expand All @@ -26,15 +20,18 @@
assert __all__ == sorted(__all__)


# # Lazy import all modules
# def __getattr__(name):
# if name == "janitor":
# return importlib.import_module(".datapipes.utils." + name, __name__)
# else:
# try:
# return importlib.import_module("." + name, __name__)
# except ModuleNotFoundError:
# if name in globals():
# return globals()[name]
# else:
# raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None
# Lazy import all modules
def __getattr__(name):
if name in ("janitor", "datapipes"):
from torchdata import _extension # noqa: F401

if name == "janitor":
return importlib.import_module(".datapipes.utils." + name, __name__)
else:
try:
return importlib.import_module("." + name, __name__)
except ModuleNotFoundError:
if name in globals():
return globals()[name]
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None

0 comments on commit e0b12ef

Please sign in to comment.