-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththreads.py
executable file
·96 lines (80 loc) · 2.57 KB
/
threads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import threading
from contextlib import contextmanager
from six.moves import queue
import logger
class ShareSessionThread(threading.Thread):
def __init__(self, th=None):
'''
Share TensorFlow session between threads
Args:
th: threading.Thread or None
'''
super(ShareSessionThread, self).__init__()
if th is not None:
assert isinstance(th, threading.Thread), th
self._th = th
self.name = th.name
self.daemon = th.daemon
@contextmanager
def default_sess(self):
if self._sess:
with self._sess.as_default():
yield self._sess
else:
logger.warn(f"ShareSessionThread {self.name} wasn't under a default session!")
yield None
def start(self):
import tensorflow as tf
self._sess = tf.get_default_session()
super(ShareSessionThread, self).start()
def run(self):
if not self._th:
raise NotImplementedError()
with self._sess.as_default():
self._th.run()
class StopableThread(threading.Thread):
def __init__(self, event=None):
'''
Create a stopable thread
Args:
event: threading.Event or None
'''
super(StopableThread, self).__init__()
if event is None:
self._stop_evt = threading.Event()
def stop(self):
self._stop_evt.set()
def stopped(self):
return self._stop_evt.isSet()
def queue_put_stopable(self, q, obj):
''' Try to put obj to q (queue.Queue), but give up when thread is stopped'''
while not self.stopped():
try:
q.put(obj, timeout=5)
break
except queue.Full:
pass
def queue_get_stopable(self, q):
''' Try to get obj from q, but give up when thread is stopped'''
while not self.stopped():
try:
return q.get(timeout=5)
except queue.Empty:
pass
class LoopThread(StopableThread):
def __init__(self, pausable=True):
super(LoopThread, self).__init__()
self.paused = False
if pausable:
self._lock = threading.Lock()
self.daemon = True
def run(self):
while not self.stopped():
if not self.paused:
raise NotImplementedError # This is a sample to overide
def pause(self):
self.paused = True
self._lock.acquire()
def resume(self):
self.paused = False
self._lock.release()