From 34dc402fc992bf0db74b73206e3c7a2dd32cf728 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 11 Feb 2025 13:46:37 -0800 Subject: [PATCH] add a method to generate permutations --- torchdata/stateful_dataloader/sampler.py | 37 ++++++++++++------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 0c4164976..45aaefb11 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -35,18 +35,24 @@ def __init__(self, sampler): def __iter__(self): return self + def _get_perm(self, replacement: bool, num_samples: int) -> List[int]: + if replacement: + return torch.randint( + high=self.n, + size=(num_samples,), + dtype=torch.int64, + generator=self.sampler.generator, + ).tolist() + else: + return torch.randperm(self.n, generator=self.sampler.generator).tolist()[:num_samples] + def __next__(self): if self.replacement: num_full_chunks = self.num_samples // self.chunk_size remainder = self.num_samples % self.chunk_size if self.chunk_index < num_full_chunks: - if self.perm is None or not self.perm: - self.perm = torch.randint( - high=self.n, - size=(self.chunk_size,), - dtype=torch.int64, - generator=self.sampler.generator, - ).tolist() + if self.perm is None: + self.perm = self._get_perm(self.replacement, self.chunk_size) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -56,13 +62,8 @@ def __next__(self): self.yielded += 1 return value elif remainder > 0: - if self.perm is None or not self.perm: - self.perm = torch.randint( - high=self.n, - size=(remainder,), - dtype=torch.int64, - generator=self.sampler.generator, - ).tolist() + if self.perm is None: + self.perm = self._get_perm(self.replacement, remainder) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -76,8 +77,8 @@ def __next__(self): num_full_perms = self.num_samples // self.n remainder = self.num_samples % self.n if self.chunk_index < num_full_perms: - if self.perm is None or not self.perm: - self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist() + if self.perm is None: + self.perm = self._get_perm(self.replacement, self.n) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1 @@ -87,8 +88,8 @@ def __next__(self): self.yielded += 1 return value elif remainder > 0: - if self.perm is None or not self.perm: - self.perm = torch.randperm(self.n, generator=self.sampler.generator).tolist()[:remainder] + if self.perm is None: + self.perm = self._get_perm(self.replacement, remainder) self.perm_index = 0 value = self.perm[self.perm_index] self.perm_index += 1