diff --git a/saq/queue/base.py b/saq/queue/base.py index c405ea3..5a2a9c0 100644 --- a/saq/queue/base.py +++ b/saq/queue/base.py @@ -31,6 +31,7 @@ LoadType, QueueInfo, WorkerStats, + WorkerInfo, ) @@ -131,24 +132,25 @@ async def finish_abort(self, job: Job) -> None: await job.finish(Status.ABORTED, error=job.error) @abstractmethod - async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: + async def write_worker_info( + self, + worker_id: str, + info: WorkerInfo, + ttl: int, + ) -> None: """ - Returns & updates stats on the queue. + Write stats and metadata for a worker. Args: worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats are attributed to the worker and not the queue instance in the proxy server. + queue_key: The key of the queue. + metadata: The metadata to write. stats: The stats to write. ttl: The time-to-live in seconds. """ pass - @abstractmethod - async def write_worker_metadata( - self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int - ) -> None: - pass - @abstractmethod async def _retry(self, job: Job, error: str | None) -> None: pass @@ -200,15 +202,19 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None: raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}") return Job(**job_dict, queue=self) - async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats: + async def worker_info( + self, worker_id: str, queue_key: str, metadata: t.Optional[dict] = None, ttl: int = 60 + ) -> WorkerInfo: """ - Method to be used by workers to update stats. + Method to be used by workers to update worker info. Args: worker_id: The worker id. ttl: Time stats are valid for in seconds. + queue_key: The key of the queue. + metadata: The metadata to write. - Returns: The stats. + Returns: Worker info. """ stats: WorkerStats = { "complete": self.complete, @@ -217,8 +223,17 @@ async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats: "aborted": self.aborted, "uptime": now() - self.started, } - await self.write_stats(worker_id, stats, ttl) - return stats + info: WorkerInfo = { + "stats": stats, + "queue_key": queue_key, + "metadata": metadata, + } + await self.write_worker_info( + worker_id, + info, + ttl=ttl, + ) + return info def register_before_enqueue(self, callback: BeforeEnqueueType) -> None: self._before_enqueues[id(callback)] = callback diff --git a/saq/queue/http.py b/saq/queue/http.py index 6d1156c..5050ab9 100644 --- a/saq/queue/http.py +++ b/saq/queue/http.py @@ -7,7 +7,6 @@ import json import typing as t - from saq.errors import MissingDependencyError from saq.job import Job, Status from saq.queue.base import Queue @@ -18,7 +17,7 @@ from saq.types import ( CountKind, QueueInfo, - WorkerStats, + WorkerInfo, ) try: @@ -102,17 +101,15 @@ async def process(self, body: str) -> str | None: jobs=req["jobs"], offset=req["offset"], limit=req["limit"] ) ) - if kind == "write_stats": - await self.queue.write_stats( - worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"] - ) - return None - if kind == "write_worker_metadata": - await self.queue.write_worker_metadata( - metadata=req["metadata"], - ttl=req["ttl"], - queue_key=req["queue_key"], + if kind == "write_worker_info": + await self.queue.write_worker_info( worker_id=req["worker_id"], + info={ + "stats": req["stats"], + "queue_key": req["queue_key"], + "metadata": req["metadata"], + }, + ttl=req["ttl"], ) return None raise ValueError(f"Invalid request {body}") @@ -216,18 +213,19 @@ async def finish_abort(self, job: Job) -> None: async def dequeue(self, timeout: float = 0) -> Job | None: return self.deserialize(await self._send("dequeue", timeout=timeout)) - async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: - await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl) - - async def write_worker_metadata( - self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int + async def write_worker_info( + self, + worker_id: str, + info: WorkerInfo, + ttl: int, ) -> None: await self._send( - "write_worker_metadata", + "write_worker_info", worker_id=worker_id, - metadata=metadata, + stats=info["stats"], ttl=ttl, - queue_key=queue_key, + queue_key=info["queue_key"], + metadata=info["metadata"], ) async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 187992c..2f2e7df 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -32,7 +32,6 @@ DumpType, LoadType, QueueInfo, - WorkerStats, WorkerInfo, ) @@ -51,7 +50,6 @@ JOBS_TABLE = "saq_jobs" STATS_TABLE = "saq_stats" VERSIONS_TABLE = "saq_versions" -METADATA_TABLE = "saq_worker_metadata" class PostgresQueue(Queue): @@ -95,7 +93,6 @@ def __init__( versions_table: str = VERSIONS_TABLE, jobs_table: str = JOBS_TABLE, stats_table: str = STATS_TABLE, - metadata_table: str = METADATA_TABLE, dump: DumpType | None = None, load: LoadType | None = None, min_size: int = 4, @@ -110,7 +107,6 @@ def __init__( self.versions_table = Identifier(versions_table) self.jobs_table = Identifier(jobs_table) self.stats_table = Identifier(stats_table) - self.metadata_table = Identifier(metadata_table) self.pool = pool self.min_size = min_size self.max_size = max_size @@ -152,7 +148,6 @@ async def init_db(self) -> None: migrations = get_migrations( jobs_table=self.jobs_table, stats_table=self.stats_table, - worker_metadata_table=self.metadata_table, ) target_version = migrations[-1][0] await cursor.execute( @@ -223,43 +218,20 @@ async def disconnect(self) -> None: await self.pool.close() self._has_sweep_lock = False - async def write_worker_metadata( - self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int - ) -> None: - async with self.pool.connection() as conn: - await conn.execute( - SQL( - dedent( - """ - INSERT INTO {worker_metadata_table} (worker_id, queue_key, metadata, expire_at) - VALUES - (%(worker_id)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s) - ON CONFLICT (worker_id) DO UPDATE - SET metadata = %(metadata)s, queue_key = %(queue_key)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s - """ - ) - ).format(worker_metadata_table=self.metadata_table), - { - "worker_id": worker_id, - "metadata": json.dumps(metadata), - "queue_key": queue_key, - "ttl": ttl, - }, - ) - async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: async with self.pool.connection() as conn, conn.cursor() as cursor: await cursor.execute( SQL( dedent( """ - SELECT stats.worker_id, stats.stats, meta.queue_key, meta.metadata - FROM {stats_table} stats - LEFT JOIN {worker_metadata_table} meta ON meta.worker_id = stats.worker_id - WHERE NOW() <= TO_TIMESTAMP(stats.expire_at) - """ + SELECT worker_id, stats, queue_key, metadata + FROM {stats_table} + WHERE expire_at >= EXTRACT(EPOCH FROM NOW()) + AND queue_key = %(queue)s + """ ) - ).format(stats_table=self.stats_table, worker_metadata_table=self.metadata_table), + ).format(stats_table=self.stats_table), + {"queue": self.name}, ) results = await cursor.fetchall() workers: dict[str, WorkerInfo] = { @@ -413,21 +385,6 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: {"now": now_ts}, ) - await cursor.execute( - SQL( - dedent( - """ - -- Delete expired worker metadata - DELETE FROM {worker_metadata_table} - WHERE %(now)s >= expire_at; - """ - ) - ).format( - worker_metadata_table=self.metadata_table, - ), - {"now": now_ts}, - ) - await cursor.execute( SQL( dedent( @@ -778,23 +735,30 @@ async def _enqueue(self, job: Job) -> Job | None: logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) return job - async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: + async def write_worker_info( + self, + worker_id: str, + info: WorkerInfo, + ttl: int, + ) -> None: async with self.pool.connection() as conn: await conn.execute( SQL( dedent( """ - INSERT INTO {stats_table} (worker_id, stats, expire_at) - VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s) + INSERT INTO {stats_table} (worker_id, stats, queue_key, metadata, expire_at) + VALUES (%(worker_id)s, %(stats)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s) ON CONFLICT (worker_id) DO UPDATE - SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s + SET stats = %(stats)s, queue_key = %(queue_key)s, metadata = %(metadata)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s """ ) ).format(stats_table=self.stats_table), { "worker_id": worker_id, - "stats": json.dumps(stats), + "stats": json.dumps(info["stats"]), "ttl": ttl, + "queue_key": info["queue_key"], + "metadata": json.dumps(info["metadata"]), }, ) diff --git a/saq/queue/postgres_migrations.py b/saq/queue/postgres_migrations.py index e2c328e..f354892 100644 --- a/saq/queue/postgres_migrations.py +++ b/saq/queue/postgres_migrations.py @@ -7,7 +7,6 @@ def get_migrations( jobs_table: Identifier, stats_table: Identifier, - worker_metadata_table: Identifier, ) -> t.List[t.Tuple[int, t.List[Composed]]]: return [ ( @@ -49,14 +48,12 @@ def get_migrations( [ SQL( dedent(""" - CREATE TABLE IF NOT EXISTS {worker_metadata_table} ( - worker_id TEXT PRIMARY KEY, - queue_key TEXT NOT NULL, - expire_at BIGINT NOT NULL, - metadata JSONB - ); +ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS metadata JSONB; +ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS queue_key TEXT; +CREATE INDEX IF NOT EXISTS saq_stats_expire_at_idx ON {stats_table} (expire_at); +CREATE INDEX IF NOT EXISTS saq_stats_queue_key_idx ON {stats_table} (queue_key); """) - ).format(worker_metadata_table=worker_metadata_table), + ).format(stats_table=stats_table), ], ), ] diff --git a/saq/queue/redis.py b/saq/queue/redis.py index 883e1da..a0e13cb 100644 --- a/saq/queue/redis.py +++ b/saq/queue/redis.py @@ -38,7 +38,6 @@ ListenCallback, LoadType, QueueInfo, - WorkerStats, VersionTuple, WorkerInfo, ) @@ -131,30 +130,15 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu *_, worker_uuid = key_str.split(":") worker_uuids.append(worker_uuid) - worker_stats = await self.redis.mget( - self.namespace(f"stats:{worker_uuid}") for worker_uuid in worker_uuids - ) worker_metadata = await self.redis.mget( - self.namespace(f"metadata:{worker_uuid}") for worker_uuid in worker_uuids + self.namespace(f"worker_info:{worker_uuid}") for worker_uuid in worker_uuids ) workers: dict[str, WorkerInfo] = {} worker_metadata_dict = dict(zip(worker_uuids, worker_metadata)) - worker_stats_dict = dict(zip(worker_uuids, worker_stats)) for worker in worker_uuids: - workers[worker] = { - "queue_key": None, - "stats": None, - "metadata": None, - } - metadata = worker_metadata_dict[worker] + metadata = worker_metadata_dict.get(worker) if metadata: - metadata_obj = json.loads(metadata) - workers[worker]["metadata"] = metadata_obj["metadata"] - workers[worker]["queue_key"] = metadata_obj["queue_key"] - stats = worker_stats_dict[worker] - if stats: - stats_obj = json.loads(stats) - workers[worker]["stats"] = stats_obj + workers[worker] = json.loads(metadata) queued = await self.count("queued") active = await self.count("active") @@ -405,26 +389,23 @@ async def finish_abort(self, job: Job) -> None: await self.redis.delete(job.abort_id) await super().finish_abort(job) - async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None: + async def write_worker_info( + self, + worker_id: str, + info: WorkerInfo, + ttl: int, + ) -> None: current = now() async with self.redis.pipeline(transaction=True) as pipe: - key = self.namespace(f"stats:{worker_id}") + key = self.namespace(f"worker_info:{worker_id}") await ( - pipe.setex(key, ttl, json.dumps(stats)) + pipe.setex(key, ttl, json.dumps(info)) .zremrangebyscore(self._stats, 0, current) .zadd(self._stats, {key: current + millis(ttl)}) .expire(self._stats, ttl) .execute() ) - async def write_worker_metadata( - self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int - ) -> None: - async with self.redis.pipeline(transaction=True) as pipe: - key = self.namespace(f"metadata:{worker_id}") - metadata = {"queue_key": queue_key, "metadata": metadata} - await pipe.setex(key, ttl, json.dumps(metadata)).expire(key, ttl).execute() - async def _retry(self, job: Job, error: str | None) -> None: job_id = job.id next_retry_delay = job.next_retry_delay() diff --git a/saq/types.py b/saq/types.py index e3e75e5..9b4e55e 100644 --- a/saq/types.py +++ b/saq/types.py @@ -97,14 +97,12 @@ class TimersDict(t.TypedDict): schedule: int "How often we poll to schedule jobs in seconds (default 1)" - stats: int - "How often to update stats in seconds (default 10)" + worker_info: int + "How often to update worker info, stats and metadata in seconds (default 10)" sweep: int "How often to clean up stuck jobs in seconds (default 60)" abort: int "How often to check if a job is aborted in seconds (default 1)" - metadata: int - "How often to write worker metadata in seconds (default 60)" class PartialTimersDict(TimersDict, total=False): diff --git a/saq/worker.py b/saq/worker.py index e5e1841..3f8a658 100644 --- a/saq/worker.py +++ b/saq/worker.py @@ -36,7 +36,7 @@ ReceivesContext, SettingsDict, TimersDict, - WorkerStats, + WorkerInfo, ) @@ -62,7 +62,7 @@ class Worker: after_process: async function to call after a job processes timers: dict with various timer overrides in seconds schedule: how often we poll to schedule jobs - stats: how often to update stats + worker_info: how often to update worker info, stats and metadata sweep: how often to clean up stuck jobs abort: how often to check if a job is aborted dequeue_timeout: how long it will wait to dequeue @@ -104,10 +104,9 @@ def __init__( ) self.timers: TimersDict = { "schedule": 1, - "stats": 10, + "worker_info": 10, "sweep": 60, "abort": 1, - "metadata": 60, } if timers is not None: self.timers.update(timers) @@ -224,17 +223,11 @@ async def schedule(self, lock: int = 1) -> None: if scheduled: logger.info("Scheduled %s", scheduled) - async def metadata(self, ttl: int = 60) -> None: - await self.queue.write_worker_metadata( - queue_key=self.queue.name, - metadata=self._metadata, - ttl=ttl, - worker_id=self.id, + async def worker_info(self, ttl: int = 60) -> WorkerInfo: + return await self.queue.worker_info( + self.id, queue_key=self.queue.name, metadata=self._metadata, ttl=ttl ) - async def stats(self, ttl: int = 60) -> WorkerStats: - return await self.queue.stats(self.id, ttl) - async def upkeep(self) -> list[Task[None]]: """Start various upkeep tasks async.""" @@ -255,8 +248,13 @@ async def poll( asyncio.create_task(poll(self.abort, self.timers["abort"])), asyncio.create_task(poll(self.schedule, self.timers["schedule"])), asyncio.create_task(poll(self.queue.sweep, self.timers["sweep"])), - asyncio.create_task(poll(self.stats, self.timers["stats"], self.timers["stats"] + 1)), - asyncio.create_task(poll(self.metadata, self.timers["metadata"])), + asyncio.create_task( + poll( + self.worker_info, + self.timers["worker_info"], + self.timers["worker_info"] + 1, + ) + ), ] async def abort(self, abort_threshold: float) -> None: diff --git a/tests/test_http_stats.py b/tests/test_http_stats.py index 2e39e18..76d9912 100644 --- a/tests/test_http_stats.py +++ b/tests/test_http_stats.py @@ -48,17 +48,17 @@ async def test_http_proxy_with_two_workers(self) -> None: queue=queue1, functions=[echo], ) - await worker.stats() + await worker.worker_info() worker2 = Worker( queue=queue2, functions=[echo], ) - await worker2.stats() + await worker2.worker_info() local_worker = Worker( queue=self.queue, functions=[echo], ) - await local_worker.stats() + await local_worker.worker_info() root_info = await self.queue.info() info1 = await queue1.info() diff --git a/tests/test_migrations.py b/tests/test_migrations.py index 86263cc..aef7ca8 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -10,7 +10,6 @@ def test_migration_versions_are_increasing(self) -> None: migrations = get_migrations( Identifier("jobs_table"), Identifier("stats_table"), - Identifier("worker_metadata_table"), ) versions = [migration[0] for migration in migrations] assert versions == sorted(versions) @@ -20,7 +19,6 @@ def test_highest_migration_equal_or_smaller_than_length(self) -> None: migrations = get_migrations( Identifier("jobs_table"), Identifier("stats_table"), - Identifier("worker_metadata_table"), ) versions = [migration[0] for migration in migrations] assert max(versions) <= len(migrations) diff --git a/tests/test_queue.py b/tests/test_queue.py index 396c989..4278b87 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -201,12 +201,12 @@ async def test_stats(self) -> None: await job.finish(Status.ABORTED) await job.finish(Status.FAILED) await job.finish(Status.COMPLETE) - stats = await worker.stats() - self.assertEqual(stats["complete"], 10) - self.assertEqual(stats["failed"], 10) - self.assertEqual(stats["retried"], 10) - self.assertEqual(stats["aborted"], 10) - self.assertGreater(stats["uptime"], 0) + worker_info = await worker.worker_info() + self.assertEqual(worker_info["stats"]["complete"], 10) + self.assertEqual(worker_info["stats"]["failed"], 10) + self.assertEqual(worker_info["stats"]["retried"], 10) + self.assertEqual(worker_info["stats"]["aborted"], 10) + self.assertGreater(worker_info["stats"]["uptime"], 0) async def test_info(self) -> None: queue2 = await self.create_queue(name=self.queue.name) @@ -226,8 +226,7 @@ async def test_info(self) -> None: await self.enqueue("echo", a=1) await queue2.enqueue("echo", a=1) await worker.process() - await worker.stats() - await worker.metadata(3) + await worker.worker_info(3) info = await self.queue.info(jobs=True) self.assertEqual(set(info["workers"].keys()), {worker.id}) @@ -239,9 +238,11 @@ async def test_info(self) -> None: self.assertEqual(len(info["jobs"]), 1) time.sleep(4) + info = await self.queue.info(jobs=True) + self.assertEqual(info["workers"], {}) await worker.queue.sweep() info = await self.queue.info(jobs=True) - self.assertEqual(info["workers"][worker.id]["metadata"], None) + self.assertEqual(info["workers"], {}) @mock.patch("saq.utils.time") async def test_schedule(self, mock_time: MagicMock) -> None: @@ -595,7 +596,7 @@ async def test_sweep_jobs(self) -> None: async def test_sweep_stats(self) -> None: worker = Worker(self.queue, functions=functions) # Stats are deleted - await worker.stats(ttl=1) + await worker.worker_info(ttl=1) await asyncio.sleep(1.5) await self.queue.sweep() async with self.queue.pool.connection() as conn, conn.cursor() as cursor: @@ -612,7 +613,7 @@ async def test_sweep_stats(self) -> None: self.assertIsNone(await cursor.fetchone()) # Stats are not deleted - await worker.stats(ttl=60) + await worker.worker_info(ttl=60) await asyncio.sleep(1) await self.queue.sweep() async with self.queue.pool.connection() as conn, conn.cursor() as cursor: