Skip to content

Commit

Permalink
Fix data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Dec 28, 2021
1 parent 82895c4 commit 48288dd
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions lightkit/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,20 @@ def __init__(self, dataset: Dataset[T_co], **kwargs: Any):
kwargs.setdefault("collate_fn", collate_tuple)

super().__init__(dataset, **kwargs) # type: ignore
self.custom_batching = self.num_workers == 0 and (
isinstance(self.batch_sampler, RangeBatchSampler)
or (
self.batch_sampler is not None
and hasattr(self.batch_sampler, "sampler")
and isinstance(self.batch_sampler.sampler, RangeBatchSampler)
)
)

def __iter__(self) -> Iterator[Any]: # pylint: disable=inconsistent-return-statements
if not (isinstance(self.dataset, TensorDataset) and self.num_workers == 0):
return super().__iter__()
if not self.custom_batching:
for item in super().__iter__():
yield item
return

for indices in self.batch_sampler:
if isinstance(indices, range):
Expand Down

0 comments on commit 48288dd

Please sign in to comment.