Skip to content

Commit

Permalink
WIP started work on bringing thread logic into on_message
Browse files Browse the repository at this point in the history
  • Loading branch information
stchris committed Jan 16, 2024
1 parent 69dc59b commit 422e26f
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions servicelayer/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,42 +295,46 @@ def __init__(
self.queues = ensure_list(queues)
self.version = version
self.prefetch_count = prefetch_count
self.local_queue = Queue()
# self.local_queue = Queue()

def on_signal(self, signal, _):
log.warning(f"Shutting down worker (signal {signal})")
# Exit eagerly without waiting for current task to finish running
sys.exit(int(signal))

def on_message(self, channel, method, properties, body, args):
def on_message(self, channel, method_frame, properties, body, args):
"""RabbitMQ on_message event handler.
We have to make sure it doesn't block for long to ensure that RabbitMQ
heartbeats are not interrupted.
"""
connection = args[0]
task = get_task(body, method.delivery_tag)
self.local_queue.put((task, channel, connection))

def process_blocking(self):
thrds = args
delivery_tag = method_frame.delivery_tag
blocking = True
task = get_task(body, delivery_tag)
t = threading.Thread(target=self.process, args=(blocking, task, channel, delivery_tag, body))
t.start()
thrds.append(t)

def process_blocking(self, task, channel, delivery_tag, body):
"""Blocking worker thread - executes tasks from a queue and periodic tasks"""
while True:
try:
(task, channel, connection) = self.local_queue.get(timeout=TIMEOUT)
# (task, channel, connection) = self.local_queue.get(timeout=TIMEOUT)
apply_task_context(task, v=self.version)
self.handle(task)
cb = functools.partial(self.ack_message, task, channel)
connection.add_callback_threadsafe(cb)
channel.add_callback_threadsafe(cb)
except Empty:
pass
finally:
clear_contextvars()
self.periodic()

def process_nonblocking(self):
def process_nonblocking(self, task, channel, delivery_tag, body):
"""Non-blocking worker is used for tests only."""
connection = get_rabbitmq_connection()
channel = connection.channel()
# connection = get_rabbitmq_connection()
# channel = connection.channel()
queue_active = {queue: True for queue in self.queues}
while True:
for queue in self.queues:
Expand All @@ -345,11 +349,11 @@ def process_nonblocking(self):
task = get_task(body, method.delivery_tag)
self.handle(task)

def process(self, blocking=True):
def process(self, blocking=True, *args):
if blocking:
self.process_blocking()
self.process_blocking(args)
else:
self.process_nonblocking()
self.process_nonblocking(args)

def handle(self, task: Task):
"""Execute a task."""
Expand Down Expand Up @@ -416,27 +420,22 @@ def run(self):
signal.signal(signal.SIGTERM, self.on_signal)

# worker threads
def process():
return self.process(blocking=True)
# def process():
# return self.process(blocking=True)

if not self.num_threads:
# TODO - seems like we need at least one thread
# consuming and processing require separate threads
self.num_threads = 1

threads = []
for _ in range(self.num_threads):
thread = threading.Thread(target=process)
thread.daemon = True
thread.start()
threads.append(thread)

log.info(f"Worker has {self.num_threads} worker threads.")

connection = get_rabbitmq_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=self.prefetch_count)
on_message_callback = functools.partial(self.on_message, args=(connection,))
channel.basic_qos(prefetch_count=self.num_threads)
threads = []
on_message_callback = functools.partial(self.on_message, args=(threads,))

for queue in self.queues:
channel.queue_declare(
queue=queue,
Expand Down

0 comments on commit 422e26f

Please sign in to comment.