Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for out of order with checkpointing #1428

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import itertools
import json
import math
import time
import unittest
from copy import deepcopy

Expand Down Expand Up @@ -1632,5 +1634,134 @@ def test_mp(self):
self._run_test(2, CountIterCallsIter(100))


class _TestSlowIndexDataset(torch.utils.data.Dataset):
def __init__(self, end: int, slow_index: int):
self.end = end
self.slow_index = slow_index
self._worker_id = None

def __getitem__(self, idx):
if idx == self.slow_index:
time.sleep(1.0)
return idx

def __len__(self):
return self.end


class _TestSlowIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.mid = math.ceil((self.end - self.start) / 2)

def give_data(self, iter_start, iter_end):
for i in range(iter_start, iter_end):
if i == self.mid:
time.sleep(1.0)
yield i

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return self.give_data(iter_start, iter_end)


class TestOutOfOrderWithCheckpointing(TestCase):
def test_out_of_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
in_order=False,
)

# worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue
# due to prefetch_factor being 2
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 3:
state_dict = dataloader.state_dict()
break

# 0 is the slow index, assert it isn't in the output before the pause
self.assertNotIn(0, output)

new_dataloader = StatefulDataLoader(dataset, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertNotEqual(output, list(range(10)))
self.assertEqual(sorted(output), list(range(10)))

def test_out_of_order_iterable_ds_one_completed_worker(self):
dataset = _TestSlowIterableDataset(start=0, end=10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For iterable dataset the slow worker will just lead to a straggler, on resume the individual workers will resume and continue, though be limited to single-worker performance. I think I can see why this might "just work" for Iterable datasets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure - I've added a second case that breaks before either worker finishes, and so they both resume after the restart, which gives the same correct results.
I think maybe the fast-forwarding part of the resuming is what is allowing this to work, and since the dataset is deterministic (ie the slow samples don't change) the fast forwarding by X samples will bring it back to the same point.

dataloader = StatefulDataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# break later on, as one of the workers will be finished
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 7:
state_dict = dataloader.state_dict()
break

worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
self.assertTrue(worker_0_ended)
self.assertFalse(worker_1_ended)

new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertEqual(output, list(range(10)))
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])

def test_out_of_order_iterable_ds_no_completed_workers(self):
dataset = _TestSlowIterableDataset(start=0, end=10)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# break early - both workers will resume
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 3:
state_dict = dataloader.state_dict()
break

worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
self.assertFalse(worker_0_ended)
self.assertFalse(worker_1_ended)

new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertEqual(output, list(range(10)))
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])


if __name__ == "__main__":
unittest.main()