Skip to content

Commit

Permalink
update randomsampleriter state_dict fully
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 7, 2025
1 parent 6d49b4f commit 20a14e5
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

def update_state_dict(self) -> None:
self.generator_state = self.sampler.generator.get_state()
self.yielded = 0

def state_dict(self) -> Dict[str, Any]:
return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded}
Expand Down Expand Up @@ -109,15 +110,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert isinstance(self.sampler_iter, Stateful)
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE])

if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
self.sampler, _InfiniteConstantSampler
):
if not (
isinstance(self.sampler, Stateful)
or isinstance(self.sampler_iter, Stateful)
) and not isinstance(self.sampler, _InfiniteConstantSampler):
# We skip x samples if underlying sampler is not stateful
for _ in range(self.samples_yielded):
next(self.sampler_iter)

def update_state_dict(self) -> None:
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
if isinstance(self.sampler_iter, Stateful) and hasattr(
self.sampler_iter, "update_state_dict"
):
self.sampler_iter.update_state_dict()


Expand Down

0 comments on commit 20a14e5

Please sign in to comment.