Skip to content

Commit

Permalink
change iterable dataset to use worker torch rng
Browse files Browse the repository at this point in the history
  • Loading branch information
gokulavasan committed Apr 13, 2024
1 parent 0c862b1 commit c2cef69
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ class DummyIterator(Iterator, Stateful):
def __init__(self, samples, shuffle):
self.samples = samples
self.shuffle = shuffle
self.g = torch.Generator()
self.g.manual_seed(1)
self.size = len(self.samples)
self.i = 0

Expand All @@ -24,19 +22,19 @@ def __next__(self):
if self.i >= len(self.samples):
raise StopIteration
if self.shuffle:
i = torch.randint(self.size, (1,), generator=self.g).item()
i = torch.randint(self.size, (1,)).item()
else:
i = self.i
sample = self.samples[i]
self.i += 1
return sample

def state_dict(self):
return {"i": self.i, "g": self.g.get_state()}
return {"i": self.i, "g": torch.get_rng_state()}

def load_state_dict(self, state_dict):
self.i = state_dict["i"]
self.g.set_state(state_dict["g"])
torch.set_rng_state(state_dict["g"])


class DummySamplerIterator(Iterator, Stateful):
Expand Down Expand Up @@ -124,21 +122,22 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
)
exp = list(dl)
list(dl)

if interrupt is None:
interrupt = len(exp)

batches = []
exp = []
it = iter(dl)
for i in range(interrupt):
batches.append(next(it))
state_dict = dl.state_dict()
for _ in range(interrupt):
next(it)

self.assertEqual(batches, exp[:interrupt])
state_dict = dl.state_dict()
for data in it:
exp.append(data)

# Restore new instance from state
dataset = DummyIterableDataset([0, 100, 37], shuffle=shuffle)
batches = []
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
Expand All @@ -147,10 +146,10 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st
persistent_workers=pw,
)
dl.load_state_dict(state_dict)
for batch in dl:
for batch in iter(dl):
batches.append(batch)

self.assertEqual(batches, exp)
self.assertEqual(exp, batches)

def test_no_mp(self):
for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]):
Expand Down Expand Up @@ -642,32 +641,31 @@ def test_iterable(self):
every_n_steps = 10
for pw, bs in itertools.product([False, True], [None, 4]):
dataset = DummyIterableDataset([0, 100, 37], shuffle=True)
g = torch.Generator()
g.manual_seed(4)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=bs,
generator=g,
)
exp = list(dl)
list(dl)
state_end = dl.state_dict()
exp = list(dl)

batches = list(dl) # simple restart
self.assertEqual(batches, exp)

dataset = DummyIterableDataset([0, 100, 37], shuffle=True)
g.manual_seed(4)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=bs,
generator=g,
)
it = iter(dl)
for _ in range(2):
next(it)
dl.load_state_dict(state_end)
batches = list(dl)

Expand All @@ -693,7 +691,6 @@ def test_map(self):
state_end = dl.state_dict()
exp = list(dl)

dataset = DummyMapDataset(100, shuffle=True)
generator.manual_seed(15)
dl = StatefulDataLoader(
dataset=dataset,
Expand Down

0 comments on commit c2cef69

Please sign in to comment.