diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index f8e12c351..8e83e90b8 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -115,6 +115,7 @@ def __init__( self._read_thread = threading.Thread( target=_populate_queue, args=(self.source, self._in_q, self._sem, self._stop, True), + daemon=True, ) self._map_threads: List[Union[threading.Thread, mp.Process]] = [] for worker_id in range(self.num_workers): @@ -131,7 +132,11 @@ def __init__( else mp_context.Process(target=_apply_udf, args=args) ) self._sort_q: queue.Queue = queue.Queue() - self._sort_thread = threading.Thread(target=_sort_worker, args=(self._intermed_q, self._sort_q, self._stop)) + self._sort_thread = threading.Thread( + target=_sort_worker, + args=(self._intermed_q, self._sort_q, self._stop), + daemon=True, + ) self._out_q = self._intermed_q if self.in_order: @@ -143,10 +148,10 @@ def __init__( if self.in_order: self._sort_thread.start() - def __iter__(self): + def __iter__(self) -> Iterator[T]: return self - def __next__(self): + def __next__(self) -> T: while True: if self._stop.is_set(): raise StopIteration() @@ -282,6 +287,7 @@ def __init__(self, source: BaseNode[T], prefetch_factor: int, worker: _WorkerTyp self._thread = threading.Thread( target=self.worker, args=(self.source, self._q, self._sem, self._stop_event), + daemon=True, ) self._thread.start() self._stopped = False @@ -289,7 +295,7 @@ def __init__(self, source: BaseNode[T], prefetch_factor: int, worker: _WorkerTyp def __iter__(self) -> Iterator[T]: return self - def __next__(self): + def __next__(self) -> T: if self._stopped: raise StopIteration() @@ -321,4 +327,5 @@ def __del__(self): def _shutdown(self): self._stop_event.set() - self._thread.join(timeout=QUEUE_TIMEOUT) + if self._thread.is_alive(): + self._thread.join(timeout=QUEUE_TIMEOUT)