diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index eb31ed17d2d..a6a346b6d11 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -407,8 +407,10 @@ def update_consumer_state_for_worker( def worker_waiting(self, worker: Worker): """This worker is now waiting for work.""" # Queue to broker and service waiting lists - self.waiting.append(worker) - worker.service.waiting.append(worker) + if worker not in self.waiting: + self.waiting.append(worker) + if worker not in worker.service.waiting: + worker.service.waiting.append(worker) worker.reset_expiry() self.update_consumer_state_for_worker(worker.syft_worker_id, ConsumerState.IDLE) self.dispatch(worker.service, None) @@ -529,7 +531,10 @@ def process_worker(self, address: bytes, msg: List[bytes]): elif QueueMsgProtocol.W_HEARTBEAT == command: if worker_ready: - worker.reset_expiry() + # If worker is ready then reset expiry + # and add it to worker waiting list + # if not already present + self.worker_waiting(worker) else: # extract the syft worker id and worker pool name from the message # Get the corresponding worker pool and worker @@ -700,7 +705,7 @@ def _run(self): finally: self.clear_job() elif command == QueueMsgProtocol.W_HEARTBEAT: - pass + self.set_producer_alive() elif command == QueueMsgProtocol.W_DISCONNECT: self.reconnect_to_producer() else: