Skip to content

Commit

Permalink
update generator usage
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 11, 2025
1 parent 10ba260 commit 0a90c04
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .stateful import Stateful


class StatefulRandomSamplerIterator(Iterator[list[int]], Stateful):
class StatefulRandomSamplerIterator(Iterator[int], Stateful):
_GENERATOR = "generator"
_YIELDED = "yielded"

Expand All @@ -25,7 +25,6 @@ def __init__(self, sampler):
self.yielded = 0
self.next_yielded = None
self.n = len(sampler.data_source)
self.generator = sampler.generator
self.replacement = sampler.replacement
self.num_samples = sampler.num_samples
self.chunk_size = 32
Expand All @@ -46,7 +45,7 @@ def __next__(self):
high=self.n,
size=(self.chunk_size,),
dtype=torch.int64,
generator=self.generator,
generator=self.sampler.generator,
).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
Expand All @@ -62,7 +61,7 @@ def __next__(self):
high=self.n,
size=(remainder,),
dtype=torch.int64,
generator=self.generator,
generator=self.sampler.generator,
).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
Expand All @@ -78,7 +77,7 @@ def __next__(self):
remainder = self.num_samples % self.n
if self.chunk_index < num_full_perms:
if self.perm is None or not self.perm:
self.perm = torch.randperm(self.n, generator=self.generator).tolist()
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand All @@ -89,7 +88,7 @@ def __next__(self):
return value
elif remainder > 0:
if self.perm is None or not self.perm:
self.perm = torch.randperm(self.n, generator=self.generator).tolist()[:remainder]
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder]
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand All @@ -109,7 +108,7 @@ def state_dict(self) -> dict:
def load_state_dict(self, state_dict: dict) -> None:
self.next_yielded = state_dict[self._YIELDED]
self.generator_state = state_dict[self._GENERATOR]
self.generator.set_state(self.generator_state)
self.sampler.generator.set_state(self.generator_state)
if self.next_yielded is not None:
for _ in range(self.next_yielded - self.yielded):
next(self)
Expand Down

0 comments on commit 0a90c04

Please sign in to comment.