diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 9172b3e6d..bddef097e 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -107,10 +107,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) - if not ( - isinstance(self.sampler, Stateful) - or isinstance(self.sampler_iter, Stateful) - ) and not isinstance(self.sampler, _InfiniteConstantSampler): + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): # We skip x samples if underlying sampler is not stateful for _ in range(self.samples_yielded): next(self.sampler_iter) diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index f43e23dd7..c986a6dfa 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -46,10 +46,7 @@ ) from torch.utils.data.dataloader import _BaseDataLoaderIter, _InfiniteConstantSampler -from torch.utils.data.datapipes.datapipe import ( - _IterDataPipeSerializationWrapper, - _MapDataPipeSerializationWrapper, -) +from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper from .incremental_state import ( _DATASET_ITER_STATE, @@ -218,8 +215,7 @@ def __init__( if num_workers < 0: raise ValueError( - "num_workers option should be non-negative; " - "use num_workers=0 to disable multiprocessing." + "num_workers option should be non-negative; " "use num_workers=0 to disable multiprocessing." ) if timeout < 0: @@ -295,9 +291,7 @@ def __init__( # specific workers. if isinstance(dataset, IterDataPipe): if shuffle is not None: - dataset = torch.utils.data.graph_settings.apply_shuffle_settings( - dataset, shuffle=shuffle - ) + dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. elif shuffle not in {False, None}: raise ValueError( @@ -326,9 +320,7 @@ def __init__( # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: raise ValueError( - "batch_sampler option is mutually exclusive " - "with batch_size, shuffle, sampler, and " - "drop_last" + "batch_sampler option is mutually exclusive " "with batch_size, shuffle, sampler, and " "drop_last" ) batch_size = None drop_last = False @@ -336,8 +328,7 @@ def __init__( # no auto_collation if drop_last: raise ValueError( - "batch_size=None option disables auto-batching " - "and is mutually exclusive with drop_last" + "batch_size=None option disables auto-batching " "and is mutually exclusive with drop_last" ) if sampler is None: # give default samplers @@ -371,9 +362,7 @@ def __init__( # set DataLoader's __initialized attribute. self._DataLoader__initialized = True - self._IterableDataset_len_called = ( - None # See NOTE [ IterableDataset and __len__ ] - ) + self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ] self._iterator = None @@ -484,9 +473,7 @@ def __init__(self, loader, next_iter_state=None): # Taking care of distributed sharding if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): # For BC, use default SHARDING_PRIORITIES - torch.utils.data.graph_settings.apply_sharding( - self._dataset, self._world_size, self._rank - ) + torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank) if next_iter_state is not None: self.load_state_dict(next_iter_state) @@ -509,9 +496,7 @@ def _next_data(self): def state_dict(self): if self._dataset_kind == _DatasetKind.Iterable: fetcher_state = { - _DATASET_ITER_STATE: try_to_serialize( - self._dataset_fetcher.dataset_iter - ), + _DATASET_ITER_STATE: try_to_serialize(self._dataset_fetcher.dataset_iter), _FETCHER_ENDED: self._dataset_fetcher.ended, } dataset_state = None @@ -541,17 +526,11 @@ def load_state_dict(self, state_dict): self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict - if isinstance(self._index_sampler, Stateful) or isinstance( - self._sampler_iter, Stateful - ): - self._index_sampler = try_to_deserialize( - self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] - ) + if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): + self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize( - self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] - ) + self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) if state_dict[_ITERATOR_FINISHED]: try: next(self._sampler_iter) @@ -563,9 +542,7 @@ def load_state_dict(self, state_dict): torch.utils.data.dataloader._InfiniteConstantSampler, ): # Fallback to fastforward - self._sampler_iter = itertools.islice( - self._index_sampler, self._sampler_iter_yielded, None - ) + self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._num_yielded = state_dict[self._NUM_YIELDED] self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] @@ -574,12 +551,8 @@ def load_state_dict(self, state_dict): # 1. try to restore dataset state # 2. generate dataset iterator # 3. try to restore iterator state - if state_dict[_DATASET_STATE] is not None and isinstance( - self._dataset, Stateful - ): - self._dataset = try_to_deserialize( - self._dataset, state_dict[_DATASET_STATE] - ) + if state_dict[_DATASET_STATE] is not None and isinstance(self._dataset, Stateful): + self._dataset = try_to_deserialize(self._dataset, state_dict[_DATASET_STATE]) self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, @@ -589,18 +562,14 @@ def load_state_dict(self, state_dict): ) if self._dataset_kind == _DatasetKind.Iterable: # If either dataset or it's iter is stateful, we don't fast-forward - if isinstance(self._dataset, Stateful) or isinstance( - self._dataset_fetcher.dataset_iter, Stateful - ): + if isinstance(self._dataset, Stateful) or isinstance(self._dataset_fetcher.dataset_iter, Stateful): if state_dict[_FETCHER_STATE] is not None: if state_dict[_FETCHER_STATE][_DATASET_ITER_STATE] is not None: self._dataset_fetcher.dataset_iter = try_to_deserialize( self._dataset_fetcher.dataset_iter, state_dict[_FETCHER_STATE][_DATASET_ITER_STATE], ) - self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][ - _FETCHER_ENDED - ] + self._dataset_fetcher.ended = state_dict[_FETCHER_STATE][_FETCHER_ENDED] else: # No state, just try to fastforward if self._num_yielded > 0: @@ -975,20 +944,16 @@ def __init__(self, loader, next_iter_state): self._SNAPSHOT in next_iter_state ), f"State doesn't contain key '{self._SNAPSHOT}' expected for multiprocess dataloader" wstates = next_iter_state[self._SNAPSHOT].get(self._WORKER_SNAPSHOTS, {}) - assert set(map(self._worker_key, range(len(wstates)))) == set( - wstates.keys() - ), ( + assert set(map(self._worker_key, range(len(wstates)))) == set(wstates.keys()), ( len(wstates), wstates.keys(), ) for worker_key, sd in wstates.items(): worker_states[worker_key] = sd - self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( - self._BASE_SEED, self._base_seed + self._base_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get(self._BASE_SEED, self._base_seed) + self._shared_seed = next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT].get( + _SHARED_SEED, self._shared_seed ) - self._shared_seed = next_iter_state[self._SNAPSHOT][ - self._MAIN_SNAPSHOT - ].get(_SHARED_SEED, self._shared_seed) for i in range(self._num_workers): # No certainty which module multiprocessing_context is @@ -1036,9 +1001,7 @@ def __init__(self, loader, next_iter_state): if self._pin_memory_device == "xpu": current_device = torch.xpu.current_device() # type: ignore[attr-defined] elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): - custom_device_mod = getattr( - torch, torch._C._get_privateuse1_backend_name() - ) + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) current_device = custom_device_mod.current_device() else: current_device = torch.cuda.current_device() # choose cuda for default @@ -1071,9 +1034,7 @@ def __init__(self, loader, next_iter_state): import atexit for w in self._workers: - atexit.register( - _StatefulMultiProcessingDataLoaderIter._clean_up_worker, w - ) + atexit.register(_StatefulMultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] @@ -1089,23 +1050,17 @@ def __init__(self, loader, next_iter_state): # We need to send initial worker state back to the main process to handle state_dict() requests # before n >= num_workers steps are taken. # self._worker_snapshots: Dict[str, _IncrementalWorkerState] = {} - self._worker_snapshots = { - key: _IncrementalWorkerState(state) for key, state in worker_states.items() - } + self._worker_snapshots = {key: _IncrementalWorkerState(state) for key, state in worker_states.items()} self._reset(loader, first_iter=True, prime_prefetch=next_iter_state is None) # Try to restore main state if next_iter_state is not None: - self._restore_main_state( - next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT] - ) + self._restore_main_state(next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT]) self._num_yielded = next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP] self._update_snapshot( snapshot_step=next_iter_state[self._SNAPSHOT][self._SNAPSHOT_STEP], - last_yielded_worker_id=next_iter_state[self._SNAPSHOT][ - self._LAST_YIELDED_WORKER_ID - ], + last_yielded_worker_id=next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID], num_workers=self._num_workers, main_snapshot=next_iter_state[self._SNAPSHOT][self._MAIN_SNAPSHOT], worker_snapshots=self._worker_snapshots, @@ -1116,10 +1071,7 @@ def __init__(self, loader, next_iter_state): for state in worker_states.values(): if state is None: continue - if ( - state[_DATASET_STATE] is None - and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None - ): + if state[_DATASET_STATE] is None and state[_FETCHER_STATE][_DATASET_ITER_STATE] is None: fast_forward = True break @@ -1136,17 +1088,10 @@ def __init__(self, loader, next_iter_state): for _ in range(self._num_yielded): next(self) # Check if last_yielded_worker_id matches - if ( - self._last_yielded_worker_id - != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] - ): - raise ValueError( - "last_yielded_worker_id does not match, the dataset may have changed" - ) + if self._last_yielded_worker_id != next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID]: + raise ValueError("last_yielded_worker_id does not match, the dataset may have changed") else: - self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][ - self._LAST_YIELDED_WORKER_ID - ] + self._last_yielded_worker_id = next_iter_state[self._SNAPSHOT][self._LAST_YIELDED_WORKER_ID] for _ in range(self._last_yielded_worker_id + 1): next(self._worker_queue_idx_cycle) for _ in range(self._prefetch_factor * self._num_workers): @@ -1164,9 +1109,7 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} - self._tasks_outstanding = ( - 0 # always equal to count(v for v in task_info.values() if len(v) == 1) - ) + self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1) # A list of booleans representing whether each worker still has work to # do, i.e., not having exhausted its iterable dataset object. It always # contains all `True`s if not using an iterable-style dataset @@ -1191,9 +1134,7 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): while remaining > 0: _, data = self._get_data() if not all(self._workers_status): - raise ValueError( - f"A worker has failed during startup! {self._workers_status}" - ) + raise ValueError(f"A worker has failed during startup! {self._workers_status}") elif isinstance(data, _AckStartup): if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() @@ -1201,37 +1142,27 @@ def _reset(self, loader, first_iter=False, prime_prefetch=True): if data.is_delta: self._worker_snapshots[self._worker_key(data.worker_id)].apply_delta(data.initial_state) # type: ignore[arg-type] else: - self._worker_snapshots[ - self._worker_key(data.worker_id) - ] = _IncrementalWorkerState( + self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) remaining -= 1 else: - raise ValueError( - f"Invalid response from worker after startup: {data}" - ) + raise ValueError(f"Invalid response from worker after startup: {data}") else: # We resume the prefetching in case it was enabled for idx in range(self._num_workers): - self._index_queues[idx].put( - _utils.worker._ResumeIteration(self._shared_seed) - ) + self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed)) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: return_idx, data = self._get_data() if not all(self._workers_status): - raise ValueError( - f"A worker has failed during Resume! {self._workers_status}" - ) + raise ValueError(f"A worker has failed during Resume! {self._workers_status}") if isinstance(return_idx, _utils.worker._ResumeIteration): assert isinstance(data, _AckStartup), (return_idx, data) if isinstance(data.initial_state, ExceptionWrapper): data.initial_state.reraise() assert data.initial_state is not None, data - self._worker_snapshots[ - self._worker_key(data.worker_id) - ] = _IncrementalWorkerState( + self._worker_snapshots[self._worker_key(data.worker_id)] = _IncrementalWorkerState( data.initial_state # type: ignore[arg-type] ) resume_iteration_cnt -= 1 @@ -1299,9 +1230,7 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): self._mark_worker_as_unavailable(worker_id) if len(failed_workers) > 0: pids_str = ", ".join(str(w.pid) for w in failed_workers) - raise RuntimeError( - f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" - ) from e + raise RuntimeError(f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly") from e if isinstance(e, queue.Empty): return (False, None) import errno @@ -1439,9 +1368,7 @@ def _get_data(self): if success: return data else: - raise RuntimeError( - f"DataLoader timed out after {self._timeout} seconds" - ) + raise RuntimeError(f"DataLoader timed out after {self._timeout} seconds") elif self._pin_memory: while self._pin_memory_thread.is_alive(): success, data = self._try_get_data() @@ -1474,9 +1401,7 @@ def _next_data(self): info = self._task_info.get(self._rcvd_idx, None) if info: worker_id = info[0] - if ( - len(info) == 2 or self._workers_status[worker_id] - ): # has data or is still active + if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 @@ -1492,9 +1417,7 @@ def _next_data(self): if len(self._task_info[self._rcvd_idx]) == 2: data, worker_id, state_dict = self._task_info.pop(self._rcvd_idx)[1] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot( - self._worker_key(data.worker_id), state_dict - ) + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) self._rcvd_idx += 1 continue else: @@ -1511,9 +1434,7 @@ def _next_data(self): self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) - assert ( - state_dict is not None - ), "StopIteration should always be accompanied by a state_dict" + assert state_dict is not None, "StopIteration should always be accompanied by a state_dict" self._try_put_index() # We want to process states until we get to that position # in the worker cycle, therefore if out-of-order we want @@ -1524,9 +1445,7 @@ def _next_data(self): if not self._in_order: # don't store it for later, process now if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot( - self._worker_key(data.worker_id), state_dict - ) + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) continue del self._task_info[idx] return self._process_data(data, worker_id, state_dict) @@ -1534,9 +1453,7 @@ def _next_data(self): else: del self._task_info[idx] if isinstance(data, _utils.worker._IterableDatasetStopIteration): - self._update_worker_snapshot( - self._worker_key(data.worker_id), state_dict - ) + self._update_worker_snapshot(self._worker_key(data.worker_id), state_dict) self._rcvd_idx += 1 continue else: @@ -1558,26 +1475,18 @@ def _restore_main_state(self, state_dict): assert self._num_workers == state_dict[self._NUM_WORKERS] # Try to restore from either _index_sampler state_dict or _sampler_iter state_dict self._sampler_iter_yielded = state_dict[_SAMPLER_ITER_YIELDED] - if isinstance(self._index_sampler, Stateful) or isinstance( - self._sampler_iter, Stateful - ): - self._index_sampler = try_to_deserialize( - self._index_sampler, state_dict[_INDEX_SAMPLER_STATE] - ) + if isinstance(self._index_sampler, Stateful) or isinstance(self._sampler_iter, Stateful): + self._index_sampler = try_to_deserialize(self._index_sampler, state_dict[_INDEX_SAMPLER_STATE]) self._sampler_iter = iter(self._index_sampler) if state_dict[_SAMPLER_ITER_STATE] is not None: - self._sampler_iter = try_to_deserialize( - self._sampler_iter, state_dict[_SAMPLER_ITER_STATE] - ) + self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) else: if not isinstance( self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler, ): # Fallback to fastforward - self._sampler_iter = itertools.islice( - self._index_sampler, self._sampler_iter_yielded, None - ) + self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] self._shared_seed = state_dict[_SHARED_SEED] self._base_seed = state_dict[self._BASE_SEED] @@ -1614,9 +1523,7 @@ def _try_put_index(self): if self._workers_status[worker_queue_idx]: if self._in_order: break - elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( - self._workers_status - ): + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(self._workers_status): # when self._in_order is False, distribute work to a worker if it has capacity # _workers_status is updated only in this thread, so the sum is guaranteed > 0 break @@ -1642,20 +1549,14 @@ def _process_data(self, data, worker_id, state_dict): self._last_yielded_worker_id = worker_id # Update latest worker state if state_dict is not None: - self._update_worker_snapshot( - self._worker_key(state_dict[_WORKER_ID]), state_dict - ) - if self._snapshot_interval and ( - (self._num_yielded + 1) % self._snapshot_interval == 0 - ): + self._update_worker_snapshot(self._worker_key(state_dict[_WORKER_ID]), state_dict) + if self._snapshot_interval and ((self._num_yielded + 1) % self._snapshot_interval == 0): self._take_snapshot() return data def _take_snapshot(self): main_snapshot_idx = None - while len(self._main_snapshots) and ( - self._main_snapshots[0][0] <= self._rcvd_idx - 1 - ): + while len(self._main_snapshots) and (self._main_snapshots[0][0] <= self._rcvd_idx - 1): main_snapshot_idx, main_snapshot = self._main_snapshots.popleft() if not self._in_order and main_snapshot_idx is None: # in_order is False and no main snapshot is available as we're ahead of rcvd_idx @@ -1685,10 +1586,7 @@ def _update_snapshot( self._SNAPSHOT_STEP: snapshot_step, self._LAST_YIELDED_WORKER_ID: last_yielded_worker_id, self._MAIN_SNAPSHOT: main_snapshot, - self._WORKER_SNAPSHOTS: { - key: worker_state.get_state() - for key, worker_state in worker_snapshots.items() - }, + self._WORKER_SNAPSHOTS: {key: worker_state.get_state() for key, worker_state in worker_snapshots.items()}, } def _mark_worker_as_unavailable(self, worker_id, shutdown=False): @@ -1721,11 +1619,7 @@ def _shutdown_workers(self): # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. - if ( - _utils is None - or _utils.python_exit_status is True - or _utils.python_exit_status is None - ): + if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted.