Skip to content

Commit

Permalink
added heartbeat
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Jul 4, 2024
1 parent 4b61047 commit e21fac7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/anemoi/registry/commands/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def add_arguments(self, command_parser):
command_parser.add_argument("--destination", help="Platform destination (e.g. leonardo, lumi)")
command_parser.add_argument("--request", help="Filter tasks to process (key=value list)", nargs="*", default=[])
command_parser.add_argument("--threads", help="Number of threads to use", type=int, default=1)
command_parser.add_argument("--heartbeat", help="Heartbeat interval", type=int, default=60)
command_parser.add_argument("--max-no-heartbeat", help="Max interval without heartbeat", type=int, default=3600)

def run(self, args):
kwargs = vars(args)
Expand Down
53 changes: 46 additions & 7 deletions src/anemoi/registry/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import datetime
import logging
import os
import signal
import threading
import time

from anemoi.utils.humanize import when

from anemoi.registry.entry.dataset import DatasetCatalogueEntry
from anemoi.registry.tasks import TaskCatalogueEntryList

Expand All @@ -29,6 +33,8 @@ def __init__(
target_dir=".",
auto_register=True,
threads=1,
heartbeat=60,
max_no_heartbeat=3600,
request={},
):
"""Run a worker that will process tasks in the queue.
Expand All @@ -46,6 +52,8 @@ def __init__(
self.target_dir = target_dir
self.request = request
self.threads = threads
self.heartbeat = heartbeat
self.max_no_heartbeat = max_no_heartbeat

self.wait = wait
self.stop_if_finished = stop_if_finished
Expand All @@ -57,6 +65,7 @@ def __init__(
raise ValueError(f"Target directory {target_dir} must already exist")

def run(self):

while True:
res = self.process_one_task()

Expand All @@ -67,20 +76,37 @@ def run(self):
LOG.info(f"Waiting {self.wait} seconds before checking again.")
time.sleep(self.wait)

def process_one_task(self):
def choose_task(self):
request = self.request.copy()
request["status"] = request.get("status", "queued")
request["destination"] = request.get("destination", self.destination)
cat = TaskCatalogueEntryList(**request)

# if a task is queued, take it
for entry in TaskCatalogueEntryList(status="queued", **request):
return entry

# else if a task is running, check if it has been running for too long, and free it
cat = TaskCatalogueEntryList(status="running", **request)
if not cat:
LOG.info("No tasks found")
LOG.info(cat.to_str(long=True))
else:
LOG.info("No queued tasks found")
for entry in cat:
updated = datetime.datetime.fromisoformat(entry.record["updated"])
LOG.info(f"Task {entry.key} is already running, last update {when(updated, use_utc=True)}.")
if (datetime.datetime.utcnow() - updated).total_seconds() > self.max_no_heartbeat:
LOG.warning(
f"Task {entry.key} has been running for more than {self.max_no_heartbeat} seconds, freeing it."
)
entry.release_ownership()

def process_one_task(self):
entry = self.choose_task()
if not entry:
return False

entry = cat[-1]
uuid = entry.key
LOG.info(f"Processing task {uuid}: {entry}")
self.parse_entry(entry)
self.parse_entry(entry) # for checking only

entry.take_ownership()
self.process_entry(entry)
Expand All @@ -90,10 +116,23 @@ def process_one_task(self):
return True

def process_entry(self, entry):

destination, source, dataset = self.parse_entry(entry)

dataset_entry = DatasetCatalogueEntry(key=dataset)

# create another thread to send heartbeat
def send_heartbeat():
while True:
try:
entry.set_status("running")
except Exception:
return
time.sleep(self.heartbeat)

thread = threading.Thread(target=send_heartbeat)
thread.start()

LOG.info(f"Transferring {dataset} from '{source}' to '{destination}'")

def get_source_path():
Expand Down Expand Up @@ -140,7 +179,7 @@ def get_source_path():
os.rename(target_tmp_path, target_path)

if self.auto_register:
dataset_entry.add_location(self, platform=destination, path=target_path)
dataset_entry.add_location(platform=destination, path=target_path)

@classmethod
def parse_entry(cls, entry):
Expand Down

0 comments on commit e21fac7

Please sign in to comment.