Skip to content

Commit

Permalink
add a method to generate permutations
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 11, 2025
1 parent 5167a94 commit 34dc402
Showing 1 changed file with 19 additions and 18 deletions.
37 changes: 19 additions & 18 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,24 @@ def __init__(self, sampler):
def __iter__(self):
return self

def _get_perm(self, replacement: bool, num_samples: int) -> List[int]:
if replacement:
return torch.randint(
high=self.n,
size=(num_samples,),
dtype=torch.int64,
generator=self.sampler.generator,
).tolist()
else:
return torch.randperm(self.n, generator=self.sampler.generator).tolist()[:num_samples]

def __next__(self):
if self.replacement:
num_full_chunks = self.num_samples // self.chunk_size
remainder = self.num_samples % self.chunk_size
if self.chunk_index < num_full_chunks:
if self.perm is None or not self.perm:
self.perm = torch.randint(
high=self.n,
size=(self.chunk_size,),
dtype=torch.int64,
generator=self.sampler.generator,
).tolist()
if self.perm is None:
self.perm = self._get_perm(self.replacement, self.chunk_size)
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand All @@ -56,13 +62,8 @@ def __next__(self):
self.yielded += 1
return value
elif remainder > 0:
if self.perm is None or not self.perm:
self.perm = torch.randint(
high=self.n,
size=(remainder,),
dtype=torch.int64,
generator=self.sampler.generator,
).tolist()
if self.perm is None:
self.perm = self._get_perm(self.replacement, remainder)
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand All @@ -76,8 +77,8 @@ def __next__(self):
num_full_perms = self.num_samples // self.n
remainder = self.num_samples % self.n
if self.chunk_index < num_full_perms:
if self.perm is None or not self.perm:
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()
if self.perm is None:
self.perm = self._get_perm(self.replacement, self.n)
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand All @@ -87,8 +88,8 @@ def __next__(self):
self.yielded += 1
return value
elif remainder > 0:
if self.perm is None or not self.perm:
self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder]
if self.perm is None:
self.perm = self._get_perm(self.replacement, remainder)
self.perm_index = 0
value = self.perm[self.perm_index]
self.perm_index += 1
Expand Down

0 comments on commit 34dc402

Please sign in to comment.