Skip to content

Commit

Permalink
feat(datasets): folder access layer allows recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Dec 17, 2023
1 parent db6f651 commit 8f6802d
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions todd/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'FolderAccessLayer',
]

import itertools
import enum
import pathlib
from abc import ABC
from typing import Iterator, TypeVar
Expand All @@ -13,21 +13,29 @@
VT = TypeVar('VT')


class Action(enum.Enum):
NONE = 'none'
WALK = 'walk'
FILTER = 'filter'


class FolderAccessLayer(BaseAccessLayer[str, VT], ABC):

def __init__(
self,
*args,
folder_root: Config | None = None,
filter_directories: bool = False,
subfolder_action: str | Action = Action.NONE,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if folder_root is None:
folder_root = Config()
self._build_folder_root(folder_root)

self._filter_directories = filter_directories
if isinstance(subfolder_action, str):
subfolder_action = Action(subfolder_action.lower())
self._subfolder_action = subfolder_action

def _build_folder_root(self, config: Config) -> None:
self._folder_root = pathlib.Path(self._data_root) / self._task_name
Expand All @@ -40,19 +48,35 @@ def touch(self) -> None:
self._folder_root.mkdir(parents=True, exist_ok=True)

def _files(self) -> Iterator[pathlib.Path]:
files: Iterator[pathlib.Path] = self._folder_root.iterdir()
if self._filter_directories:
files = itertools.filterfalse(
lambda path: path.is_dir(),
files: Iterator[pathlib.Path]
if self._subfolder_action is Action.WALK:
files = self._folder_root.rglob('*')
else:
files = self._folder_root.iterdir()
if self._subfolder_action in [Action.WALK, Action.FILTER]:
files = filter(
lambda path: path.is_file(),
files,
)
return files

def _file(self, key: str) -> pathlib.Path:
return self._folder_root / key

def _name(self, path: pathlib.Path) -> str:
return path.name

def _relative_to(self, path: pathlib.Path) -> str:
return str(path.relative_to(self._folder_root))

def __iter__(self) -> Iterator[str]:
return map(lambda path: path.name, self._files())
if self._subfolder_action in [Action.NONE, Action.FILTER]:
func = self._name
elif self._subfolder_action is Action.WALK:
func = self._relative_to
else:
raise NotImplementedError
return map(func, self._files())

def __len__(self) -> int:
return len(list(self._files()))
Expand Down

0 comments on commit 8f6802d

Please sign in to comment.