From 2dfc975545938147d5521f908935900d4fd0532c Mon Sep 17 00:00:00 2001 From: andrewkho Date: Tue, 5 Nov 2024 18:23:00 -0800 Subject: [PATCH] add logging and disable daemon --- torchdata/nodes/map.py | 47 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index b64c3dca4..666d71bc7 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -82,6 +82,16 @@ def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_ev cur_idx += 1 +def dump_threads(): + import sys + import traceback + + """Prints stack traces for all running threads.""" + for thread_id, frame in sys._current_frames().items(): + print(f"\nThread {thread_id}: ") + traceback.print_stack(frame) + + class _ParallelMapperIter(Iterator[T]): """_ParallelMapperIter will start at least two threads, one running _populate_queue, and one for _apply_udf. If in_order == True, a @@ -143,7 +153,7 @@ def __init__( self._sem, self._stop, ), - daemon=True, + daemon=False, ) self._workers: List[Union[threading.Thread, mp.Process]] = [] for worker_id in range(self.num_workers): @@ -155,15 +165,15 @@ def __init__( self._stop if self.method == "thread" else self._mp_stop, ) self._workers.append( - threading.Thread(target=_apply_udf, args=args, daemon=True) + threading.Thread(target=_apply_udf, args=args, daemon=False) if self.method == "thread" - else mp_context.Process(target=_apply_udf, args=args, daemon=True) + else mp_context.Process(target=_apply_udf, args=args, daemon=False) ) self._sort_q: queue.Queue = queue.Queue() self._sort_thread = threading.Thread( target=_sort_worker, args=(self._intermed_q, self._sort_q, self._stop), - daemon=True, + daemon=False, ) self._out_q = self._intermed_q @@ -195,6 +205,11 @@ def __next__(self) -> T: item, idx = self._out_q.get(block=True, timeout=QUEUE_TIMEOUT) self._steps_since_snapshot += 1 except queue.Empty: + print( + f"Empty queue, {self._stop.is_set()=}, {self._mp_stop.is_set()=}, " + f"{self._read_thread.is_alive()=}, {self._sort_thread.is_alive()=}, {[w.is_alive() for w in self._workers]}" + ) + dump_threads() continue if isinstance(item, StopIteration): @@ -229,12 +244,21 @@ def _shutdown(self): self._stop.set() self._mp_stop.set() if self._read_thread.is_alive(): - self._read_thread.join(timeout=QUEUE_TIMEOUT) + self._read_thread.join(timeout=QUEUE_TIMEOUT * 5) + if self._read_thread.is_alive(): + dump_threads() + raise RuntimeError("Read thread did not stop in time") if self._sort_thread.is_alive(): - self._sort_thread.join(timeout=QUEUE_TIMEOUT) + self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5) + if self._sort_thread.is_alive(): + dump_threads() + raise RuntimeError("sort thread did not stop in time") for t in self._workers: if t.is_alive(): - t.join(timeout=QUEUE_TIMEOUT) + t.join(timeout=QUEUE_TIMEOUT * 5) + if t.is_alive(): + dump_threads() + raise RuntimeError("worker thread did not stop in time") class ParallelMapper(BaseNode[T]): @@ -385,7 +409,7 @@ def __init__( self._sem, self._stop_event, ), - daemon=True, + daemon=False, ) self._thread.start() self._stopped = False @@ -408,6 +432,8 @@ def __next__(self) -> T: self._steps_since_snapshot += 1 break except queue.Empty: + print(f"Empty queue, {self._stop_event.is_set()=}, {self._thread.is_alive()=}") + dump_threads() continue if isinstance(item, StopIteration): @@ -445,4 +471,7 @@ def __del__(self): def _shutdown(self): self._stop_event.set() if self._thread.is_alive(): - self._thread.join(timeout=QUEUE_TIMEOUT) + self._thread.join(timeout=QUEUE_TIMEOUT * 5) + if self._thread.is_alive(): + dump_threads() + raise RuntimeError("sort thread did not stop in time")