Skip to content

Commit

Permalink
isolate some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed May 9, 2024
1 parent ad243c4 commit fb6b709
Showing 1 changed file with 86 additions and 86 deletions.
172 changes: 86 additions & 86 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,104 +225,104 @@ def test_random_state(self):
)


# class TestStatefulDataLoaderMap(TestStatefulDataLoaderIterable):
# def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
# if num_workers == 0:
# return
# dataset = DummyMapDataset(100, shuffle=shuffle)
# generator = torch.Generator()
# generator.manual_seed(13)
# sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# batch_size=batch_size,
# sampler=sampler,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
class TestStatefulDataLoaderMap(TestStatefulDataLoaderIterable):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
if num_workers == 0:
return
dataset = DummyMapDataset(100, shuffle=shuffle)
generator = torch.Generator()
generator.manual_seed(13)
sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)

# if interrupt is None:
# interrupt = len(dl)
if interrupt is None:
interrupt = len(dl)

# it = iter(dl)
# for _ in range(interrupt):
# next(it)
it = iter(dl)
for _ in range(interrupt):
next(it)

# state_dict = dl.state_dict()
# exp = []
# for batch in it:
# exp.append(batch)
state_dict = dl.state_dict()
exp = []
for batch in it:
exp.append(batch)

# # Restore new instance from state
# generator = torch.Generator()
# generator.manual_seed(13)
# sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# batch_size=batch_size,
# sampler=sampler,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
# dl.load_state_dict(state_dict)
# batches = []
# for batch in dl:
# batches.append(batch)
# Restore new instance from state
generator = torch.Generator()
generator.manual_seed(13)
sampler = torch.utils.data.RandomSampler(dataset, generator=generator)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)
dl.load_state_dict(state_dict)
batches = []
for batch in dl:
batches.append(batch)

# self.assertEqual(batches, exp)
self.assertEqual(batches, exp)


# class TestStatefulSampler(TestStatefulDataLoaderIterable):
# def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
# dataset = DummyMapDataset(100, shuffle=shuffle)
# sampler = DummySampler(len(dataset))
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# batch_size=batch_size,
# sampler=sampler,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
class TestStatefulSampler(TestStatefulDataLoaderIterable):
def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False):
dataset = DummyMapDataset(100, shuffle=shuffle)
sampler = DummySampler(len(dataset))
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)

# if interrupt is None:
# interrupt = len(dl)
if interrupt is None:
interrupt = len(dl)

# it = iter(dl)
# for _ in range(interrupt):
# next(it)
it = iter(dl)
for _ in range(interrupt):
next(it)

# state_dict = dl.state_dict()
# exp = []
# for batch in it:
# exp.append(batch)
state_dict = dl.state_dict()
exp = []
for batch in it:
exp.append(batch)

# # Restore new instance from state
# sampler = DummySampler(len(dataset))
# dl = StatefulDataLoader(
# dataset=dataset,
# num_workers=num_workers,
# collate_fn=identity,
# snapshot_every_n_steps=every_n_steps,
# persistent_workers=pw,
# batch_size=batch_size,
# sampler=sampler,
# multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
# )
# dl.load_state_dict(state_dict)
# batches = []
# for batch in dl:
# batches.append(batch)
# Restore new instance from state
sampler = DummySampler(len(dataset))
dl = StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
snapshot_every_n_steps=every_n_steps,
persistent_workers=pw,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
)
dl.load_state_dict(state_dict)
batches = []
for batch in dl:
batches.append(batch)

# self.assertEqual(batches, exp)
self.assertEqual(batches, exp)


# class GeneratorIterable(torch.utils.data.IterableDataset):
Expand Down

0 comments on commit fb6b709

Please sign in to comment.