Skip to content

Commit

Permalink
Merge pull request #200 from benfdking/let_remote_workers_print_their…
Browse files Browse the repository at this point in the history
…_own_stats

fix: http workers pass their own id
  • Loading branch information
tobymao authored Jan 8, 2025
2 parents 5b97779 + 29eca74 commit b7d1760
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 38 deletions.
32 changes: 24 additions & 8 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Status,
get_default_job_key,
)
from saq.utils import now, uuid1
from saq.utils import now

if t.TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterable, Sequence
Expand All @@ -30,7 +30,7 @@
DumpType,
LoadType,
QueueInfo,
QueueStats,
WorkerStats,
)


Expand Down Expand Up @@ -59,7 +59,6 @@ def __init__(
load: LoadType | None,
) -> None:
self.name = name
self.uuid: str = uuid1()
self.started: int = now()
self.complete = 0
self.failed = 0
Expand Down Expand Up @@ -132,7 +131,16 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
"""
Returns & updates stats on the queue.
Args:
worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats
are attributed to the worker and not the queue instance in the proxy server.
stats: The stats to write.
ttl: The time-to-live in seconds.
"""
pass

@abstractmethod
Expand Down Expand Up @@ -186,16 +194,24 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, ttl: int = 60) -> QueueStats:
stats: QueueStats = {
async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
"""
Method to be used by workers to update stats.
Args:
worker_id: The worker id.
ttl: Time stats are valid for in seconds.
Returns: The stats.
"""
stats: WorkerStats = {
"complete": self.complete,
"failed": self.failed,
"retried": self.retried,
"aborted": self.aborted,
"uptime": now() - self.started,
}

await self.write_stats(stats, ttl)
await self.write_stats(worker_id, stats, ttl)
return stats

def register_before_enqueue(self, callback: BeforeEnqueueType) -> None:
Expand Down
10 changes: 6 additions & 4 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from saq.types import (
CountKind,
QueueInfo,
QueueStats,
WorkerStats,
)

try:
Expand Down Expand Up @@ -103,7 +103,9 @@ async def process(self, body: str) -> str | None:
)
)
if kind == "write_stats":
await self.queue.write_stats(req["stats"], ttl=req["ttl"])
await self.queue.write_stats(
worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"]
)
return None
raise ValueError(f"Invalid request {body}")

Expand Down Expand Up @@ -206,8 +208,8 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, stats: QueueStats, ttl: int) -> None:
await self._send("write_stats", stats=stats, ttl=ttl)
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
return json.loads(await self._send("info", jobs=jobs, offset=offset, limit=limit))
Expand Down
6 changes: 3 additions & 3 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DumpType,
LoadType,
QueueInfo,
QueueStats,
WorkerStats,
)

try:
Expand Down Expand Up @@ -670,7 +670,7 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
Expand All @@ -684,7 +684,7 @@ async def write_stats(self, stats: QueueStats, ttl: int) -> None:
)
).format(stats_table=self.stats_table),
{
"worker_id": self.uuid,
"worker_id": worker_id,
"stats": json.dumps(stats),
"ttl": ttl,
},
Expand Down
12 changes: 3 additions & 9 deletions saq/queue/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ListenCallback,
LoadType,
QueueInfo,
QueueStats,
WorkerStats,
VersionTuple,
)

Expand Down Expand Up @@ -389,16 +389,10 @@ async def finish_abort(self, job: Job) -> None:
await self.redis.delete(job.abort_id)
await super().finish_abort(job)

async def write_stats(self, stats: QueueStats, ttl: int) -> None:
"""
Returns & updates stats on the queue
Args:
ttl: Time-to-live of stats saved in Redis
"""
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
current = now()
async with self.redis.pipeline(transaction=True) as pipe:
key = self.namespace(f"stats:{self.uuid}")
key = self.namespace(f"stats:{worker_id}")
await (
pipe.setex(key, ttl, json.dumps(stats))
.zremrangebyscore(self._stats, 0, current)
Expand Down
4 changes: 2 additions & 2 deletions saq/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class QueueInfo(t.TypedDict):
"A truncated list containing the jobs that are scheduled to execute soonest"


class QueueStats(t.TypedDict):
class WorkerStats(t.TypedDict):
"""
Queue Stats
Worker Stats
"""

complete: int
Expand Down
13 changes: 9 additions & 4 deletions saq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from saq.job import Status
from saq.queue import Queue
from saq.utils import cancel_tasks, millis, now, now_seconds
from saq.utils import cancel_tasks, millis, now, now_seconds, uuid1

if t.TYPE_CHECKING:
from asyncio import Task
Expand All @@ -36,6 +36,7 @@
ReceivesContext,
SettingsDict,
TimersDict,
WorkerStats,
)


Expand All @@ -47,6 +48,7 @@ class Worker:
Worker is used to process and monitor jobs.
Args:
id: optional override for the worker id, if not provided, uuid will be used
queue: instance of saq.queue.Queue
functions: list of async functions
concurrency: number of jobs to process concurrently
Expand All @@ -72,6 +74,7 @@ def __init__(
queue: Queue,
functions: Collection[Function | tuple[str, Function]],
*,
id: t.Optional[str] = None,
concurrency: int = 10,
cron_jobs: Collection[CronJob] | None = None,
startup: ReceivesContext | Collection[ReceivesContext] | None = None,
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(
self.burst_jobs_processed = 0
self.burst_jobs_processed_lock = threading.Lock()
self.burst_condition_met = False
self.id = uuid1() if id is None else id

if self.burst:
if self.dequeue_timeout <= 0:
Expand Down Expand Up @@ -213,6 +217,9 @@ async def schedule(self, lock: int = 1) -> None:
if scheduled:
logger.info("Scheduled %s", scheduled)

async def stats(self, ttl: int = 60) -> WorkerStats:
return await self.queue.stats(self.id, ttl)

async def upkeep(self) -> list[Task[None]]:
"""Start various upkeep tasks async."""

Expand All @@ -233,9 +240,7 @@ async def poll(
asyncio.create_task(poll(self.abort, self.timers["abort"])),
asyncio.create_task(poll(self.schedule, self.timers["schedule"])),
asyncio.create_task(poll(self.queue.sweep, self.timers["sweep"])),
asyncio.create_task(
poll(self.queue.stats, self.timers["stats"], self.timers["stats"] + 1)
),
asyncio.create_task(poll(self.stats, self.timers["stats"], self.timers["stats"] + 1)),
]

async def abort(self, abort_threshold: float) -> None:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_http_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Validate that the worker id in the context of the http proxy is the id of the worker rather than the queue."""

import unittest
from aiohttp import web

from saq import Queue, Worker
from saq.queue.http import HttpProxy
from saq.types import Context
from tests.helpers import setup_postgres, create_postgres_queue, teardown_postgres


async def echo(_ctx: Context, *, a: int) -> int:
return a


class TestQueue(unittest.IsolatedAsyncioTestCase):
async def handle_post(self, request):
body = await request.text()
response = await self.proxy.process(body)
if response:
return web.Response(text=response, content_type="application/json")
else:
return web.Response(status=200)

async def asyncSetUp(self) -> None:
await setup_postgres()
self.queue = await create_postgres_queue()
self.proxy = HttpProxy(queue=self.queue)
self.app = web.Application()
self.app.add_routes([web.post("/", self.handle_post)])
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8080)
await self.site.start()

async def asyncTearDown(self) -> None:
await teardown_postgres()
await self.site.stop()
await self.runner.cleanup()

async def test_http_proxy_with_two_workers(self) -> None:
queue1 = Queue.from_url("http://localhost:8080/")
await queue1.connect()
queue2 = Queue.from_url("http://localhost:8080/")
await queue2.connect()

worker = Worker(
queue=queue1,
functions=[echo],
)
await worker.stats()
worker2 = Worker(
queue=queue2,
functions=[echo],
)
await worker2.stats()
local_worker = Worker(
queue=self.queue,
functions=[echo],
)
await local_worker.stats()

root_info = await self.queue.info()
info1 = await queue1.info()
info2 = await queue2.info()

self.assertEqual(root_info["workers"], info1["workers"])
self.assertEqual(info1["workers"], info2["workers"])
self.assertEqual(info1["workers"].keys(), {worker.id, worker2.id, local_worker.id})
self.assertEqual(info1["workers"].keys(), info2["workers"].keys())

await queue1.disconnect()
await queue2.disconnect()
await self.queue.disconnect()
17 changes: 9 additions & 8 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,15 @@ async def test_retry_delay(self) -> None:
self.assertEqual(job.status, Status.QUEUED)

async def test_stats(self) -> None:
worker = Worker(self.queue, functions=functions)
for _ in range(10):
await self.enqueue("test")
job = await self.dequeue()
await job.retry(None)
await job.finish(Status.ABORTED)
await job.finish(Status.FAILED)
await job.finish(Status.COMPLETE)
stats = await self.queue.stats()
stats = await worker.stats()
self.assertEqual(stats["complete"], 10)
self.assertEqual(stats["failed"], 10)
self.assertEqual(stats["retried"], 10)
Expand All @@ -221,11 +222,10 @@ async def test_info(self) -> None:
await self.enqueue("echo", a=1)
await queue2.enqueue("echo", a=1)
await worker.process()
await self.queue.stats()
await queue2.stats()
await worker.stats()

info = await self.queue.info(jobs=True)
self.assertEqual(set(info["workers"].keys()), {self.queue.uuid, queue2.uuid})
self.assertEqual(set(info["workers"].keys()), {worker.id})
self.assertEqual(info["active"], 0)
self.assertEqual(info["queued"], 1)
self.assertEqual(len(info["jobs"]), 1)
Expand Down Expand Up @@ -580,8 +580,9 @@ async def test_sweep_jobs(self) -> None:
self.assertEqual(job2.status, Status.COMPLETE)

async def test_sweep_stats(self) -> None:
worker = Worker(self.queue, functions=functions)
# Stats are deleted
await self.queue.stats(ttl=1)
await worker.stats(ttl=1)
await asyncio.sleep(1.5)
await self.queue.sweep()
async with self.queue.pool.connection() as conn, conn.cursor() as cursor:
Expand All @@ -593,12 +594,12 @@ async def test_sweep_stats(self) -> None:
WHERE worker_id = %s
"""
).format(self.queue.stats_table),
(self.queue.uuid,),
(worker.id,),
)
self.assertIsNone(await cursor.fetchone())

# Stats are not deleted
await self.queue.stats(ttl=60)
await worker.stats(ttl=60)
await asyncio.sleep(1)
await self.queue.sweep()
async with self.queue.pool.connection() as conn, conn.cursor() as cursor:
Expand All @@ -610,7 +611,7 @@ async def test_sweep_stats(self) -> None:
WHERE worker_id = %s
"""
).format(self.queue.stats_table),
(self.queue.uuid,),
(worker.id,),
)
self.assertIsNotNone(await cursor.fetchone())

Expand Down

0 comments on commit b7d1760

Please sign in to comment.