diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 64d2ed06a..203a4e042 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -43,6 +43,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def update_state_dict(self) -> None: self.generator_state = self.sampler.generator.get_state() + self.yielded = 0 def state_dict(self) -> Dict[str, Any]: return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} @@ -109,15 +110,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): + if not ( + isinstance(self.sampler, Stateful) + or isinstance(self.sampler_iter, Stateful) + ) and not isinstance(self.sampler, _InfiniteConstantSampler): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) def update_state_dict(self) -> None: - if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + if isinstance(self.sampler_iter, Stateful) and hasattr( + self.sampler_iter, "update_state_dict" + ): self.sampler_iter.update_state_dict()