Skip to content

Commit

Permalink
add logging and disable daemon
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Nov 6, 2024
1 parent b9fb830 commit 2dfc975
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_ev
cur_idx += 1


def dump_threads():
import sys
import traceback

"""Prints stack traces for all running threads."""
for thread_id, frame in sys._current_frames().items():
print(f"\nThread {thread_id}: ")
traceback.print_stack(frame)


class _ParallelMapperIter(Iterator[T]):
"""_ParallelMapperIter will start at least two threads, one running
_populate_queue, and one for _apply_udf. If in_order == True, a
Expand Down Expand Up @@ -143,7 +153,7 @@ def __init__(
self._sem,
self._stop,
),
daemon=True,
daemon=False,
)
self._workers: List[Union[threading.Thread, mp.Process]] = []
for worker_id in range(self.num_workers):
Expand All @@ -155,15 +165,15 @@ def __init__(
self._stop if self.method == "thread" else self._mp_stop,
)
self._workers.append(
threading.Thread(target=_apply_udf, args=args, daemon=True)
threading.Thread(target=_apply_udf, args=args, daemon=False)
if self.method == "thread"
else mp_context.Process(target=_apply_udf, args=args, daemon=True)
else mp_context.Process(target=_apply_udf, args=args, daemon=False)
)
self._sort_q: queue.Queue = queue.Queue()
self._sort_thread = threading.Thread(
target=_sort_worker,
args=(self._intermed_q, self._sort_q, self._stop),
daemon=True,
daemon=False,
)

self._out_q = self._intermed_q
Expand Down Expand Up @@ -195,6 +205,11 @@ def __next__(self) -> T:
item, idx = self._out_q.get(block=True, timeout=QUEUE_TIMEOUT)
self._steps_since_snapshot += 1
except queue.Empty:
print(
f"Empty queue, {self._stop.is_set()=}, {self._mp_stop.is_set()=}, "
f"{self._read_thread.is_alive()=}, {self._sort_thread.is_alive()=}, {[w.is_alive() for w in self._workers]}"
)
dump_threads()
continue

if isinstance(item, StopIteration):
Expand Down Expand Up @@ -229,12 +244,21 @@ def _shutdown(self):
self._stop.set()
self._mp_stop.set()
if self._read_thread.is_alive():
self._read_thread.join(timeout=QUEUE_TIMEOUT)
self._read_thread.join(timeout=QUEUE_TIMEOUT * 5)
if self._read_thread.is_alive():
dump_threads()
raise RuntimeError("Read thread did not stop in time")
if self._sort_thread.is_alive():
self._sort_thread.join(timeout=QUEUE_TIMEOUT)
self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5)
if self._sort_thread.is_alive():
dump_threads()
raise RuntimeError("sort thread did not stop in time")
for t in self._workers:
if t.is_alive():
t.join(timeout=QUEUE_TIMEOUT)
t.join(timeout=QUEUE_TIMEOUT * 5)
if t.is_alive():
dump_threads()
raise RuntimeError("worker thread did not stop in time")


class ParallelMapper(BaseNode[T]):
Expand Down Expand Up @@ -385,7 +409,7 @@ def __init__(
self._sem,
self._stop_event,
),
daemon=True,
daemon=False,
)
self._thread.start()
self._stopped = False
Expand All @@ -408,6 +432,8 @@ def __next__(self) -> T:
self._steps_since_snapshot += 1
break
except queue.Empty:
print(f"Empty queue, {self._stop_event.is_set()=}, {self._thread.is_alive()=}")
dump_threads()
continue

if isinstance(item, StopIteration):
Expand Down Expand Up @@ -445,4 +471,7 @@ def __del__(self):
def _shutdown(self):
self._stop_event.set()
if self._thread.is_alive():
self._thread.join(timeout=QUEUE_TIMEOUT)
self._thread.join(timeout=QUEUE_TIMEOUT * 5)
if self._thread.is_alive():
dump_threads()
raise RuntimeError("sort thread did not stop in time")

0 comments on commit 2dfc975

Please sign in to comment.