Skip to content

Commit

Permalink
generator to iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 11, 2025
1 parent db01e08 commit 6b7a106
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 189 deletions.
104 changes: 74 additions & 30 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,86 @@ class StatefulRandomSamplerIterator(Iterator[list[int]], Stateful):
_GENERATOR = "generator"
_YIELDED = "yielded"

def __init__(self, sampler, parent_iterator=None):
def __init__(self, sampler):
self.sampler = sampler
self.generator_state = self.sampler.generator.get_state()
self.yielded = 0
self.next_yielded = None
self.parent_iterator = parent_iterator or self._get_iterator()
self.n = len(sampler.data_source)
self.generator = sampler.generator
self.replacement = sampler.replacement
self.num_samples = sampler.num_samples
self.chunk_size = 32
self.chunk_index = 0
self.perm_index = 0
self.perm = None

def __iter__(self):
return self

def __next__(self):
val = next(self.parent_iterator)
self.yielded += 1
return val
if self.replacement:
num_full_chunks = self.num_samples // self.chunk_size
remainder = self.num_samples % self.chunk_size
if self.chunk_index < num_full_chunks:
if self.perm is None or not self.perm:
self.perm = torch.randint(
high=self.n,
size=(self.chunk_size,),
dtype=torch.int64,
generator=self.generator,
).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
if self.perm_index == self.chunk_size:
self.chunk_index += 1
self.perm = None
self.yielded += 1
return value
elif remainder > 0:
if self.perm is None or not self.perm:
self.perm = torch.randint(
high=self.n,
size=(self.remainder,),
dtype=torch.int64,
generator=self.generator,
).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
if self.perm_index == remainder:
raise StopIteration
self.yielded += 1
return value
else:
raise StopIteration
else:
num_full_perms = self.num_samples // self.n
remainder = self.num_samples % self.n
if self.chunk_index < num_full_perms:
if self.perm is None or not self.perm:
self.perm = torch.randperm(self.n, generator=self.generator).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
if self.perm_index == self.n:
self.chunk_index += 1
self.perm = None
self.yielded += 1
return value
elif remainder > 0:
if self.perm is None or not self.perm:
self.perm = torch.randperm(self.n, generator=self.generator).tolist()[:remainder]
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
if self.perm_index == remainder:
raise StopIteration
self.yielded += 1
return value
else:
raise StopIteration

def state_dict(self) -> dict:
return {
Expand All @@ -43,35 +109,13 @@ def state_dict(self) -> dict:
def load_state_dict(self, state_dict: dict) -> None:
self.next_yielded = state_dict[self._YIELDED]
self.generator_state = state_dict[self._GENERATOR]
self.sampler.generator.set_state(self.generator_state)
self.generator.set_state(self.generator_state)
if self.next_yielded is not None:
for _ in range(self.next_yielded):
next(self.parent_iterator)
for _ in range(self.next_yielded - self.yielded):
next(self)
self.yielded = self.next_yielded
self.next_yielded = None

def _get_iterator(self) -> Iterator[int]:
n = len(self.sampler.data_source)
generator = self.sampler.generator
if self.sampler.replacement:
chunk_size = 32
full_chunks = self.sampler.num_samples // chunk_size
remainder = self.sampler.num_samples % chunk_size
for _ in range(full_chunks):
yield from torch.randint(high=n, size=(chunk_size,), dtype=torch.int64, generator=generator).tolist()
if remainder > 0:
yield from torch.randint(
high=n,
size=(remainder,),
dtype=torch.int64,
generator=generator,
).tolist()
else:
for _ in range(self.sampler.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
if self.sampler.num_samples % n > 0:
yield from torch.randperm(n, generator=generator).tolist()[: self.sampler.num_samples % n]


class RandomSampler(Sampler[int]):
def __init__(
Expand Down
Loading

0 comments on commit 6b7a106

Please sign in to comment.