From 29eca74a1c597ce7ea1ebde5c0705b84aa3aa44c Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:09:21 +0000 Subject: [PATCH] chore: address feeedback --- saq/queue/base.py | 8 +++--- saq/queue/http.py | 4 +-- saq/queue/postgres.py | 4 +-- saq/queue/redis.py | 4 +-- saq/types.py | 4 +-- saq/worker.py | 7 +++-- tests/test_http_stats.py | 57 ++++++++++++++-------------------------- 7 files changed, 35 insertions(+), 53 deletions(-) diff --git a/saq/queue/base.py b/saq/queue/base.py index 4dcff75..60c166f 100644 --- a/saq/queue/base.py +++ b/saq/queue/base.py @@ -30,7 +30,7 @@ DumpType, LoadType, QueueInfo, - QueueStats, + WorkerStats, ) @@ -131,7 +131,7 @@ async def finish_abort(self, job: Job) -> None: await job.finish(Status.ABORTED, error=job.error) @abstractmethod - async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None: + async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: """ Returns & updates stats on the queue. @@ -194,7 +194,7 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None: raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}") return Job(**job_dict, queue=self) - async def stats(self, worker_id: str, ttl: int = 60) -> QueueStats: + async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats: """ Method to be used by workers to update stats. @@ -204,7 +204,7 @@ async def stats(self, worker_id: str, ttl: int = 60) -> QueueStats: Returns: The stats. """ - stats: QueueStats = { + stats: WorkerStats = { "complete": self.complete, "failed": self.failed, "retried": self.retried, diff --git a/saq/queue/http.py b/saq/queue/http.py index dc39c77..67d5d63 100644 --- a/saq/queue/http.py +++ b/saq/queue/http.py @@ -18,7 +18,7 @@ from saq.types import ( CountKind, QueueInfo, - QueueStats, + WorkerStats, ) try: @@ -208,7 +208,7 @@ async def finish_abort(self, job: Job) -> None: async def dequeue(self, timeout: float = 0) -> Job | None: return self.deserialize(await self._send("dequeue", timeout=timeout)) - async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None: + async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl) async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 8b6489e..587d151 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -32,7 +32,7 @@ DumpType, LoadType, QueueInfo, - QueueStats, + WorkerStats, ) try: @@ -670,7 +670,7 @@ async def _enqueue(self, job: Job) -> Job | None: logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) return job - async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None: + async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: async with self.pool.connection() as conn: await conn.execute( SQL( diff --git a/saq/queue/redis.py b/saq/queue/redis.py index 26ca602..6907a32 100644 --- a/saq/queue/redis.py +++ b/saq/queue/redis.py @@ -38,7 +38,7 @@ ListenCallback, LoadType, QueueInfo, - QueueStats, + WorkerStats, VersionTuple, ) @@ -389,7 +389,7 @@ async def finish_abort(self, job: Job) -> None: await self.redis.delete(job.abort_id) await super().finish_abort(job) - async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None: + async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: current = now() async with self.redis.pipeline(transaction=True) as pipe: key = self.namespace(f"stats:{worker_id}") diff --git a/saq/types.py b/saq/types.py index fc1b492..005e82e 100644 --- a/saq/types.py +++ b/saq/types.py @@ -63,9 +63,9 @@ class QueueInfo(t.TypedDict): "A truncated list containing the jobs that are scheduled to execute soonest" -class QueueStats(t.TypedDict): +class WorkerStats(t.TypedDict): """ - Queue Stats, could also be used for Worker Stats + Worker Stats """ complete: int diff --git a/saq/worker.py b/saq/worker.py index 1f2ebbc..301a544 100644 --- a/saq/worker.py +++ b/saq/worker.py @@ -36,7 +36,7 @@ ReceivesContext, SettingsDict, TimersDict, - QueueStats, + WorkerStats, ) @@ -118,7 +118,7 @@ def __init__( self.burst_jobs_processed = 0 self.burst_jobs_processed_lock = threading.Lock() self.burst_condition_met = False - self.id = id if id is not None else uuid1() + self.id = uuid1() if id is None else id if self.burst: if self.dequeue_timeout <= 0: @@ -155,7 +155,6 @@ async def start(self) -> None: """Start processing jobs and upkeep tasks.""" logger.info("Worker starting: %s", repr(self.queue)) logger.debug("Registered functions:\n%s", "\n".join(f" {key}" for key in self.functions)) - await self.stats() try: self.event = asyncio.Event() @@ -218,7 +217,7 @@ async def schedule(self, lock: int = 1) -> None: if scheduled: logger.info("Scheduled %s", scheduled) - async def stats(self, ttl: int = 60) -> QueueStats: + async def stats(self, ttl: int = 60) -> WorkerStats: return await self.queue.stats(self.id, ttl) async def upkeep(self) -> list[Task[None]]: diff --git a/tests/test_http_stats.py b/tests/test_http_stats.py index 32055ce..2e39e18 100644 --- a/tests/test_http_stats.py +++ b/tests/test_http_stats.py @@ -1,58 +1,44 @@ """Validate that the worker id in the context of the http proxy is the id of the worker rather than the queue.""" import unittest -from http.server import HTTPServer, BaseHTTPRequestHandler +from aiohttp import web from saq import Queue, Worker from saq.queue.http import HttpProxy from saq.types import Context from tests.helpers import setup_postgres, create_postgres_queue, teardown_postgres -import asyncio -import threading async def echo(_ctx: Context, *, a: int) -> int: return a -class ProxyRequestHandler(BaseHTTPRequestHandler): - def __init__(self, *args, proxy=None, **kwargs): - self.proxy = proxy - super().__init__(*args, **kwargs) - - def do_POST(self): - length = int(self.headers["Content-Length"]) - body = self.rfile.read(length).decode("utf-8") - response = asyncio.run(self.proxy.process(body)) +class TestQueue(unittest.IsolatedAsyncioTestCase): + async def handle_post(self, request): + body = await request.text() + response = await self.proxy.process(body) if response: - self.send_response(200) - self.send_header("Content-type", "application/json") - self.end_headers() - self.wfile.write(response.encode("utf-8")) + return web.Response(text=response, content_type="application/json") else: - self.send_response(200) - self.end_headers() + return web.Response(status=200) - -class TestQueue(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self) -> None: await setup_postgres() + self.queue = await create_postgres_queue() + self.proxy = HttpProxy(queue=self.queue) + self.app = web.Application() + self.app.add_routes([web.post("/", self.handle_post)]) + self.runner = web.AppRunner(self.app) + await self.runner.setup() + self.site = web.TCPSite(self.runner, "localhost", 8080) + await self.site.start() async def asyncTearDown(self) -> None: await teardown_postgres() + await self.site.stop() + await self.runner.cleanup() async def test_http_proxy_with_two_workers(self) -> None: - queue = await create_postgres_queue() - proxy = HttpProxy(queue=queue) - - server = HTTPServer( - ("localhost", 8080), - lambda *args, **kwargs: ProxyRequestHandler(*args, proxy=proxy, **kwargs), - ) - server_thread = threading.Thread(target=server.serve_forever) - server_thread.daemon = True - server_thread.start() - queue1 = Queue.from_url("http://localhost:8080/") await queue1.connect() queue2 = Queue.from_url("http://localhost:8080/") @@ -69,12 +55,12 @@ async def test_http_proxy_with_two_workers(self) -> None: ) await worker2.stats() local_worker = Worker( - queue=queue, + queue=self.queue, functions=[echo], ) await local_worker.stats() - root_info = await queue.info() + root_info = await self.queue.info() info1 = await queue1.info() info2 = await queue2.info() @@ -85,7 +71,4 @@ async def test_http_proxy_with_two_workers(self) -> None: await queue1.disconnect() await queue2.disconnect() - await queue.disconnect() - - server.shutdown() - server_thread.join() + await self.queue.disconnect()