Skip to content

Commit

Permalink
Fix unclean shutdown by setting daemon property on utility threads (#…
Browse files Browse the repository at this point in the history
…1348)

* Fix unclean shutdown by setting threads as daemons

* udfs threads are not daemons
  • Loading branch information
andrewkho authored Oct 25, 2024
1 parent dcaa6ea commit 7fdd0e9
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
self._read_thread = threading.Thread(
target=_populate_queue,
args=(self.source, self._in_q, self._sem, self._stop, True),
daemon=True,
)
self._map_threads: List[Union[threading.Thread, mp.Process]] = []
for worker_id in range(self.num_workers):
Expand All @@ -131,7 +132,11 @@ def __init__(
else mp_context.Process(target=_apply_udf, args=args)
)
self._sort_q: queue.Queue = queue.Queue()
self._sort_thread = threading.Thread(target=_sort_worker, args=(self._intermed_q, self._sort_q, self._stop))
self._sort_thread = threading.Thread(
target=_sort_worker,
args=(self._intermed_q, self._sort_q, self._stop),
daemon=True,
)

self._out_q = self._intermed_q
if self.in_order:
Expand All @@ -143,10 +148,10 @@ def __init__(
if self.in_order:
self._sort_thread.start()

def __iter__(self):
def __iter__(self) -> Iterator[T]:
return self

def __next__(self):
def __next__(self) -> T:
while True:
if self._stop.is_set():
raise StopIteration()
Expand Down Expand Up @@ -282,14 +287,15 @@ def __init__(self, source: BaseNode[T], prefetch_factor: int, worker: _WorkerTyp
self._thread = threading.Thread(
target=self.worker,
args=(self.source, self._q, self._sem, self._stop_event),
daemon=True,
)
self._thread.start()
self._stopped = False

def __iter__(self) -> Iterator[T]:
return self

def __next__(self):
def __next__(self) -> T:
if self._stopped:
raise StopIteration()

Expand Down Expand Up @@ -321,4 +327,5 @@ def __del__(self):

def _shutdown(self):
self._stop_event.set()
self._thread.join(timeout=QUEUE_TIMEOUT)
if self._thread.is_alive():
self._thread.join(timeout=QUEUE_TIMEOUT)

0 comments on commit 7fdd0e9

Please sign in to comment.