From 743a9bf9fc6b8523c68013725327ac65608bcceb Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 29 Sep 2024 17:07:12 -0500 Subject: [PATCH 01/13] feat: add `asyncpg` adapter and tests --- saq/queue/base.py | 4 +- saq/queue/postgres_asyncpg.py | 681 ++++++++++++++++++++++++++++++++++ setup.py | 3 + tests/helpers.py | 29 ++ tests/test_queue.py | 299 ++++++++++++++- 5 files changed, 1014 insertions(+), 2 deletions(-) create mode 100644 saq/queue/postgres_asyncpg.py diff --git a/saq/queue/base.py b/saq/queue/base.py index 33948d5..f01db1e 100644 --- a/saq/queue/base.py +++ b/saq/queue/base.py @@ -156,12 +156,14 @@ def from_url(url: str, **kwargs: t.Any) -> Queue: from saq.queue.redis import RedisQueue return RedisQueue.from_url(url, **kwargs) + if url.startswith("postgres+asyncpg"): + from saq.queue.postgres_asyncpg import PostgresQueue + return PostgresQueue.from_url(url.replace("postgres+asyncpg", "postgres"), **kwargs) if url.startswith("postgres"): from saq.queue.postgres import PostgresQueue return PostgresQueue.from_url(url, **kwargs) - from saq.queue.http import HttpQueue return HttpQueue.from_url(url, **kwargs) diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py new file mode 100644 index 0000000..43b4a80 --- /dev/null +++ b/saq/queue/postgres_asyncpg.py @@ -0,0 +1,681 @@ +""" +Postgres Queue using asyncpg +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import math +import time +import typing as t +from contextlib import asynccontextmanager +from textwrap import dedent + +from saq.errors import MissingDependencyError +from saq.job import ( + Job, + Status, +) +from saq.queue.base import Queue, logger +from saq.queue.postgres_ddl import DDL_STATEMENTS +from saq.utils import now, seconds + +if t.TYPE_CHECKING: + from collections.abc import Iterable + from asyncpg.pool import PoolConnectionProxy + from saq.types import ( + CountKind, + ListenCallback, + DumpType, + LoadType, + QueueInfo, + QueueStats, + ) + +try: + import asyncpg + from asyncpg import Connection, Pool +except ModuleNotFoundError as e: + raise MissingDependencyError( + "Missing dependencies for Postgres. Install them with `pip install saq[asyncpg]`." + ) from e + +ENQUEUE_CHANNEL = "saq:enqueue" +JOBS_TABLE = "saq_jobs" +STATS_TABLE = "saq_stats" + + +class PostgresQueue(Queue): + """ + Queue is used to interact with Postgres using asyncpg. + + Args: + pool: instance of asyncpg.Pool + name: name of the queue (default "default") + jobs_table: name of the Postgres table SAQ will write jobs to (default "saq_jobs") + stats_table: name of the Postgres table SAQ will write stats to (default "saq_stats") + dump: lambda that takes a dictionary and outputs bytes (default `json.dumps`) + load: lambda that takes str or bytes and outputs a python dictionary (default `json.loads`) + min_size: minimum pool size. (default 4) + The minimum number of Postgres connections. + max_size: maximum pool size. (default 20) + If greater than 0, this limits the maximum number of connections to Postgres. + Otherwise, maintain `min_size` number of connections. + poll_interval: how often to poll for jobs. (default 1) + If 0, the queue will not poll for jobs and will only rely on notifications from the server. + This mean cron jobs will not be picked up in a timely fashion. + saq_lock_keyspace: The first of two advisory lock keys used by SAQ. (default 0) + SAQ uses advisory locks for coordinating tasks between its workers, e.g. sweeping. + job_lock_keyspace: The first of two advisory lock keys used for jobs. (default 1) + """ + + @classmethod + def from_url(cls: type[PostgresQueue], url: str, **kwargs: t.Any) -> PostgresQueue: + """Create a queue from a postgres url.""" + pool = asyncpg.create_pool(dsn=url, **kwargs) + return cls(pool, **kwargs) + + def __init__( + self, + pool: Pool, + name: str = "default", + jobs_table: str = JOBS_TABLE, + stats_table: str = STATS_TABLE, + dump: DumpType | None = None, + load: LoadType | None = None, + min_size: int = 4, + max_size: int = 20, + poll_interval: int = 1, + saq_lock_keyspace: int = 0, + job_lock_keyspace: int = 1, + ) -> None: + super().__init__(name=name, dump=dump, load=load) + + self.jobs_table = jobs_table + self.stats_table = stats_table + self.pool = pool + self.min_size = min_size + self.max_size = max_size + self.poll_interval = poll_interval + self.saq_lock_keyspace = saq_lock_keyspace + self.job_lock_keyspace = job_lock_keyspace + + self.cond = asyncio.Condition() + self.queue: asyncio.Queue = asyncio.Queue() + self.waiting = 0 # Internal counter of worker tasks waiting for dequeue + self.connection: PoolConnectionProxy | None = None + self.connection_lock = asyncio.Lock() + self.released: list[str] = [] + self.has_sweep_lock = False + + async def init_db(self) -> None: + async with self.pool.acquire() as conn: + for statement in DDL_STATEMENTS: + await conn.execute(statement.format(jobs_table=self.jobs_table, stats_table=self.stats_table)) + + async def upkeep(self) -> None: + await self.init_db() + + self.tasks.add(asyncio.create_task(self.wait_for_job())) + self.tasks.add(asyncio.create_task(self.listen_for_enqueues())) + if self.poll_interval > 0: + self.tasks.add(asyncio.create_task(self.dequeue_timer(self.poll_interval))) + + async def connect(self) -> None: + await self.pool # ensure the pool is created sync `from_url` does not await the `create_pool` return` + if self.connection: + # If connection exists, connect() was already called + return + + # Reserve a connection for dequeue and advisory locks + self.connection = await self.pool.acquire() + + def serialize(self, job: Job) -> bytes | str: + """Ensure serialized job is in bytes because the job column is of type BYTEA.""" + serialized = self._dump(job.to_dict()) + if isinstance(serialized, str): + return serialized.encode("utf-8") + return serialized + + async def disconnect(self) -> None: + async with self.connection_lock: + if self.connection: + await self.pool.release(self.connection) + self.connection = None + await self.pool.close() + self.has_sweep_lock = False + + async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: + async with self.pool.acquire() as conn: + results = await conn.fetch( + f""" + SELECT worker_id, stats FROM {self.stats_table} + WHERE $1 <= expire_at + """, + seconds(now()), + ) + workers: dict[str, dict[str, t.Any]] = {row['worker_id']: json.loads(row['stats']) for row in results} + + queued = await self.count("queued") + active = await self.count("active") + incomplete = await self.count("incomplete") + + if jobs: + async with self.pool.acquire() as conn: + results = await conn.fetch( + f""" + SELECT job FROM {self.jobs_table} + WHERE status IN ('new', 'deferred', 'queued', 'active') + """ + ) + deserialized_jobs = (self.deserialize(result['job']) for result in results) + jobs_info = [job.to_dict() for job in deserialized_jobs if job] + else: + jobs_info = [] + + return { + "workers": workers, + "name": self.name, + "queued": queued, + "active": active, + "scheduled": incomplete - queued - active, + "jobs": jobs_info, + } + + async def count(self, kind: CountKind) -> int: + async with self.pool.acquire() as conn: + if kind == "queued": + result = await conn.fetchval( + f""" + SELECT count(*) FROM {self.jobs_table} + WHERE status = 'queued' + AND queue = $1 + AND $2 >= scheduled + """, + self.name, + math.ceil(seconds(now())), + ) + elif kind == "active": + result = await conn.fetchval( + f""" + SELECT count(*) FROM {self.jobs_table} + WHERE status = 'active' + AND queue = $1 + """, + self.name, + ) + elif kind == "incomplete": + result = await conn.fetchval( + f""" + SELECT count(*) FROM {self.jobs_table} + WHERE status IN ('new', 'deferred', 'queued', 'active') + AND queue = $1 + """, + self.name, + ) + else: + raise ValueError(f"Can't count unknown type {kind}") + + return result + + async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: + """Delete jobs and stats past their expiration and sweep stuck jobs""" + swept = [] + + if not self.has_sweep_lock: + # Attempt to get the sweep lock and hold on to it + async with self._get_connection() as conn: + result = await conn.fetchval( + "SELECT pg_try_advisory_lock($1, hashtext($2))", + self.saq_lock_keyspace, + self.name, + ) + if not result: + # Could not acquire the sweep lock so another worker must already have it + return [] + self.has_sweep_lock = True + + async with self.pool.acquire() as conn: + await conn.execute( + f""" + -- Delete expired jobs + DELETE FROM {self.jobs_table} + WHERE queue = $1 + AND status IN ('aborted', 'complete', 'failed') + AND $2 >= expire_at; + """, + self.name, + math.ceil(seconds(now())), + ) + + await conn.execute( + f""" + -- Delete expired stats + DELETE FROM {self.stats_table} + WHERE $1 >= expire_at; + """, + math.ceil(seconds(now())), + ) + + results = await conn.fetch( + f""" + -- Fetch active and aborting jobs without advisory locks + WITH locks AS ( + SELECT objid + FROM pg_locks + WHERE locktype = 'advisory' + AND classid = $1 + AND objsubid = 2 -- key is int pair, not single bigint + ) + SELECT key, job, objid + FROM {self.jobs_table} LEFT OUTER JOIN locks ON lock_key = objid + WHERE queue = $2 + AND status IN ('active', 'aborting'); + """, + self.job_lock_keyspace, + self.name, + ) + for row in results: + key, job_bytes, objid = row['key'], row['job'], row['objid'] + job = self.deserialize(job_bytes) + assert job + if objid and not job.stuck: + continue + + swept.append(key) + await self.abort(job, error="swept") + + try: + await job.refresh(abort) + except asyncio.TimeoutError: + logger.info("Could not abort job %s", key) + + logger.info("Sweeping job %s", job.info(logger.isEnabledFor(logging.DEBUG))) + if job.retryable: + await self.retry(job, error="swept") + else: + await self.finish(job, Status.ABORTED, error="swept") + return swept + + async def listen( + self, + job_keys: Iterable[str], + callback: ListenCallback, + timeout: float | None = 10, + ) -> None: + if not job_keys: + return + + async def _listen(conn: PoolConnectionProxy | Connection, pid: int, channel: str, payload: str) -> None: + payload_data = json.loads(payload) + key = payload_data["key"] + status = Status[payload_data["status"].upper()] + if asyncio.iscoroutinefunction(callback): + _ = await callback(key, status) + else: + _ = callback(key, status) + + async with self.pool.acquire() as conn: + for key in job_keys: + await conn.add_listener(key, _listen) + + if timeout: + try: + await asyncio.sleep(timeout) + finally: + for key in job_keys: + await conn.remove_listener(key, _listen) + else: + # If no timeout, we'll keep listening indefinitely + while True: + await asyncio.sleep(30) # Sleep for 30 seconds and then check again + + async def notify(self, job: Job, connection: PoolConnectionProxy | None = None) -> None: + payload = json.dumps({"key": job.key, "status": job.status}) + await self._notify(job.key, payload, connection) + + async def update( + self, + job: Job, + connection: PoolConnectionProxy | None = None, + expire_at: float | None = -1, + **kwargs: t.Any, + ) -> None: + job.touched = now() + + for k, v in kwargs.items(): + setattr(job, k, v) + + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + if expire_at != -1: + await conn.execute( + f""" + UPDATE {self.jobs_table} SET + job = $1, + status = $2, + scheduled = $3, + expire_at = $4 + WHERE key = $5 + """, + self.serialize(job), + job.status, + job.scheduled, + expire_at, + job.key, + ) + else: + await conn.execute( + f""" + UPDATE {self.jobs_table} SET + job = $1, + status = $2, + scheduled = $3 + WHERE key = $4 + """, + self.serialize(job), + job.status, + job.scheduled, + job.key, + ) + await self.notify(job, conn) + + async def job(self, job_key: str) -> Job | None: + async with self.pool.acquire() as conn: + result = await conn.fetchrow( + f""" + SELECT job + FROM {self.jobs_table} + WHERE key = $1 + """, + job_key, + ) + if result: + return self.deserialize(result['job']) + return None + + async def jobs(self, job_keys: Iterable[str]) -> t.List[Job | None]: + keys = list(job_keys) + + async with self.pool.acquire() as conn: + results = await conn.fetch( + f""" + SELECT key, job + FROM {self.jobs_table} + WHERE key = ANY($1) + """, + keys, + ) + job_dict = {row['key']: row['job'] for row in results} + return [self.deserialize(job_dict.get(key)) for key in keys] + + async def iter_jobs( + self, + statuses: t.List[Status] = list(Status), + batch_size: int = 100, + ) -> t.AsyncIterator[Job]: + async with self.pool.acquire() as conn: + last_key = "" + + while True: + results = await conn.fetch( + f""" + SELECT key, job + FROM {self.jobs_table} + WHERE + status = ANY($1) + AND queue = $2 + AND key > $3 + ORDER BY key + LIMIT $4 + """, + [status.value for status in statuses], + self.name, + last_key, + batch_size, + ) + + if not results: + break + + for row in results: + last_key = row['key'] + job = self.deserialize(row['job']) + if job: + yield job + + async def abort(self, job: Job, error: str, ttl: float = 5) -> None: + async with self.pool.acquire() as conn: + status = await self.get_job_status(job.key, for_update=True, connection=conn) + if status == Status.QUEUED: + await self.finish(job, Status.ABORTED, error=error, connection=conn) + else: + await self.update(job, status=Status.ABORTING, error=error, connection=conn) + + async def dequeue(self, timeout: float = 0) -> Job | None: + """Wait on `self.cond` to dequeue. + + Retries indefinitely until a job is available or times out. + """ + self.waiting += 1 + async with self.cond: + self.cond.notify(1) + + try: + return await asyncio.wait_for(self.queue.get(), timeout or None) + except asyncio.exceptions.TimeoutError: + return None + finally: + self.waiting -= 1 + + async def wait_for_job(self) -> None: + while True: + async with self.cond: + await self.cond.wait() + + for job in await self._dequeue(): + await self.queue.put(job) + + async def _enqueue(self, job: Job) -> Job | None: + async with self.pool.acquire() as conn: + result = await conn.fetchrow( + f""" + INSERT INTO {self.jobs_table} (key, job, queue, status, scheduled) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (key) DO UPDATE + SET job = $2, queue = $3, status = $4, scheduled = $5, expire_at = null + WHERE {self.jobs_table}.status IN ('aborted', 'complete', 'failed') + AND $5 > {self.jobs_table}.scheduled + RETURNING job + """, + job.key, + self.serialize(job), + self.name, + job.status, + job.scheduled or seconds(now()), + ) + + if not result: + return None + + await self._notify(ENQUEUE_CHANNEL, job.key, conn) + + logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) + return job + + async def write_stats(self, stats: QueueStats, ttl: int) -> None: + async with self.pool.acquire() as conn: + await conn.execute( + f""" + INSERT INTO {self.stats_table} (worker_id, stats, expire_at) + VALUES ($1, $2, $3) + ON CONFLICT (worker_id) DO UPDATE + SET stats = $2, expire_at = $3 + """, + self.uuid, + json.dumps(stats), + seconds(now()) + ttl, + ) + + async def dequeue_timer(self, poll_interval: int) -> None: + """Wakes up a single dequeue task every `poll_interval` seconds.""" + while True: + async with self.cond: + self.cond.notify(1) + await asyncio.sleep(poll_interval) + + async def listen_for_enqueues(self, timeout: float | None = None) -> None: + """Wakes up a single dequeue task when a Postgres enqueue notification is received.""" + async def _listen(conn: PoolConnectionProxy | Connection, pid: int, channel: str, payload: str) -> None: + async with self.cond: + self.cond.notify(1) + async with self.pool.acquire() as conn: + await conn.add_listener(ENQUEUE_CHANNEL, _listen) + if timeout: + try: + await asyncio.sleep(timeout) + finally: + await conn.remove_listener(ENQUEUE_CHANNEL, _listen) + else: + # If no timeout, we'll keep listening indefinitely + while True: + await asyncio.sleep(30) # Sleep for 30 seconds and then check again + + async def get_job_status( + self, + key: str, + for_update: bool = False, + connection: PoolConnectionProxy | None = None, + ) -> Status: + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + result = await conn.fetchval( + f""" + SELECT status + FROM {self.jobs_table} + WHERE key = $1 + {"FOR UPDATE" if for_update else ""} + """, + key, + ) + assert result + return result + + async def _retry(self, job: Job, error: str | None) -> None: + next_retry_delay = job.next_retry_delay() + if next_retry_delay: + scheduled = time.time() + next_retry_delay + else: + scheduled = job.scheduled or seconds(now()) + + await self.update(job, scheduled=int(scheduled), expire_at=None) + + async def _finish( + self, + job: Job, + status: Status, + *, + result: t.Any = None, + error: str | None = None, + connection: PoolConnectionProxy | None = None, + ) -> None: + key = job.key + + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + if job.ttl >= 0: + expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None + await self.update(job, status=status, expire_at=expire_at, connection=conn) + else: + await conn.execute( + f""" + DELETE FROM {self.jobs_table} + WHERE key = $1 + """, + key, + ) + await self.notify(job, conn) + await self._release_job(key) + try: + self.queue.task_done() + except ValueError: + # Error because task_done() called too many times, which happens in unit tests + pass + + async def _dequeue(self) -> list[Job]: + if not self.waiting: + return [] + jobs = [] + async with self._get_connection() as conn: + async with conn.transaction(): + results = await conn.fetch( + f""" + WITH locked_job AS ( + SELECT key, lock_key + FROM {self.jobs_table} + WHERE status = 'queued' + AND queue = $1 + AND $2 >= scheduled + ORDER BY scheduled + LIMIT $3 + FOR UPDATE SKIP LOCKED + ) + UPDATE {self.jobs_table} SET status = 'active' + FROM locked_job + WHERE {self.jobs_table}.key = locked_job.key + AND pg_try_advisory_lock($4, locked_job.lock_key) + RETURNING job + """, + self.name, + math.ceil(seconds(now())), + self.waiting, + self.job_lock_keyspace, + ) + for result in results: + job = self.deserialize(result['job']) + if job: + await self.update(job, status=Status.ACTIVE, connection=conn) + jobs.append(job) + return jobs + + async def _notify( + self, channel: str, payload: t.Any, connection: PoolConnectionProxy | None = None + ) -> None: + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + await conn.execute(f"NOTIFY \"{channel}\", '{payload}'") + + @asynccontextmanager + async def _get_connection(self) -> t.AsyncGenerator: + assert self.connection + async with self.connection_lock: + try: + # Pool normally performs this check when getting a connection. + await self.connection.execute("SELECT 1") + except asyncpg.exceptions.ConnectionDoesNotExistError: + # The connection is bad so return it to the pool and get a new one. + await self.pool.release(self.connection) + self.connection = await self.pool.acquire() + yield self.connection + + @asynccontextmanager + async def nullcontext(self, enter_result: t.Any | None = None) -> t.AsyncGenerator: + """Async version of contextlib.nullcontext + + Async support has been added to contextlib.nullcontext in Python 3.10. + """ + yield enter_result + + async def _release_job(self, key: str) -> None: + self.released.append(key) + if self.connection_lock.locked(): + return + async with self._get_connection() as conn: + await conn.execute( + f""" + SELECT pg_advisory_unlock($1, lock_key) + FROM {self.jobs_table} + WHERE key = ANY($2) + """, + self.job_lock_keyspace, + self.released, + ) + self.released.clear() \ No newline at end of file diff --git a/setup.py b/setup.py index 093645f..cd7b2a8 100644 --- a/setup.py +++ b/setup.py @@ -37,10 +37,12 @@ "hiredis": ["redis[hiredis]>=4.2.0"], "http": ["aiohttp"], "postgres": ["psycopg[pool]>=3.2.0"], + "asyncpg": ["asyncpg"], "redis": ["redis>=4.2,<6.0"], "web": ["aiohttp", "aiohttp_basicauth"], "dev": [ "aiohttp", + "asyncpg", "aiohttp_basicauth", "coverage", "mypy", @@ -50,6 +52,7 @@ "ruff", "types-croniter", "types-redis", + "asyncpg_stubs", "types-setuptools", "starlette", "httpx", diff --git a/tests/helpers.py b/tests/helpers.py index 2cad925..1fc2b18 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,10 +1,12 @@ import asyncio import typing as t +import asyncpg import psycopg from saq.queue import Queue from saq.queue.postgres import PostgresQueue +from saq.queue.postgres_asyncpg import PostgresQueue as AsyncpgPostgresQueue from saq.queue.redis import RedisQueue POSTGRES_TEST_SCHEMA = "test_saq" @@ -49,3 +51,30 @@ async def teardown_postgres() -> None: "postgres://postgres@localhost", autocommit=True ) as conn: await conn.execute(f"DROP SCHEMA {POSTGRES_TEST_SCHEMA} CASCADE") + + + + +async def create_postgres_asyncpg_queue(**kwargs: t.Any) -> AsyncpgPostgresQueue: + queue = t.cast( + AsyncpgPostgresQueue, + Queue.from_url( + f"postgres+asyncpg://postgres@localhost?options=--search_path%3D{POSTGRES_TEST_SCHEMA}", + **kwargs, + ), + ) + await queue.connect() + await queue.upkeep() + await asyncio.sleep(0.1) # Give some time for the tasks to start + return queue + +async def setup_postgres_asyncpg() -> None: + async with asyncpg.create_pool( + "postgres://postgres@localhost", min_size=1, max_size=10, command_timeout=60 + ) as pool: + await pool.execute(f"CREATE SCHEMA IF NOT EXISTS {POSTGRES_TEST_SCHEMA}") + + +async def teardown_postgres_asyncpg() -> None: + async with asyncpg.create_pool("postgres://postgres@localhost") as pool: + await pool.execute(f"DROP SCHEMA {POSTGRES_TEST_SCHEMA} CASCADE") diff --git a/tests/test_queue.py b/tests/test_queue.py index 48806bf..17e1cdc 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -19,7 +19,10 @@ create_postgres_queue, create_redis_queue, setup_postgres, + setup_postgres_asyncpg, teardown_postgres, + teardown_postgres_asyncpg, + create_postgres_asyncpg_queue ) @@ -27,6 +30,7 @@ from unittest.mock import MagicMock from saq.queue.postgres import PostgresQueue + from saq.queue.postgres_asyncpg import PostgresQueue as AsyncpgPostgresQueue from saq.queue.redis import RedisQueue from saq.types import Context, CountKind, Function @@ -88,7 +92,7 @@ async def test_enqueue_job_str(self) -> None: async def test_enqueue_dup(self) -> None: job = await self.enqueue("test", key="1") - self.assertEqual(job.id, "saq:job:default:1") + self.assertEqual(job.id, "saq:job:default$1") self.assertIsNone(await self.queue.enqueue("test", key="1")) self.assertIsNone(await self.queue.enqueue(job)) @@ -754,3 +758,296 @@ async def test_priority(self) -> None: assert await self.enqueue("test", priority=-1) self.assertEqual(await self.count("queued"), 1) assert not await self.queue.dequeue(0.01) + + +class TestAsyncpgPostgresQueue(TestQueue): + async def asyncSetUp(self) -> None: + await setup_postgres_asyncpg() + self.create_queue = create_postgres_asyncpg_queue + self.queue: AsyncpgPostgresQueue = await self.create_queue() + + async def asyncTearDown(self) -> None: + await super().asyncTearDown() + await teardown_postgres_asyncpg() + + @unittest.skip("Not implemented") + async def test_job_key(self) -> None: + pass + + @unittest.skip("Not implemented") + @mock.patch("saq.utils.time") + async def test_schedule(self, mock_time: MagicMock) -> None: + pass + + async def test_enqueue_dup(self) -> None: + job = await self.enqueue("test", key="1") + self.assertEqual(job.id, "1") + self.assertIsNone(await self.queue.enqueue("test", key="1")) + self.assertIsNone(await self.queue.enqueue(job)) + + async def test_abort(self) -> None: + job = await self.enqueue("test", retries=2) + self.assertEqual(await self.count("queued"), 1) + self.assertEqual(await self.count("incomplete"), 1) + await self.queue.abort(job, "test") + self.assertEqual(await self.count("queued"), 0) + self.assertEqual(await self.count("incomplete"), 0) + await job.refresh() + self.assertEqual(job.status, Status.ABORTED) + self.assertEqual(await self.queue.get_job_status(job.key), Status.ABORTED) + + job = await self.enqueue("test", retries=2) + await self.dequeue() + self.assertEqual(await self.count("queued"), 0) + self.assertEqual(await self.count("incomplete"), 1) + self.assertEqual(await self.count("active"), 1) + await self.queue.abort(job, "test") + self.assertEqual(await self.count("queued"), 0) + self.assertEqual(await self.count("incomplete"), 0) + self.assertEqual(await self.count("active"), 0) + self.assertEqual(await self.queue.get_job_status(job.key), Status.ABORTING) + + @mock.patch("saq.utils.time") + async def test_sweep(self, mock_time: MagicMock) -> None: + mock_time.time.return_value = 1 + job1 = await self.enqueue("test", heartbeat=1, retries=0) + job2 = await self.enqueue("test", timeout=1) + await self.enqueue("test", timeout=2) + await self.enqueue("test", heartbeat=2) + job3 = await self.enqueue("test", timeout=1) + for _ in range(4): + job = await self.dequeue() + job.status = Status.ACTIVE + job.started = 1000 + await self.queue.update(job) + await self.dequeue() + + mock_time.time.return_value = 3 + self.assertEqual(await self.count("active"), 5) + swept = await self.queue.sweep(abort=0.01) + self.assertEqual( + set(swept), + { + job1.key, + job2.key, + job3.key, + }, + ) + await job1.refresh() + await job2.refresh() + await job3.refresh() + self.assertEqual(job1.status, Status.ABORTED) + self.assertEqual(job2.status, Status.QUEUED) + self.assertEqual(job3.status, Status.QUEUED) + self.assertEqual(await self.count("active"), 2) + + @mock.patch("saq.utils.time") + async def test_sweep_stuck(self, mock_time: MagicMock) -> None: + job1 = await self.queue.enqueue("test") + assert job1 + job = await self.dequeue() + job.status = Status.ACTIVE + job.started = 1000 + await self.queue.update(job) + + # Enqueue 2 more jobs that will become stuck + job2 = await self.queue.enqueue("test", retries=0) + assert job2 + job3 = await self.queue.enqueue("test") + assert job3 + + another_queue = await self.create_queue() + for _ in range(2): + job = await another_queue.dequeue() + job.status = Status.ACTIVE + job.started = 1000 + await another_queue.update(job) + + # Disconnect another_queue to simulate worker going down + await another_queue.disconnect() + + mock_time.time.return_value = 3 + self.assertEqual(await self.count("active"), 3) + swept = await self.queue.sweep(abort=0.01) + self.assertEqual( + set(swept), + { + job2.id, + job3.id, + }, + ) + await job1.refresh() + await job2.refresh() + await job3.refresh() + self.assertEqual(job1.status, Status.ACTIVE) + self.assertEqual(job2.status, Status.ABORTED) + self.assertEqual(job3.status, Status.QUEUED) + self.assertEqual(await self.count("active"), 1) + + async def test_sweep_jobs(self) -> None: + job1 = await self.enqueue("test", ttl=1) + job2 = await self.enqueue("test", ttl=60) + await self.queue.finish(job1, Status.COMPLETE) + await self.queue.finish(job2, Status.COMPLETE) + await asyncio.sleep(1) + + await self.queue.sweep() + with self.assertRaisesRegex(RuntimeError, "doesn't exist"): + await job1.refresh() + await job2.refresh() + self.assertEqual(job2.status, Status.COMPLETE) + + async def test_sweep_stats(self) -> None: + # Stats are deleted + await self.queue.stats(ttl=1) + await asyncio.sleep(1) + await self.queue.sweep() + async with self.queue.pool.acquire() as conn: + result = await conn.fetchrow( + + """ + SELECT stats + FROM {} + WHERE worker_id = %s + """.format(self.queue.stats_table), + (self.queue.uuid,), + ) + self.assertIsNone(result) + + # Stats are not deleted + await self.queue.stats(ttl=60) + await asyncio.sleep(1) + await self.queue.sweep() + async with self.queue.pool.acquire() as conn : + result = await conn.fetchrow( + + """ + SELECT stats + FROM {} + WHERE worker_id = %s + """.format(self.queue.stats_table), + (self.queue.uuid,), + ) + self.assertIsNotNone(result) + + async def test_job_lock(self) -> None: + query = """ + SELECT count(*) + FROM {} JOIN pg_locks ON lock_key = objid + WHERE key = $key + AND classid = {} + AND objsubid = 2 -- key is int pair, not single bigint + """.format(self.queue.jobs_table, self.queue.job_lock_keyspace) + job = await self.enqueue("test") + await self.dequeue() + async with self.queue.pool.acquire() as conn : + result = await conn.fetchval(query, {"key": job.key}) + self.assertEqual(result, 1) + + await self.finish(job, Status.COMPLETE, result=1) + async with self.queue.pool.acquire() as conn : + result = await conn.execute(query, {"key": job.key}) + self.assertEqual(result, (0,)) + + async def test_load_dump_pickle(self) -> None: + self.queue = await self.create_queue(dump=pickle.dumps, load=pickle.loads) + job = await self.enqueue("test") + + async with self.queue.pool.acquire() as conn : + result = await conn.fetchrow( + """ + SELECT job + FROM {} + WHERE key =$1 + """ .format(self.queue.jobs_table), + job.key, + ) + assert result + fetched_job = pickle.loads(result[0]) + self.assertIsInstance(fetched_job, dict) + self.assertEqual(fetched_job["key"], job.key) + + dequeued_job = await self.dequeue() + self.assertEqual(dequeued_job, job) + + @mock.patch("saq.utils.time") + async def test_finish_ttl_positive(self, mock_time: MagicMock) -> None: + mock_time.time.return_value = 0 + job = await self.enqueue("test", ttl=5) + await self.dequeue() + await self.finish(job, Status.COMPLETE) + async with self.queue.pool.acquire() as conn : + result = await conn.fetchval( + + """ + SELECT expire_at + FROM {} + WHERE key = $1 + """ .format(self.queue.jobs_table), + job.key, + ) + self.assertEqual(result,5) + + @mock.patch("saq.utils.time") + async def test_finish_ttl_neutral(self, mock_time: MagicMock) -> None: + mock_time.time.return_value = 0 + job = await self.enqueue("test", ttl=0) + await self.dequeue() + await self.finish(job, Status.COMPLETE) + async with self.queue.pool.acquire() as conn : + result = await conn.fetchval( + + """ + SELECT expire_at + FROM {} + WHERE key = $1 + """ .format(self.queue.jobs_table), + job.key, + ) + self.assertEqual(result,None) + + @mock.patch("saq.utils.time") + async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: + mock_time.time.return_value = 0 + job = await self.enqueue("test", ttl=-1) + await self.dequeue() + await self.finish(job, Status.COMPLETE) + async with self.queue.pool.acquire() as conn : + result = await conn.fetchval( + """ + SELECT expire_at + FROM {} + WHERE key = $1 + """ .format(self.queue.jobs_table), + job.key, + ) + self.assertIsNone(result) + + async def test_bad_connection(self) -> None: + job = await self.enqueue("test") + original_connection = self.queue.connection + await self.queue.connection.close() + # Test dequeue still works + self.assertEqual((await self.dequeue()), job) + # Check queue has a new connection + self.assertNotEqual(original_connection, self.queue.connection) + + + async def test_group_key(self) -> None: + job1 = await self.enqueue("test", group_key=1) + assert job1 + job2 = await self.enqueue("test", group_key=1) + assert job2 + self.assertEqual(await self.count("queued"), 2) + + assert await self.dequeue() + self.assertEqual(await self.count("queued"), 1) + assert not await self.queue.dequeue(0.01) + await job1.update(status="finished") + assert await self.dequeue() + + async def test_priority(self) -> None: + assert await self.enqueue("test", priority=-1) + self.assertEqual(await self.count("queued"), 1) + assert not await self.queue.dequeue(0.01) + From b135fc33acc0793cc7482c579cef0af5c2509928 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 6 Oct 2024 10:25:00 -0500 Subject: [PATCH 02/13] feat: merge latest changes --- saq/queue/base.py | 5 +- saq/queue/postgres_asyncpg.py | 580 ++++++++++++++-------------------- tests/helpers.py | 5 +- tests/test_queue.py | 2 +- 4 files changed, 235 insertions(+), 357 deletions(-) diff --git a/saq/queue/base.py b/saq/queue/base.py index f01db1e..542cb03 100644 --- a/saq/queue/base.py +++ b/saq/queue/base.py @@ -157,9 +157,8 @@ def from_url(url: str, **kwargs: t.Any) -> Queue: return RedisQueue.from_url(url, **kwargs) if url.startswith("postgres+asyncpg"): - from saq.queue.postgres_asyncpg import PostgresQueue - - return PostgresQueue.from_url(url.replace("postgres+asyncpg", "postgres"), **kwargs) + from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue + return PostgresAsyncpgQueue.from_url(url.replace("postgres+asyncpg", "postgres"), **kwargs) if url.startswith("postgres"): from saq.queue.postgres import PostgresQueue diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py index 43b4a80..436dda2 100644 --- a/saq/queue/postgres_asyncpg.py +++ b/saq/queue/postgres_asyncpg.py @@ -14,72 +14,45 @@ from textwrap import dedent from saq.errors import MissingDependencyError -from saq.job import ( - Job, - Status, -) +from saq.job import Job, Status +from saq.multiplexer import Multiplexer from saq.queue.base import Queue, logger from saq.queue.postgres_ddl import DDL_STATEMENTS from saq.utils import now, seconds if t.TYPE_CHECKING: from collections.abc import Iterable - from asyncpg.pool import PoolConnectionProxy - from saq.types import ( - CountKind, - ListenCallback, - DumpType, - LoadType, - QueueInfo, - QueueStats, - ) - + from saq.types import CountKind, DumpType, LoadType, QueueInfo, QueueStats try: - import asyncpg - from asyncpg import Connection, Pool + from asyncpg.pool import PoolConnectionProxy + from asyncpg import Pool, create_pool + from asyncpg.exceptions import ConnectionDoesNotExistError + except ModuleNotFoundError as e: raise MissingDependencyError( "Missing dependencies for Postgres. Install them with `pip install saq[asyncpg]`." ) from e -ENQUEUE_CHANNEL = "saq:enqueue" +CHANNEL = "saq:{}" +ENQUEUE = "saq:enqueue" JOBS_TABLE = "saq_jobs" STATS_TABLE = "saq_stats" -class PostgresQueue(Queue): +class PostgresAsyncpgQueue(Queue): """ Queue is used to interact with Postgres using asyncpg. - - Args: - pool: instance of asyncpg.Pool - name: name of the queue (default "default") - jobs_table: name of the Postgres table SAQ will write jobs to (default "saq_jobs") - stats_table: name of the Postgres table SAQ will write stats to (default "saq_stats") - dump: lambda that takes a dictionary and outputs bytes (default `json.dumps`) - load: lambda that takes str or bytes and outputs a python dictionary (default `json.loads`) - min_size: minimum pool size. (default 4) - The minimum number of Postgres connections. - max_size: maximum pool size. (default 20) - If greater than 0, this limits the maximum number of connections to Postgres. - Otherwise, maintain `min_size` number of connections. - poll_interval: how often to poll for jobs. (default 1) - If 0, the queue will not poll for jobs and will only rely on notifications from the server. - This mean cron jobs will not be picked up in a timely fashion. - saq_lock_keyspace: The first of two advisory lock keys used by SAQ. (default 0) - SAQ uses advisory locks for coordinating tasks between its workers, e.g. sweeping. - job_lock_keyspace: The first of two advisory lock keys used for jobs. (default 1) """ @classmethod - def from_url(cls: type[PostgresQueue], url: str, **kwargs: t.Any) -> PostgresQueue: + def from_url(cls: type[PostgresAsyncpgQueue], url: str, **kwargs: t.Any) -> PostgresAsyncpgQueue: # pyright: ignore[reportIncompatibleMethodOverride] """Create a queue from a postgres url.""" - pool = asyncpg.create_pool(dsn=url, **kwargs) - return cls(pool, **kwargs) + pool = create_pool(dsn=url, **kwargs, ) + return cls(t.cast("Pool[t.Any]", pool), **kwargs) def __init__( self, - pool: Pool, + pool: Pool[t.Any], name: str = "default", jobs_table: str = JOBS_TABLE, stats_table: str = STATS_TABLE, @@ -102,50 +75,42 @@ def __init__( self.saq_lock_keyspace = saq_lock_keyspace self.job_lock_keyspace = job_lock_keyspace - self.cond = asyncio.Condition() - self.queue: asyncio.Queue = asyncio.Queue() - self.waiting = 0 # Internal counter of worker tasks waiting for dequeue - self.connection: PoolConnectionProxy | None = None - self.connection_lock = asyncio.Lock() - self.released: list[str] = [] - self.has_sweep_lock = False + self._job_queue: asyncio.Queue = asyncio.Queue() + self._waiting = 0 + self._dequeue_conn: PoolConnectionProxy | None = None + self._connection_lock = asyncio.Lock() + self._releasing: list[str] = [] + self._has_sweep_lock = False + self._channel = CHANNEL.format(self.name) + self._listener = ListenMultiplexer(self.pool, self._channel) + self._dequeue_lock = asyncio.Lock() + self._listen_lock = asyncio.Lock() async def init_db(self) -> None: async with self.pool.acquire() as conn: for statement in DDL_STATEMENTS: await conn.execute(statement.format(jobs_table=self.jobs_table, stats_table=self.stats_table)) - async def upkeep(self) -> None: - await self.init_db() - - self.tasks.add(asyncio.create_task(self.wait_for_job())) - self.tasks.add(asyncio.create_task(self.listen_for_enqueues())) - if self.poll_interval > 0: - self.tasks.add(asyncio.create_task(self.dequeue_timer(self.poll_interval))) - async def connect(self) -> None: - await self.pool # ensure the pool is created sync `from_url` does not await the `create_pool` return` - if self.connection: - # If connection exists, connect() was already called + if self._dequeue_conn: return - - # Reserve a connection for dequeue and advisory locks - self.connection = await self.pool.acquire() + await self.pool + self._dequeue_conn = await self.pool.acquire() + await self.init_db() def serialize(self, job: Job) -> bytes | str: - """Ensure serialized job is in bytes because the job column is of type BYTEA.""" serialized = self._dump(job.to_dict()) if isinstance(serialized, str): return serialized.encode("utf-8") return serialized async def disconnect(self) -> None: - async with self.connection_lock: - if self.connection: - await self.pool.release(self.connection) - self.connection = None + async with self._connection_lock: + if self._dequeue_conn: + await self.pool.release(self._dequeue_conn) + self._dequeue_conn = None await self.pool.close() - self.has_sweep_lock = False + self._has_sweep_lock = False async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: async with self.pool.acquire() as conn: @@ -194,8 +159,7 @@ async def count(self, kind: CountKind) -> int: AND queue = $1 AND $2 >= scheduled """, - self.name, - math.ceil(seconds(now())), + self.name, math.ceil(seconds(now())), ) elif kind == "active": result = await conn.fetchval( @@ -220,83 +184,80 @@ async def count(self, kind: CountKind) -> int: return result + async def schedule(self, lock: int = 1) -> t.List[str]: + await self._dequeue() + return [] + async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: - """Delete jobs and stats past their expiration and sweep stuck jobs""" swept = [] - if not self.has_sweep_lock: - # Attempt to get the sweep lock and hold on to it - async with self._get_connection() as conn: + if not self._has_sweep_lock: + async with self._get_dequeue_conn() as conn: result = await conn.fetchval( "SELECT pg_try_advisory_lock($1, hashtext($2))", - self.saq_lock_keyspace, - self.name, + self.saq_lock_keyspace, self.name, ) if not result: - # Could not acquire the sweep lock so another worker must already have it return [] - self.has_sweep_lock = True + self._has_sweep_lock = True async with self.pool.acquire() as conn: await conn.execute( - f""" - -- Delete expired jobs + dedent(f""" DELETE FROM {self.jobs_table} WHERE queue = $1 - AND status IN ('aborted', 'complete', 'failed') - AND $2 >= expire_at; - """, - self.name, - math.ceil(seconds(now())), + AND status in ('aborted', 'complete', 'failed') + AND $2 >= expire_at + """), + math.ceil(seconds(now())),self.name ) - await conn.execute( - f""" - -- Delete expired stats + dedent(f""" DELETE FROM {self.stats_table} WHERE $1 >= expire_at; - """, + """), math.ceil(seconds(now())), ) - results = await conn.fetch( - f""" - -- Fetch active and aborting jobs without advisory locks - WITH locks AS ( - SELECT objid - FROM pg_locks - WHERE locktype = 'advisory' - AND classid = $1 - AND objsubid = 2 -- key is int pair, not single bigint - ) - SELECT key, job, objid - FROM {self.jobs_table} LEFT OUTER JOIN locks ON lock_key = objid - WHERE queue = $2 - AND status IN ('active', 'aborting'); - """, - self.job_lock_keyspace, - self.name, - ) - for row in results: - key, job_bytes, objid = row['key'], row['job'], row['objid'] - job = self.deserialize(job_bytes) - assert job - if objid and not job.stuck: - continue - - swept.append(key) - await self.abort(job, error="swept") - - try: - await job.refresh(abort) - except asyncio.TimeoutError: - logger.info("Could not abort job %s", key) - - logger.info("Sweeping job %s", job.info(logger.isEnabledFor(logging.DEBUG))) - if job.retryable: - await self.retry(job, error="swept") - else: - await self.finish(job, Status.ABORTED, error="swept") + + dedent( + f""" + WITH locks AS ( + SELECT objid + FROM pg_locks + WHERE locktype = 'advisory' + AND classid = $1 + AND objsubid = 2 -- key is int pair, not single bigint + ) + SELECT key, job, objid, status + FROM {self.jobs_table} + LEFT OUTER JOIN locks + ON lock_key = objid + WHERE queue = $2 + AND status IN ('active', 'aborting'); + """) , self.name, self.job_lock_keyspace, + + ) + + for key, job_bytes, objid, status in results: + job = self.deserialize(job_bytes) + assert job + if objid and not job.stuck: + continue + + swept.append(key) + await self.abort(job, error="swept") + + try: + await job.refresh(abort) + except asyncio.TimeoutError: + logger.info("Could not abort job %s", key) + + logger.info("Sweeping job %s", job.info(logger.isEnabledFor(logging.DEBUG))) + if job.retryable: + await self.retry(job, error="swept") + else: + await self.finish(job, Status.ABORTED, error="swept") return swept async def listen( @@ -308,143 +269,69 @@ async def listen( if not job_keys: return - async def _listen(conn: PoolConnectionProxy | Connection, pid: int, channel: str, payload: str) -> None: - payload_data = json.loads(payload) - key = payload_data["key"] - status = Status[payload_data["status"].upper()] + async for message in self._listener.listen(*job_keys, timeout=timeout): + job_key = message["key"] + status = Status[message["data"].upper()] if asyncio.iscoroutinefunction(callback): - _ = await callback(key, status) + stop = await callback(job_key, status) else: - _ = callback(key, status) - - async with self.pool.acquire() as conn: - for key in job_keys: - await conn.add_listener(key, _listen) - - if timeout: - try: - await asyncio.sleep(timeout) - finally: - for key in job_keys: - await conn.remove_listener(key, _listen) - else: - # If no timeout, we'll keep listening indefinitely - while True: - await asyncio.sleep(30) # Sleep for 30 seconds and then check again + stop = callback(job_key, status) + if stop: + break async def notify(self, job: Job, connection: PoolConnectionProxy | None = None) -> None: - payload = json.dumps({"key": job.key, "status": job.status}) - await self._notify(job.key, payload, connection) + await self._notify(job.key, job.status, connection) async def update( - self, - job: Job, + self, job: Job, connection: PoolConnectionProxy | None = None, expire_at: float | None = -1, - **kwargs: t.Any, + **kwargs: t.Any ) -> None: job.touched = now() for k, v in kwargs.items(): setattr(job, k, v) - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: - if expire_at != -1: + if expire_at != -1 : await conn.execute( - f""" - UPDATE {self.jobs_table} SET - job = $1, - status = $2, - scheduled = $3, - expire_at = $4 - WHERE key = $5 - """, - self.serialize(job), - job.status, - job.scheduled, - expire_at, - job.key, + dedent(f""" + UPDATE {self.jobs_table} + SET job=$1, status = $2, expire_at = $3 + WHERE key = $4 + """), + self.serialize(job), job.status, expire_at, job.key ) else: await conn.execute( - f""" - UPDATE {self.jobs_table} SET - job = $1, - status = $2, - scheduled = $3 - WHERE key = $4 - """, - self.serialize(job), - job.status, - job.scheduled, - job.key, + dedent(f""" + UPDATE {self.jobs_table} + SET job=$1, status = $2 + WHERE key = $3 + """), + self.serialize(job), job.status, job.key ) await self.notify(job, conn) - async def job(self, job_key: str) -> Job | None: + async def job(self, key: str) -> Job | None: async with self.pool.acquire() as conn: - result = await conn.fetchrow( - f""" - SELECT job - FROM {self.jobs_table} - WHERE key = $1 - """, - job_key, + record = await conn.fetchrow( + f"SELECT job FROM {self.jobs_table} WHERE key = $1", key ) - if result: - return self.deserialize(result['job']) - return None - - async def jobs(self, job_keys: Iterable[str]) -> t.List[Job | None]: - keys = list(job_keys) + return self.deserialize(record['job']) if record else None + async def jobs(self, keys: t.Iterable[str]) -> t.List[Job]: async with self.pool.acquire() as conn: - results = await conn.fetch( - f""" - SELECT key, job - FROM {self.jobs_table} - WHERE key = ANY($1) - """, - keys, + records = await conn.fetch( + f"SELECT job FROM {self.jobs_table} WHERE key = ANY($1::text[])", list(keys) ) - job_dict = {row['key']: row['job'] for row in results} - return [self.deserialize(job_dict.get(key)) for key in keys] - - async def iter_jobs( - self, - statuses: t.List[Status] = list(Status), - batch_size: int = 100, - ) -> t.AsyncIterator[Job]: + return [self.deserialize(record['job']) for record in records] + async def iter_jobs(self) -> t.AsyncIterator[Job]: async with self.pool.acquire() as conn: - last_key = "" - - while True: - results = await conn.fetch( - f""" - SELECT key, job - FROM {self.jobs_table} - WHERE - status = ANY($1) - AND queue = $2 - AND key > $3 - ORDER BY key - LIMIT $4 - """, - [status.value for status in statuses], - self.name, - last_key, - batch_size, - ) - - if not results: - break - - for row in results: - last_key = row['key'] - job = self.deserialize(row['job']) - if job: - yield job - + async for record in conn.cursor( + f"SELECT job FROM {self.jobs_table} WHERE queue = $1", self.name + ): + yield self.deserialize(record['job']) async def abort(self, job: Job, error: str, ttl: float = 5) -> None: async with self.pool.acquire() as conn: status = await self.get_job_status(job.key, for_update=True, connection=conn) @@ -454,53 +341,63 @@ async def abort(self, job: Job, error: str, ttl: float = 5) -> None: await self.update(job, status=Status.ABORTING, error=error, connection=conn) async def dequeue(self, timeout: float = 0) -> Job | None: - """Wait on `self.cond` to dequeue. - - Retries indefinitely until a job is available or times out. - """ - self.waiting += 1 - async with self.cond: - self.cond.notify(1) + job = None try: - return await asyncio.wait_for(self.queue.get(), timeout or None) - except asyncio.exceptions.TimeoutError: - return None + self._waiting += 1 + + if self._job_queue.empty(): + await self._dequeue() + + if not self._job_queue.empty(): + job = self._job_queue.get_nowait() + elif self._listen_lock.locked(): + job = await ( + asyncio.wait_for(self._job_queue.get(), timeout) + if timeout > 0 + else self._job_queue.get() + ) + else: + async with self._listen_lock: + async for _ in self._listener.listen(ENQUEUE, timeout=timeout): + await self._dequeue() + + if not self._job_queue.empty(): + job = self._job_queue.get_nowait() + break + except (asyncio.TimeoutError, asyncio.CancelledError): + pass finally: - self.waiting -= 1 + self._waiting -= 1 - async def wait_for_job(self) -> None: - while True: - async with self.cond: - await self.cond.wait() - - for job in await self._dequeue(): - await self.queue.put(job) + if job: + self._job_queue.task_done() + return job async def _enqueue(self, job: Job) -> Job | None: async with self.pool.acquire() as conn: - result = await conn.fetchrow( + result = await conn.execute( f""" INSERT INTO {self.jobs_table} (key, job, queue, status, scheduled) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (key) DO UPDATE - SET job = $2, queue = $3, status = $4, scheduled = $5, expire_at = null - WHERE {self.jobs_table}.status IN ('aborted', 'complete', 'failed') + SET + job = $2, + queue = $3, + status = $4, + scheduled = $5, + expire_at = null + WHERE + {self.jobs_table}.status IN ('aborted', 'complete', 'failed') AND $5 > {self.jobs_table}.scheduled - RETURNING job + RETURNING 1 """, - job.key, - self.serialize(job), - self.name, - job.status, - job.scheduled or seconds(now()), + job.key, self.serialize(job), self.name, job.status, job.scheduled or seconds(now()), ) if not result: return None - - await self._notify(ENQUEUE_CHANNEL, job.key, conn) - + await self._notify(ENQUEUE, connection=conn) logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) return job @@ -513,53 +410,29 @@ async def write_stats(self, stats: QueueStats, ttl: int) -> None: ON CONFLICT (worker_id) DO UPDATE SET stats = $2, expire_at = $3 """, - self.uuid, - json.dumps(stats), - seconds(now()) + ttl, + self.uuid, json.dumps(stats), seconds(now()) + ttl, ) - async def dequeue_timer(self, poll_interval: int) -> None: - """Wakes up a single dequeue task every `poll_interval` seconds.""" - while True: - async with self.cond: - self.cond.notify(1) - await asyncio.sleep(poll_interval) - - async def listen_for_enqueues(self, timeout: float | None = None) -> None: - """Wakes up a single dequeue task when a Postgres enqueue notification is received.""" - async def _listen(conn: PoolConnectionProxy | Connection, pid: int, channel: str, payload: str) -> None: - async with self.cond: - self.cond.notify(1) - async with self.pool.acquire() as conn: - await conn.add_listener(ENQUEUE_CHANNEL, _listen) - if timeout: - try: - await asyncio.sleep(timeout) - finally: - await conn.remove_listener(ENQUEUE_CHANNEL, _listen) - else: - # If no timeout, we'll keep listening indefinitely - while True: - await asyncio.sleep(30) # Sleep for 30 seconds and then check again - async def get_job_status( self, key: str, for_update: bool = False, connection: PoolConnectionProxy | None = None, ) -> Status: - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: - result = await conn.fetchval( + async with self.nullcontext( + connection + ) if connection else self.pool.acquire() as conn: + result = await conn.fetchrow( f""" SELECT status FROM {self.jobs_table} WHERE key = $1 - {"FOR UPDATE" if for_update else ""} + {('FOR UPDATE' if for_update else '')} """, key, ) assert result - return result + return result['status'] async def _retry(self, job: Job, error: str | None) -> None: next_retry_delay = job.next_retry_delay() @@ -581,34 +454,33 @@ async def _finish( ) -> None: key = job.key - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + async with self.nullcontext( + connection + ) if connection else self.pool.acquire() as conn: if job.ttl >= 0: expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None await self.update(job, status=status, expire_at=expire_at, connection=conn) else: await conn.execute( - f""" + dedent(f""" DELETE FROM {self.jobs_table} WHERE key = $1 - """, + """), key, ) await self.notify(job, conn) await self._release_job(key) - try: - self.queue.task_done() - except ValueError: - # Error because task_done() called too many times, which happens in unit tests - pass - - async def _dequeue(self) -> list[Job]: - if not self.waiting: - return [] - jobs = [] - async with self._get_connection() as conn: - async with conn.transaction(): + + async def _dequeue(self) -> None: + if self._dequeue_lock.locked(): + return + + async with self._dequeue_lock: + async with self._get_dequeue_conn() as conn: + if not self._waiting: + return results = await conn.fetch( - f""" + dedent(f""" WITH locked_job AS ( SELECT key, lock_key FROM {self.jobs_table} @@ -622,60 +494,68 @@ async def _dequeue(self) -> list[Job]: UPDATE {self.jobs_table} SET status = 'active' FROM locked_job WHERE {self.jobs_table}.key = locked_job.key - AND pg_try_advisory_lock($4, locked_job.lock_key) + AND pg_try_advisory_lock({self.job_lock_keyspace}, locked_job.lock_key) RETURNING job - """, - self.name, - math.ceil(seconds(now())), - self.waiting, - self.job_lock_keyspace, + """), + self.name, math.ceil(seconds(now())), self._waiting, ) - for result in results: - job = self.deserialize(result['job']) - if job: - await self.update(job, status=Status.ACTIVE, connection=conn) - jobs.append(job) - return jobs + for result in results: + self._job_queue.put_nowait(self.deserialize(result['job'])) async def _notify( - self, channel: str, payload: t.Any, connection: PoolConnectionProxy | None = None + self, key: str, data: t.Any | None = None, connection: PoolConnectionProxy | None = None ) -> None: + payload = {"key": key} + + if data is not None: + payload["data"] = data + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: - await conn.execute(f"NOTIFY \"{channel}\", '{payload}'") + await conn.execute( + f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'" + ) @asynccontextmanager - async def _get_connection(self) -> t.AsyncGenerator: - assert self.connection - async with self.connection_lock: + async def _get_dequeue_conn(self) -> t.AsyncGenerator: + assert self._dequeue_conn + async with self._connection_lock: try: - # Pool normally performs this check when getting a connection. - await self.connection.execute("SELECT 1") - except asyncpg.exceptions.ConnectionDoesNotExistError: - # The connection is bad so return it to the pool and get a new one. - await self.pool.release(self.connection) - self.connection = await self.pool.acquire() - yield self.connection + await self._dequeue_conn.execute("SELECT 1") + except ConnectionDoesNotExistError: + await self.pool.release(self._dequeue_conn) + self._dequeue_conn = await self.pool.acquire() + yield self._dequeue_conn @asynccontextmanager async def nullcontext(self, enter_result: t.Any | None = None) -> t.AsyncGenerator: - """Async version of contextlib.nullcontext - - Async support has been added to contextlib.nullcontext in Python 3.10. - """ yield enter_result async def _release_job(self, key: str) -> None: - self.released.append(key) - if self.connection_lock.locked(): + self._releasing.append(key) + if self._connection_lock.locked(): return - async with self._get_connection() as conn: + async with self._get_dequeue_conn() as conn: await conn.execute( f""" - SELECT pg_advisory_unlock($1, lock_key) + SELECT pg_advisory_unlock({self.job_lock_keyspace}, lock_key) FROM {self.jobs_table} - WHERE key = ANY($2) + WHERE key = ANY($1) """, - self.job_lock_keyspace, - self.released, + self._releasing, ) - self.released.clear() \ No newline at end of file + self._releasing.clear() + + +class ListenMultiplexer(Multiplexer): + def __init__(self, pool: Pool, key: str) -> None: + super().__init__() + self.pool = pool + self.key = key + + async def _start(self) -> None: + async with self.pool.acquire() as conn: + await conn.add_listener(self.key, self._notify_callback) + + async def _notify_callback(self, connection, pid, channel, payload): + payload_data = json.loads(payload) + self.publish(payload_data["key"], payload_data) \ No newline at end of file diff --git a/tests/helpers.py b/tests/helpers.py index 1fc2b18..b10dd7c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,7 @@ from saq.queue import Queue from saq.queue.postgres import PostgresQueue -from saq.queue.postgres_asyncpg import PostgresQueue as AsyncpgPostgresQueue +from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue as AsyncpgPostgresQueue from saq.queue.redis import RedisQueue POSTGRES_TEST_SCHEMA = "test_saq" @@ -62,9 +62,8 @@ async def create_postgres_asyncpg_queue(**kwargs: t.Any) -> AsyncpgPostgresQueue f"postgres+asyncpg://postgres@localhost?options=--search_path%3D{POSTGRES_TEST_SCHEMA}", **kwargs, ), - ) + ) await queue.connect() - await queue.upkeep() await asyncio.sleep(0.1) # Give some time for the tasks to start return queue diff --git a/tests/test_queue.py b/tests/test_queue.py index 17e1cdc..b151e7d 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -30,7 +30,7 @@ from unittest.mock import MagicMock from saq.queue.postgres import PostgresQueue - from saq.queue.postgres_asyncpg import PostgresQueue as AsyncpgPostgresQueue + from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue as AsyncpgPostgresQueue from saq.queue.redis import RedisQueue from saq.types import Context, CountKind, Function From 9516959b2c031b82b313de61e099f73e15d51d40 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 6 Oct 2024 14:51:52 -0500 Subject: [PATCH 03/13] feat: current progress --- saq/queue/postgres.py | 2 +- saq/queue/postgres_asyncpg.py | 283 +++++++++++++++++++--------------- tests/test_queue.py | 21 +-- 3 files changed, 174 insertions(+), 132 deletions(-) diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 94581b7..3bb6326 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -833,4 +833,4 @@ async def _start(self) -> None: async for notify in conn.notifies(): payload = json.loads(notify.payload) - self.publish(payload["key"], payload) + self.publish(payload["key"], payload) \ No newline at end of file diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py index 436dda2..fe61195 100644 --- a/saq/queue/postgres_asyncpg.py +++ b/saq/queue/postgres_asyncpg.py @@ -9,25 +9,28 @@ import logging import math import time -import typing as t +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from textwrap import dedent +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, cast from saq.errors import MissingDependencyError from saq.job import Job, Status from saq.multiplexer import Multiplexer from saq.queue.base import Queue, logger from saq.queue.postgres_ddl import DDL_STATEMENTS +from saq.types import ListenCallback from saq.utils import now, seconds -if t.TYPE_CHECKING: +if TYPE_CHECKING: from collections.abc import Iterable + from saq.types import CountKind, DumpType, LoadType, QueueInfo, QueueStats try: - from asyncpg.pool import PoolConnectionProxy from asyncpg import Pool, create_pool - from asyncpg.exceptions import ConnectionDoesNotExistError - + from asyncpg.exceptions import ConnectionDoesNotExistError, InterfaceError + from asyncpg.pool import PoolConnectionProxy + except ModuleNotFoundError as e: raise MissingDependencyError( "Missing dependencies for Postgres. Install them with `pip install saq[asyncpg]`." @@ -35,6 +38,7 @@ CHANNEL = "saq:{}" ENQUEUE = "saq:enqueue" +DEQUEUE = "saq:dequeue" JOBS_TABLE = "saq_jobs" STATS_TABLE = "saq_stats" @@ -45,14 +49,15 @@ class PostgresAsyncpgQueue(Queue): """ @classmethod - def from_url(cls: type[PostgresAsyncpgQueue], url: str, **kwargs: t.Any) -> PostgresAsyncpgQueue: # pyright: ignore[reportIncompatibleMethodOverride] + def from_url(cls: type[PostgresAsyncpgQueue], url: str, **kwargs: Any) -> PostgresAsyncpgQueue: # pyright: ignore[reportIncompatibleMethodOverride] """Create a queue from a postgres url.""" - pool = create_pool(dsn=url, **kwargs, ) - return cls(t.cast("Pool[t.Any]", pool), **kwargs) + pool_kwargs = {k: v for k, v in kwargs.items() if k in {"min_size", "max_size"}} + pool = create_pool(dsn=url, **pool_kwargs) + return cls(cast("Pool[Any]", pool), **kwargs) def __init__( self, - pool: Pool[t.Any], + pool: Pool[Any], name: str = "default", jobs_table: str = JOBS_TABLE, stats_table: str = STATS_TABLE, @@ -94,6 +99,8 @@ async def init_db(self) -> None: async def connect(self) -> None: if self._dequeue_conn: return + # the return from the `from_url` call must be awaited. The loop isn't running at the time `from_url` is called, so this seemed to make the most sense + self.pool._loop = asyncio.get_event_loop() await self.pool self._dequeue_conn = await self.pool.acquire() await self.init_db() @@ -115,13 +122,13 @@ async def disconnect(self) -> None: async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: async with self.pool.acquire() as conn: results = await conn.fetch( - f""" + dedent(f""" SELECT worker_id, stats FROM {self.stats_table} WHERE $1 <= expire_at - """, + """), seconds(now()), ) - workers: dict[str, dict[str, t.Any]] = {row['worker_id']: json.loads(row['stats']) for row in results} + workers: dict[str, dict[str, Any]] = {row["worker_id"]: json.loads(row["stats"]) for row in results} queued = await self.count("queued") active = await self.count("active") @@ -130,12 +137,12 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu if jobs: async with self.pool.acquire() as conn: results = await conn.fetch( - f""" + dedent(f""" SELECT job FROM {self.jobs_table} WHERE status IN ('new', 'deferred', 'queued', 'active') - """ + """) ) - deserialized_jobs = (self.deserialize(result['job']) for result in results) + deserialized_jobs = (self.deserialize(result["job"]) for result in results) jobs_info = [job.to_dict() for job in deserialized_jobs if job] else: jobs_info = [] @@ -153,30 +160,31 @@ async def count(self, kind: CountKind) -> int: async with self.pool.acquire() as conn: if kind == "queued": result = await conn.fetchval( - f""" + dedent(f""" SELECT count(*) FROM {self.jobs_table} WHERE status = 'queued' AND queue = $1 AND $2 >= scheduled - """, - self.name, math.ceil(seconds(now())), + """), + self.name, + math.ceil(seconds(now())), ) elif kind == "active": result = await conn.fetchval( - f""" + dedent(f""" SELECT count(*) FROM {self.jobs_table} WHERE status = 'active' AND queue = $1 - """, + """), self.name, ) elif kind == "incomplete": result = await conn.fetchval( - f""" + dedent(f""" SELECT count(*) FROM {self.jobs_table} WHERE status IN ('new', 'deferred', 'queued', 'active') AND queue = $1 - """, + """), self.name, ) else: @@ -184,7 +192,7 @@ async def count(self, kind: CountKind) -> int: return result - async def schedule(self, lock: int = 1) -> t.List[str]: + async def schedule(self, lock: int = 1) -> List[str]: await self._dequeue() return [] @@ -194,8 +202,9 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: if not self._has_sweep_lock: async with self._get_dequeue_conn() as conn: result = await conn.fetchval( - "SELECT pg_try_advisory_lock($1, hashtext($2))", - self.saq_lock_keyspace, self.name, + dedent("SELECT pg_try_advisory_lock($1, hashtext($2))"), + self.saq_lock_keyspace, + self.name, ) if not result: return [] @@ -206,10 +215,11 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: dedent(f""" DELETE FROM {self.jobs_table} WHERE queue = $1 - AND status in ('aborted', 'complete', 'failed') + AND status IN ('aborted', 'complete', 'failed') AND $2 >= expire_at """), - math.ceil(seconds(now())),self.name + math.ceil(seconds(now())), + self.name, ) await conn.execute( dedent(f""" @@ -219,9 +229,8 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: math.ceil(seconds(now())), ) results = await conn.fetch( - - dedent( - f""" + dedent( + f""" WITH locks AS ( SELECT objid FROM pg_locks @@ -235,9 +244,11 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: ON lock_key = objid WHERE queue = $2 AND status IN ('active', 'aborting'); - """) , self.name, self.job_lock_keyspace, - - ) + """ + ), + self.name, + self.job_lock_keyspace, + ) for key, job_bytes, objid, status in results: job = self.deserialize(job_bytes) @@ -283,24 +294,24 @@ async def notify(self, job: Job, connection: PoolConnectionProxy | None = None) await self._notify(job.key, job.status, connection) async def update( - self, job: Job, - connection: PoolConnectionProxy | None = None, - expire_at: float | None = -1, - **kwargs: t.Any + self, job: Job, connection: PoolConnectionProxy | None = None, expire_at: float | None = -1, **kwargs: Any ) -> None: job.touched = now() for k, v in kwargs.items(): setattr(job, k, v) async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: - if expire_at != -1 : + if expire_at != -1: await conn.execute( dedent(f""" UPDATE {self.jobs_table} SET job=$1, status = $2, expire_at = $3 WHERE key = $4 """), - self.serialize(job), job.status, expire_at, job.key + self.serialize(job), + job.status, + expire_at, + job.key, ) else: await conn.execute( @@ -309,29 +320,47 @@ async def update( SET job=$1, status = $2 WHERE key = $3 """), - self.serialize(job), job.status, job.key + self.serialize(job), + job.status, + job.key, ) await self.notify(job, conn) - async def job(self, key: str) -> Job | None: + async def job(self, job_key: str) -> Job | None: async with self.pool.acquire() as conn: - record = await conn.fetchrow( - f"SELECT job FROM {self.jobs_table} WHERE key = $1", key - ) - return self.deserialize(record['job']) if record else None + record = await conn.fetchrow(f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key) + return self.deserialize(record["job"]) if record else None - async def jobs(self, keys: t.Iterable[str]) -> t.List[Job]: + async def jobs(self, job_keys: Iterable[str]) -> List[Job | None]: + keys = list(job_keys) async with self.pool.acquire() as conn: - records = await conn.fetch( - f"SELECT job FROM {self.jobs_table} WHERE key = ANY($1::text[])", list(keys) - ) - return [self.deserialize(record['job']) for record in records] - async def iter_jobs(self) -> t.AsyncIterator[Job]: + records = await conn.fetch(f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys) + return [self.deserialize(record["job"]) for record in records] + + async def iter_jobs( + self, + statuses: List[Status] = list(Status), + batch_size: int = 100, + ) -> AsyncIterator[Job]: async with self.pool.acquire() as conn: - async for record in conn.cursor( - f"SELECT job FROM {self.jobs_table} WHERE queue = $1", self.name - ): - yield self.deserialize(record['job']) + last_key = "" + while True: + rows = await conn.fetch( + f"SELECT key, job FROM {self.jobs_table} WHERE status = ANY($1) AND queue = $2 AND key > $3 ORDER BY key LIMIT $4", + statuses, + self.name, + last_key, + batch_size, + ) + if rows: + for row in rows: + last_key = row["key"] + job = self.deserialize(row["job"]) + if job: + yield job + else: + break + async def abort(self, job: Job, error: str, ttl: float = 5) -> None: async with self.pool.acquire() as conn: status = await self.get_job_status(job.key, for_update=True, connection=conn) @@ -352,15 +381,12 @@ async def dequeue(self, timeout: float = 0) -> Job | None: if not self._job_queue.empty(): job = self._job_queue.get_nowait() elif self._listen_lock.locked(): - job = await ( - asyncio.wait_for(self._job_queue.get(), timeout) - if timeout > 0 - else self._job_queue.get() - ) + job = await (asyncio.wait_for(self._job_queue.get(), timeout) if timeout > 0 else self._job_queue.get()) else: async with self._listen_lock: - async for _ in self._listener.listen(ENQUEUE, timeout=timeout): - await self._dequeue() + async for payload in self._listener.listen(ENQUEUE, DEQUEUE, timeout=timeout): + if payload["key"] == ENQUEUE: + await self._dequeue() if not self._job_queue.empty(): job = self._job_queue.get_nowait() @@ -374,6 +400,42 @@ async def dequeue(self, timeout: float = 0) -> Job | None: self._job_queue.task_done() return job + + async def _dequeue(self) -> None: + if self._dequeue_lock.locked(): + return + + async with self._dequeue_lock: + async with self._get_dequeue_conn() as conn, conn.transaction(): + if not self._waiting: + return + results = await conn.fetch( + dedent(f""" + WITH locked_job AS ( + SELECT key, lock_key + FROM {self.jobs_table} + WHERE status = 'queued' + AND queue = $1 + AND $2 >= scheduled + ORDER BY scheduled + LIMIT $3 + FOR UPDATE SKIP LOCKED + ) + UPDATE {self.jobs_table} SET status = 'active' + FROM locked_job + WHERE {self.jobs_table}.key = locked_job.key + AND pg_try_advisory_lock({self.job_lock_keyspace}, locked_job.lock_key) + RETURNING job + """), + self.name, + math.ceil(seconds(now())), + self._waiting, + ) + for result in results: + self._job_queue.put_nowait(self.deserialize(result[0])) + if results: + await self._notify(DEQUEUE) + async def _enqueue(self, job: Job) -> Job | None: async with self.pool.acquire() as conn: result = await conn.execute( @@ -392,7 +454,11 @@ async def _enqueue(self, job: Job) -> Job | None: AND $5 > {self.jobs_table}.scheduled RETURNING 1 """, - job.key, self.serialize(job), self.name, job.status, job.scheduled or seconds(now()), + job.key, + self.serialize(job), + self.name, + job.status, + job.scheduled or seconds(now()), ) if not result: @@ -410,7 +476,9 @@ async def write_stats(self, stats: QueueStats, ttl: int) -> None: ON CONFLICT (worker_id) DO UPDATE SET stats = $2, expire_at = $3 """, - self.uuid, json.dumps(stats), seconds(now()) + ttl, + self.uuid, + json.dumps(stats), + seconds(now()) + ttl, ) async def get_job_status( @@ -419,20 +487,18 @@ async def get_job_status( for_update: bool = False, connection: PoolConnectionProxy | None = None, ) -> Status: - async with self.nullcontext( - connection - ) if connection else self.pool.acquire() as conn: - result = await conn.fetchrow( + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + result = await conn.fetchval( f""" SELECT status FROM {self.jobs_table} WHERE key = $1 - {('FOR UPDATE' if for_update else '')} + {"FOR UPDATE" if for_update else ""} """, key, ) assert result - return result['status'] + return result async def _retry(self, job: Job, error: str | None) -> None: next_retry_delay = job.next_retry_delay() @@ -448,15 +514,13 @@ async def _finish( job: Job, status: Status, *, - result: t.Any = None, + result: Any = None, error: str | None = None, connection: PoolConnectionProxy | None = None, ) -> None: key = job.key - async with self.nullcontext( - connection - ) if connection else self.pool.acquire() as conn: + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: if job.ttl >= 0: expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None await self.update(job, status=status, expire_at=expire_at, connection=conn) @@ -471,63 +535,30 @@ async def _finish( await self.notify(job, conn) await self._release_job(key) - async def _dequeue(self) -> None: - if self._dequeue_lock.locked(): - return - - async with self._dequeue_lock: - async with self._get_dequeue_conn() as conn: - if not self._waiting: - return - results = await conn.fetch( - dedent(f""" - WITH locked_job AS ( - SELECT key, lock_key - FROM {self.jobs_table} - WHERE status = 'queued' - AND queue = $1 - AND $2 >= scheduled - ORDER BY scheduled - LIMIT $3 - FOR UPDATE SKIP LOCKED - ) - UPDATE {self.jobs_table} SET status = 'active' - FROM locked_job - WHERE {self.jobs_table}.key = locked_job.key - AND pg_try_advisory_lock({self.job_lock_keyspace}, locked_job.lock_key) - RETURNING job - """), - self.name, math.ceil(seconds(now())), self._waiting, - ) - for result in results: - self._job_queue.put_nowait(self.deserialize(result['job'])) - - async def _notify( - self, key: str, data: t.Any | None = None, connection: PoolConnectionProxy | None = None - ) -> None: + async def _notify(self, key: str, data: Any | None = None, connection: PoolConnectionProxy | None = None) -> None: payload = {"key": key} if data is not None: payload["data"] = data async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: - await conn.execute( - f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'" - ) + await conn.execute(f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'") @asynccontextmanager - async def _get_dequeue_conn(self) -> t.AsyncGenerator: - assert self._dequeue_conn + async def _get_dequeue_conn(self) -> AsyncGenerator: async with self._connection_lock: - try: - await self._dequeue_conn.execute("SELECT 1") - except ConnectionDoesNotExistError: - await self.pool.release(self._dequeue_conn) + if self._dequeue_conn: + try: + await self._dequeue_conn.execute("SELECT 1") + except (ConnectionDoesNotExistError, InterfaceError): + await self.pool.release(self._dequeue_conn) + self._dequeue_conn = await self.pool.acquire() + else: self._dequeue_conn = await self.pool.acquire() yield self._dequeue_conn @asynccontextmanager - async def nullcontext(self, enter_result: t.Any | None = None) -> t.AsyncGenerator: + async def nullcontext(self, enter_result: Any | None = None) -> AsyncGenerator: yield enter_result async def _release_job(self, key: str) -> None: @@ -543,6 +574,7 @@ async def _release_job(self, key: str) -> None: """, self._releasing, ) + await conn.execute("commit;") self._releasing.clear() @@ -551,11 +583,20 @@ def __init__(self, pool: Pool, key: str) -> None: super().__init__() self.pool = pool self.key = key + self._connection: PoolConnectionProxy | None = None async def _start(self) -> None: - async with self.pool.acquire() as conn: - await conn.add_listener(self.key, self._notify_callback) + if self._connection is None: + self._connection = await self.pool.acquire() + await self._connection.add_listener(self.key, self._notify_callback) + await self._connection.execute("COMMIT;") + + async def _close(self) -> None: + if self._connection: + await self._connection.remove_listener(self.key, self._notify_callback) + await self._connection.close() + self._connection = None async def _notify_callback(self, connection, pid, channel, payload): payload_data = json.loads(payload) - self.publish(payload_data["key"], payload_data) \ No newline at end of file + self.publish(payload_data["key"], payload_data) diff --git a/tests/test_queue.py b/tests/test_queue.py index b151e7d..f1d49c4 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -908,9 +908,9 @@ async def test_sweep_stats(self) -> None: """ SELECT stats FROM {} - WHERE worker_id = %s + WHERE worker_id = $1 """.format(self.queue.stats_table), - (self.queue.uuid,), + self.queue.uuid ) self.assertIsNone(result) @@ -924,9 +924,9 @@ async def test_sweep_stats(self) -> None: """ SELECT stats FROM {} - WHERE worker_id = %s + WHERE worker_id = $1 """.format(self.queue.stats_table), - (self.queue.uuid,), + self.queue.uuid ) self.assertIsNotNone(result) @@ -934,19 +934,19 @@ async def test_job_lock(self) -> None: query = """ SELECT count(*) FROM {} JOIN pg_locks ON lock_key = objid - WHERE key = $key + WHERE key = $1 AND classid = {} AND objsubid = 2 -- key is int pair, not single bigint """.format(self.queue.jobs_table, self.queue.job_lock_keyspace) job = await self.enqueue("test") await self.dequeue() async with self.queue.pool.acquire() as conn : - result = await conn.fetchval(query, {"key": job.key}) + result = await conn.fetchval(query, job.key) self.assertEqual(result, 1) await self.finish(job, Status.COMPLETE, result=1) async with self.queue.pool.acquire() as conn : - result = await conn.execute(query, {"key": job.key}) + result = await conn.execute(query, job.key) self.assertEqual(result, (0,)) async def test_load_dump_pickle(self) -> None: @@ -1025,12 +1025,13 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: async def test_bad_connection(self) -> None: job = await self.enqueue("test") - original_connection = self.queue.connection - await self.queue.connection.close() + original_connection = self.queue._dequeue_conn + if self.queue._dequeue_conn: + await self.queue._dequeue_conn.close() # Test dequeue still works self.assertEqual((await self.dequeue()), job) # Check queue has a new connection - self.assertNotEqual(original_connection, self.queue.connection) + self.assertNotEqual(original_connection,self.queue._dequeue_conn) async def test_group_key(self) -> None: From dfcbb290bb67b4bfe05ac7df691a06f62e8e36ce Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 6 Oct 2024 15:15:39 -0500 Subject: [PATCH 04/13] fix: revert change to file --- saq/queue/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 3bb6326..94581b7 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -833,4 +833,4 @@ async def _start(self) -> None: async for notify in conn.notifies(): payload = json.loads(notify.payload) - self.publish(payload["key"], payload) \ No newline at end of file + self.publish(payload["key"], payload) From 5dbe4be5c92f63aa789fef760f204988cac92f9e Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 8 Oct 2024 17:17:16 +0000 Subject: [PATCH 05/13] fix: updated --- saq/queue/postgres_asyncpg.py | 108 +++++++++++++++++++++++++--------- 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py index fe61195..3ad63f2 100644 --- a/saq/queue/postgres_asyncpg.py +++ b/saq/queue/postgres_asyncpg.py @@ -49,7 +49,9 @@ class PostgresAsyncpgQueue(Queue): """ @classmethod - def from_url(cls: type[PostgresAsyncpgQueue], url: str, **kwargs: Any) -> PostgresAsyncpgQueue: # pyright: ignore[reportIncompatibleMethodOverride] + def from_url( # pyright: ignore[reportIncompatibleMethodOverride] + cls: type[PostgresAsyncpgQueue], url: str, **kwargs: Any + ) -> PostgresAsyncpgQueue: """Create a queue from a postgres url.""" pool_kwargs = {k: v for k, v in kwargs.items() if k in {"min_size", "max_size"}} pool = create_pool(dsn=url, **pool_kwargs) @@ -94,13 +96,17 @@ def __init__( async def init_db(self) -> None: async with self.pool.acquire() as conn: for statement in DDL_STATEMENTS: - await conn.execute(statement.format(jobs_table=self.jobs_table, stats_table=self.stats_table)) + await conn.execute( + statement.format( + jobs_table=self.jobs_table, stats_table=self.stats_table + ) + ) async def connect(self) -> None: if self._dequeue_conn: return # the return from the `from_url` call must be awaited. The loop isn't running at the time `from_url` is called, so this seemed to make the most sense - self.pool._loop = asyncio.get_event_loop() + self.pool._loop = asyncio.get_event_loop() # type: ignore[attr-defined] await self.pool self._dequeue_conn = await self.pool.acquire() await self.init_db() @@ -119,7 +125,9 @@ async def disconnect(self) -> None: await self.pool.close() self._has_sweep_lock = False - async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: + async def info( + self, jobs: bool = False, offset: int = 0, limit: int = 10 + ) -> QueueInfo: async with self.pool.acquire() as conn: results = await conn.fetch( dedent(f""" @@ -128,7 +136,9 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu """), seconds(now()), ) - workers: dict[str, dict[str, Any]] = {row["worker_id"]: json.loads(row["stats"]) for row in results} + workers: dict[str, dict[str, Any]] = { + row["worker_id"]: json.loads(row["stats"]) for row in results + } queued = await self.count("queued") active = await self.count("active") @@ -200,7 +210,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: swept = [] if not self._has_sweep_lock: - async with self._get_dequeue_conn() as conn: + async with self._get_dequeue_conn() as conn, conn.transaction(): result = await conn.fetchval( dedent("SELECT pg_try_advisory_lock($1, hashtext($2))"), self.saq_lock_keyspace, @@ -218,8 +228,8 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: AND status IN ('aborted', 'complete', 'failed') AND $2 >= expire_at """), - math.ceil(seconds(now())), self.name, + math.ceil(seconds(now())), ) await conn.execute( dedent(f""" @@ -246,8 +256,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: AND status IN ('active', 'aborting'); """ ), - self.name, - self.job_lock_keyspace, + self.job_lock_keyspace, self.name, ) for key, job_bytes, objid, status in results: @@ -290,17 +299,25 @@ async def listen( if stop: break - async def notify(self, job: Job, connection: PoolConnectionProxy | None = None) -> None: + async def notify( + self, job: Job, connection: PoolConnectionProxy | None = None + ) -> None: await self._notify(job.key, job.status, connection) async def update( - self, job: Job, connection: PoolConnectionProxy | None = None, expire_at: float | None = -1, **kwargs: Any + self, + job: Job, + connection: PoolConnectionProxy | None = None, + expire_at: float | None = -1, + **kwargs: Any, ) -> None: job.touched = now() for k, v in kwargs.items(): setattr(job, k, v) - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + async with self.nullcontext( # type: ignore[attr-defined] + connection + ) if connection else self.pool.acquire() as conn: if expire_at != -1: await conn.execute( dedent(f""" @@ -328,14 +345,19 @@ async def update( async def job(self, job_key: str) -> Job | None: async with self.pool.acquire() as conn: - record = await conn.fetchrow(f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key) + record = await conn.fetchrow( + f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key + ) return self.deserialize(record["job"]) if record else None async def jobs(self, job_keys: Iterable[str]) -> List[Job | None]: keys = list(job_keys) async with self.pool.acquire() as conn: - records = await conn.fetch(f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys) - return [self.deserialize(record["job"]) for record in records] + records = await conn.fetch( + f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys + ) + results = {record.get('key'): record.get('job') for record in records} + return [self.deserialize(results.get(key)) for key in keys] async def iter_jobs( self, @@ -346,16 +368,23 @@ async def iter_jobs( last_key = "" while True: rows = await conn.fetch( - f"SELECT key, job FROM {self.jobs_table} WHERE status = ANY($1) AND queue = $2 AND key > $3 ORDER BY key LIMIT $4", + dedent(f""" + SELECT key, job + FROM {self.jobs_table} + WHERE status = ANY($1) + AND queue = $2 + AND key > $3 + ORDER BY key + LIMIT $4"""), statuses, self.name, last_key, batch_size, ) if rows: - for row in rows: - last_key = row["key"] - job = self.deserialize(row["job"]) + for key, job_bytes in rows: + last_key = key + job = self.deserialize(job_bytes) if job: yield job else: @@ -363,11 +392,15 @@ async def iter_jobs( async def abort(self, job: Job, error: str, ttl: float = 5) -> None: async with self.pool.acquire() as conn: - status = await self.get_job_status(job.key, for_update=True, connection=conn) + status = await self.get_job_status( + job.key, for_update=True, connection=conn + ) if status == Status.QUEUED: await self.finish(job, Status.ABORTED, error=error, connection=conn) else: - await self.update(job, status=Status.ABORTING, error=error, connection=conn) + await self.update( + job, status=Status.ABORTING, error=error, connection=conn + ) async def dequeue(self, timeout: float = 0) -> Job | None: job = None @@ -381,10 +414,16 @@ async def dequeue(self, timeout: float = 0) -> Job | None: if not self._job_queue.empty(): job = self._job_queue.get_nowait() elif self._listen_lock.locked(): - job = await (asyncio.wait_for(self._job_queue.get(), timeout) if timeout > 0 else self._job_queue.get()) + job = await ( + asyncio.wait_for(self._job_queue.get(), timeout) + if timeout > 0 + else self._job_queue.get() + ) else: async with self._listen_lock: - async for payload in self._listener.listen(ENQUEUE, DEQUEUE, timeout=timeout): + async for payload in self._listener.listen( + ENQUEUE, DEQUEUE, timeout=timeout + ): if payload["key"] == ENQUEUE: await self._dequeue() @@ -487,7 +526,9 @@ async def get_job_status( for_update: bool = False, connection: PoolConnectionProxy | None = None, ) -> Status: - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + async with self.nullcontext( # type: ignore[attr-defined] + connection + ) if connection else self.pool.acquire() as conn: result = await conn.fetchval( f""" SELECT status @@ -520,10 +561,14 @@ async def _finish( ) -> None: key = job.key - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + async with self.nullcontext( # type: ignore[attr-defined] + connection + ) if connection else self.pool.acquire() as conn: if job.ttl >= 0: expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None - await self.update(job, status=status, expire_at=expire_at, connection=conn) + await self.update( + job, status=status, expire_at=expire_at, connection=conn + ) else: await conn.execute( dedent(f""" @@ -535,13 +580,20 @@ async def _finish( await self.notify(job, conn) await self._release_job(key) - async def _notify(self, key: str, data: Any | None = None, connection: PoolConnectionProxy | None = None) -> None: + async def _notify( + self, + key: str, + data: Any | None = None, + connection: PoolConnectionProxy | None = None, + ) -> None: payload = {"key": key} if data is not None: payload["data"] = data - async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: + async with self.nullcontext( # type: ignore[attr-defined] + connection + ) if connection else self.pool.acquire() as conn: await conn.execute(f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'") @asynccontextmanager From fb4bfef1f0f0060532cb87890d70e55783e89ec4 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 12 Oct 2024 18:55:55 +0000 Subject: [PATCH 06/13] feat: upstream changes --- saq/queue/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 94581b7..3bb6326 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -833,4 +833,4 @@ async def _start(self) -> None: async for notify in conn.notifies(): payload = json.loads(notify.payload) - self.publish(payload["key"], payload) + self.publish(payload["key"], payload) \ No newline at end of file From 8652a4c080f1f0a091788d1372c43ef2bbf13a05 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 16:39:59 +0000 Subject: [PATCH 07/13] feat: passing tests --- saq/queue/postgres_asyncpg.py | 226 +++++++++++++++++++--------------- tests/test_queue.py | 84 +++++++------ 2 files changed, 173 insertions(+), 137 deletions(-) diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py index 3ad63f2..de71e2c 100644 --- a/saq/queue/postgres_asyncpg.py +++ b/saq/queue/postgres_asyncpg.py @@ -7,12 +7,11 @@ import asyncio import json import logging -import math import time from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from textwrap import dedent -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, cast +from typing import TYPE_CHECKING, Any, AsyncIterator, List, TypeVar from saq.errors import MissingDependencyError from saq.job import Job, Status @@ -27,7 +26,7 @@ from saq.types import CountKind, DumpType, LoadType, QueueInfo, QueueStats try: - from asyncpg import Pool, create_pool + from asyncpg import Pool, create_pool, Connection from asyncpg.exceptions import ConnectionDoesNotExistError, InterfaceError from asyncpg.pool import PoolConnectionProxy @@ -42,6 +41,8 @@ JOBS_TABLE = "saq_jobs" STATS_TABLE = "saq_stats" +ContextT = TypeVar("ContextT") + class PostgresAsyncpgQueue(Queue): """ @@ -49,13 +50,16 @@ class PostgresAsyncpgQueue(Queue): """ @classmethod - def from_url( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type[PostgresAsyncpgQueue], url: str, **kwargs: Any + def from_url( # pyright: ignore[reportIncompatibleMethodOverride] + cls: type[PostgresAsyncpgQueue], + url: str, + min_size: int = 4, + max_size: int = 20, + **kwargs: Any, ) -> PostgresAsyncpgQueue: """Create a queue from a postgres url.""" - pool_kwargs = {k: v for k, v in kwargs.items() if k in {"min_size", "max_size"}} - pool = create_pool(dsn=url, **pool_kwargs) - return cls(cast("Pool[Any]", pool), **kwargs) + pool = create_pool(dsn=url, min_size=min_size, max_size=max_size) + return cls(pool, **kwargs) def __init__( self, @@ -65,8 +69,6 @@ def __init__( stats_table: str = STATS_TABLE, dump: DumpType | None = None, load: LoadType | None = None, - min_size: int = 4, - max_size: int = 20, poll_interval: int = 1, saq_lock_keyspace: int = 0, job_lock_keyspace: int = 1, @@ -76,8 +78,6 @@ def __init__( self.jobs_table = jobs_table self.stats_table = stats_table self.pool = pool - self.min_size = min_size - self.max_size = max_size self.poll_interval = poll_interval self.saq_lock_keyspace = saq_lock_keyspace self.job_lock_keyspace = job_lock_keyspace @@ -94,14 +94,13 @@ def __init__( self._listen_lock = asyncio.Lock() async def init_db(self) -> None: - async with self.pool.acquire() as conn: + async with self.with_connection() as conn, conn.transaction(): for statement in DDL_STATEMENTS: await conn.execute( statement.format( jobs_table=self.jobs_table, stats_table=self.stats_table ) - ) - + ) async def connect(self) -> None: if self._dequeue_conn: return @@ -116,6 +115,15 @@ def serialize(self, job: Job) -> bytes | str: if isinstance(serialized, str): return serialized.encode("utf-8") return serialized + + @asynccontextmanager + async def with_connection( + self, connection: PoolConnectionProxy | None = None + ) -> AsyncGenerator[PoolConnectionProxy]: + async with self._nullcontext( + connection + ) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined] + yield conn async def disconnect(self) -> None: async with self._connection_lock: @@ -128,24 +136,23 @@ async def disconnect(self) -> None: async def info( self, jobs: bool = False, offset: int = 0, limit: int = 10 ) -> QueueInfo: - async with self.pool.acquire() as conn: + workers: dict[str, dict[str, Any]] = {} + async with self.with_connection() as conn: results = await conn.fetch( dedent(f""" SELECT worker_id, stats FROM {self.stats_table} WHERE $1 <= expire_at """), - seconds(now()), + int(seconds(now())), ) - workers: dict[str, dict[str, Any]] = { - row["worker_id"]: json.loads(row["stats"]) for row in results - } - + for record in results: + workers[record.get("worker_id")] = record.get("stats") queued = await self.count("queued") active = await self.count("active") incomplete = await self.count("incomplete") if jobs: - async with self.pool.acquire() as conn: + async with self.with_connection() as conn: results = await conn.fetch( dedent(f""" SELECT job FROM {self.jobs_table} @@ -167,7 +174,7 @@ async def info( } async def count(self, kind: CountKind) -> int: - async with self.pool.acquire() as conn: + async with self.with_connection() as conn: if kind == "queued": result = await conn.fetchval( dedent(f""" @@ -177,7 +184,7 @@ async def count(self, kind: CountKind) -> int: AND $2 >= scheduled """), self.name, - math.ceil(seconds(now())), + int(seconds(now())), ) elif kind == "active": result = await conn.fetchval( @@ -202,15 +209,17 @@ async def count(self, kind: CountKind) -> int: return result - async def schedule(self, lock: int = 1) -> List[str]: + async def schedule(self, lock: int = 1) -> list[str]: await self._dequeue() return [] async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: - swept = [] + """Delete jobs and stats past their expiration and sweep stuck jobs""" + swept = [] if not self._has_sweep_lock: - async with self._get_dequeue_conn() as conn, conn.transaction(): + # Attempt to get the sweep lock and hold on to it + async with self._get_dequeue_conn() as conn: result = await conn.fetchval( dedent("SELECT pg_try_advisory_lock($1, hashtext($2))"), self.saq_lock_keyspace, @@ -220,7 +229,8 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: return [] self._has_sweep_lock = True - async with self.pool.acquire() as conn: + async with self.with_connection() as conn, conn.transaction(): + expired_at = int(seconds(now())) await conn.execute( dedent(f""" DELETE FROM {self.jobs_table} @@ -229,14 +239,14 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: AND $2 >= expire_at """), self.name, - math.ceil(seconds(now())), + expired_at, ) await conn.execute( dedent(f""" DELETE FROM {self.stats_table} WHERE $1 >= expire_at; """), - math.ceil(seconds(now())), + expired_at, ) results = await conn.fetch( dedent( @@ -256,10 +266,12 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: AND status IN ('active', 'aborting'); """ ), - self.job_lock_keyspace, self.name, + self.job_lock_keyspace, + self.name, ) - for key, job_bytes, objid, status in results: + for row in results: + key, job_bytes, objid, _status = row.values() job = self.deserialize(job_bytes) assert job if objid and not job.stuck: @@ -315,18 +327,17 @@ async def update( for k, v in kwargs.items(): setattr(job, k, v) - async with self.nullcontext( # type: ignore[attr-defined] - connection - ) if connection else self.pool.acquire() as conn: + async with self.with_connection(connection) as conn, conn.transaction(): if expire_at != -1: await conn.execute( dedent(f""" UPDATE {self.jobs_table} - SET job=$1, status = $2, expire_at = $3 - WHERE key = $4 + SET job=$1, status = $2, scheduled = $3, expire_at = $4 + WHERE key = $5 """), self.serialize(job), job.status, + job.scheduled, expire_at, job.key, ) @@ -334,37 +345,40 @@ async def update( await conn.execute( dedent(f""" UPDATE {self.jobs_table} - SET job=$1, status = $2 - WHERE key = $3 + SET job=$1, status = $2, scheduled=$3 + WHERE key = $4 """), self.serialize(job), job.status, + job.scheduled, job.key, ) await self.notify(job, conn) async def job(self, job_key: str) -> Job | None: - async with self.pool.acquire() as conn: - record = await conn.fetchrow( + async with self.with_connection() as conn, conn.transaction(): + cursor = await conn.cursor( f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key ) - return self.deserialize(record["job"]) if record else None + record = await cursor.fetchrow() + return self.deserialize(record.get("job")) if record else None async def jobs(self, job_keys: Iterable[str]) -> List[Job | None]: keys = list(job_keys) - async with self.pool.acquire() as conn: - records = await conn.fetch( + results = {} + async with self.with_connection() as conn, conn.transaction(): + async for record in conn.cursor( f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys - ) - results = {record.get('key'): record.get('job') for record in records} - return [self.deserialize(results.get(key)) for key in keys] + ): + results[record.get("key")] = record.get("job") + return [self.deserialize(results.get(key)) for key in keys] async def iter_jobs( self, statuses: List[Status] = list(Status), batch_size: int = 100, ) -> AsyncIterator[Job]: - async with self.pool.acquire() as conn: + async with self.with_connection() as conn, conn.transaction(): last_key = "" while True: rows = await conn.fetch( @@ -382,16 +396,16 @@ async def iter_jobs( batch_size, ) if rows: - for key, job_bytes in rows: - last_key = key - job = self.deserialize(job_bytes) + for row in rows: + last_key = row.get("key") + job = self.deserialize(row.get("job")) if job: yield job else: break async def abort(self, job: Job, error: str, ttl: float = 5) -> None: - async with self.pool.acquire() as conn: + async with self.with_connection() as conn: status = await self.get_job_status( job.key, for_update=True, connection=conn ) @@ -448,7 +462,8 @@ async def _dequeue(self) -> None: async with self._get_dequeue_conn() as conn, conn.transaction(): if not self._waiting: return - results = await conn.fetch( + should_notify = False + async for record in conn.cursor( dedent(f""" WITH locked_job AS ( SELECT key, lock_key @@ -467,17 +482,17 @@ async def _dequeue(self) -> None: RETURNING job """), self.name, - math.ceil(seconds(now())), + int(seconds(now())), self._waiting, - ) - for result in results: - self._job_queue.put_nowait(self.deserialize(result[0])) - if results: - await self._notify(DEQUEUE) + ): + should_notify = True + self._job_queue.put_nowait(self.deserialize(record.get("job"))) + if should_notify: + await self._notify(DEQUEUE) async def _enqueue(self, job: Job) -> Job | None: - async with self.pool.acquire() as conn: - result = await conn.execute( + async with self.with_connection() as conn, conn.transaction(): + cursor = await conn.cursor( f""" INSERT INTO {self.jobs_table} (key, job, queue, status, scheduled) VALUES ($1, $2, $3, $4, $5) @@ -497,17 +512,16 @@ async def _enqueue(self, job: Job) -> Job | None: self.serialize(job), self.name, job.status, - job.scheduled or seconds(now()), + job.scheduled or int(seconds(now())), ) - - if not result: + if not await cursor.fetchrow(): return None await self._notify(ENQUEUE, connection=conn) logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) return job async def write_stats(self, stats: QueueStats, ttl: int) -> None: - async with self.pool.acquire() as conn: + async with self.with_connection() as conn: await conn.execute( f""" INSERT INTO {self.stats_table} (worker_id, stats, expire_at) @@ -526,18 +540,16 @@ async def get_job_status( for_update: bool = False, connection: PoolConnectionProxy | None = None, ) -> Status: - async with self.nullcontext( # type: ignore[attr-defined] - connection - ) if connection else self.pool.acquire() as conn: + async with self.with_connection(connection) as conn: result = await conn.fetchval( - f""" + dedent(f""" SELECT status FROM {self.jobs_table} WHERE key = $1 {"FOR UPDATE" if for_update else ""} - """, + """), key, - ) + ) assert result return result @@ -546,7 +558,7 @@ async def _retry(self, job: Job, error: str | None) -> None: if next_retry_delay: scheduled = time.time() + next_retry_delay else: - scheduled = job.scheduled or seconds(now()) + scheduled = job.scheduled or int(seconds(now())) await self.update(job, scheduled=int(scheduled), expire_at=None) @@ -561,11 +573,9 @@ async def _finish( ) -> None: key = job.key - async with self.nullcontext( # type: ignore[attr-defined] - connection - ) if connection else self.pool.acquire() as conn: + async with self.with_connection(connection) as conn: if job.ttl >= 0: - expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None + expire_at = int(seconds(now())) + job.ttl if job.ttl > 0 else None await self.update( job, status=status, expire_at=expire_at, connection=conn ) @@ -577,8 +587,8 @@ async def _finish( """), key, ) - await self.notify(job, conn) - await self._release_job(key) + await self.notify(job, conn) + await self._release_job(key) async def _notify( self, @@ -591,33 +601,16 @@ async def _notify( if data is not None: payload["data"] = data - async with self.nullcontext( # type: ignore[attr-defined] - connection - ) if connection else self.pool.acquire() as conn: + async with self.with_connection(connection) as conn: await conn.execute(f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'") - @asynccontextmanager - async def _get_dequeue_conn(self) -> AsyncGenerator: - async with self._connection_lock: - if self._dequeue_conn: - try: - await self._dequeue_conn.execute("SELECT 1") - except (ConnectionDoesNotExistError, InterfaceError): - await self.pool.release(self._dequeue_conn) - self._dequeue_conn = await self.pool.acquire() - else: - self._dequeue_conn = await self.pool.acquire() - yield self._dequeue_conn - - @asynccontextmanager - async def nullcontext(self, enter_result: Any | None = None) -> AsyncGenerator: - yield enter_result - async def _release_job(self, key: str) -> None: self._releasing.append(key) if self._connection_lock.locked(): return async with self._get_dequeue_conn() as conn: + txn = conn.transaction() + await txn.start() await conn.execute( f""" SELECT pg_advisory_unlock({self.job_lock_keyspace}, lock_key) @@ -626,9 +619,32 @@ async def _release_job(self, key: str) -> None: """, self._releasing, ) - await conn.execute("commit;") + await txn.commit() self._releasing.clear() + @asynccontextmanager + async def _get_dequeue_conn(self) -> AsyncGenerator[PoolConnectionProxy]: + async with self._connection_lock: + if self._dequeue_conn: + try: + await self._dequeue_conn.execute("SELECT 1") + except (ConnectionDoesNotExistError, InterfaceError): + await self.pool.release(self._dequeue_conn) + self._dequeue_conn = await self.pool.acquire() + else: + self._dequeue_conn = await self.pool.acquire() + + yield self._dequeue_conn + + + @asynccontextmanager + async def _nullcontext(self, enter_result: ContextT) -> AsyncGenerator[ContextT]: + """Async version of contextlib.nullcontext + + Async support has been added to contextlib.nullcontext in Python 3.10. + """ + yield enter_result + class ListenMultiplexer(Multiplexer): def __init__(self, pool: Pool, key: str) -> None: @@ -640,8 +656,10 @@ def __init__(self, pool: Pool, key: str) -> None: async def _start(self) -> None: if self._connection is None: self._connection = await self.pool.acquire() + txn = self._connection.transaction() + await txn.start() await self._connection.add_listener(self.key, self._notify_callback) - await self._connection.execute("COMMIT;") + await txn.commit() async def _close(self) -> None: if self._connection: @@ -649,6 +667,12 @@ async def _close(self) -> None: await self._connection.close() self._connection = None - async def _notify_callback(self, connection, pid, channel, payload): + async def _notify_callback( + self, + connection: Connection | PoolConnectionProxy, + pid: int, + channel: str, + payload: Any, + ) -> None: payload_data = json.loads(payload) self.publish(payload_data["key"], payload_data) diff --git a/tests/test_queue.py b/tests/test_queue.py index f1d49c4..b32e0f4 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -16,21 +16,20 @@ from saq.worker import Worker from tests.helpers import ( cleanup_queue, + create_postgres_asyncpg_queue, create_postgres_queue, create_redis_queue, setup_postgres, setup_postgres_asyncpg, teardown_postgres, teardown_postgres_asyncpg, - create_postgres_asyncpg_queue ) if t.TYPE_CHECKING: from unittest.mock import MagicMock - + from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue from saq.queue.postgres import PostgresQueue - from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue as AsyncpgPostgresQueue from saq.queue.redis import RedisQueue from saq.types import Context, CountKind, Function @@ -92,7 +91,7 @@ async def test_enqueue_job_str(self) -> None: async def test_enqueue_dup(self) -> None: job = await self.enqueue("test", key="1") - self.assertEqual(job.id, "saq:job:default$1") + self.assertEqual(job.id, "saq:job:default:1") self.assertIsNone(await self.queue.enqueue("test", key="1")) self.assertIsNone(await self.queue.enqueue(job)) @@ -722,9 +721,14 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: result = await cursor.fetchone() self.assertIsNone(result) - async def test_cron_job_close_to_target(self) -> None: - await self.enqueue("test", scheduled=time.time() + 0.5) - job = await self.queue.dequeue(timeout=0.1) + @mock.patch("saq.utils.time") + async def test_cron_job_close_to_target(self, mock_time: MagicMock) -> None: + mock_time.time.return_value = 1000.5 + await self.enqueue("test", scheduled=1001) + + # The job is scheduled to run at 1001, but we're running at 1000.5 + # so it should not be picked up + job = await self.queue.dequeue(timeout=1) assert not job async def test_bad_connection(self) -> None: @@ -764,12 +768,12 @@ class TestAsyncpgPostgresQueue(TestQueue): async def asyncSetUp(self) -> None: await setup_postgres_asyncpg() self.create_queue = create_postgres_asyncpg_queue - self.queue: AsyncpgPostgresQueue = await self.create_queue() + self.queue: PostgresAsyncpgQueue = await self.create_queue() async def asyncTearDown(self) -> None: await super().asyncTearDown() await teardown_postgres_asyncpg() - + @unittest.skip("Not implemented") async def test_job_key(self) -> None: pass @@ -830,7 +834,6 @@ async def test_sweep(self, mock_time: MagicMock) -> None: { job1.key, job2.key, - job3.key, }, ) await job1.refresh() @@ -839,7 +842,7 @@ async def test_sweep(self, mock_time: MagicMock) -> None: self.assertEqual(job1.status, Status.ABORTED) self.assertEqual(job2.status, Status.QUEUED) self.assertEqual(job3.status, Status.QUEUED) - self.assertEqual(await self.count("active"), 2) + self.assertEqual(await self.count("active"), 3) @mock.patch("saq.utils.time") async def test_sweep_stuck(self, mock_time: MagicMock) -> None: @@ -889,7 +892,7 @@ async def test_sweep_jobs(self) -> None: job2 = await self.enqueue("test", ttl=60) await self.queue.finish(job1, Status.COMPLETE) await self.queue.finish(job2, Status.COMPLETE) - await asyncio.sleep(1) + await asyncio.sleep(1.5) await self.queue.sweep() with self.assertRaisesRegex(RuntimeError, "doesn't exist"): @@ -900,11 +903,11 @@ async def test_sweep_jobs(self) -> None: async def test_sweep_stats(self) -> None: # Stats are deleted await self.queue.stats(ttl=1) - await asyncio.sleep(1) + await asyncio.sleep(1.5) await self.queue.sweep() - async with self.queue.pool.acquire() as conn: - result = await conn.fetchrow( - + async with self.queue.pool.acquire() as conn, conn.transaction(): + cursor = await conn.cursor( + """ SELECT stats FROM {} @@ -912,15 +915,15 @@ async def test_sweep_stats(self) -> None: """.format(self.queue.stats_table), self.queue.uuid ) - self.assertIsNone(result) + self.assertIsNone(await cursor.fetchrow()) # Stats are not deleted await self.queue.stats(ttl=60) await asyncio.sleep(1) - await self.queue.sweep() - async with self.queue.pool.acquire() as conn : - result = await conn.fetchrow( - + # await self.queue.sweep() + async with self.queue.pool.acquire() as conn, conn.transaction(): + cursor = await conn.cursor( + """ SELECT stats FROM {} @@ -928,7 +931,7 @@ async def test_sweep_stats(self) -> None: """.format(self.queue.stats_table), self.queue.uuid ) - self.assertIsNotNone(result) + self.assertIsNotNone(await cursor.fetchrow()) async def test_job_lock(self) -> None: query = """ @@ -940,28 +943,27 @@ async def test_job_lock(self) -> None: """.format(self.queue.jobs_table, self.queue.job_lock_keyspace) job = await self.enqueue("test") await self.dequeue() - async with self.queue.pool.acquire() as conn : + async with self.queue.pool.acquire() as conn, conn.transaction(): result = await conn.fetchval(query, job.key) self.assertEqual(result, 1) await self.finish(job, Status.COMPLETE, result=1) - async with self.queue.pool.acquire() as conn : - result = await conn.execute(query, job.key) - self.assertEqual(result, (0,)) + async with self.queue.pool.acquire() as conn, conn.transaction(): + result = await conn.fetchval(query, job.key) + self.assertEqual(result, 0) async def test_load_dump_pickle(self) -> None: self.queue = await self.create_queue(dump=pickle.dumps, load=pickle.loads) job = await self.enqueue("test") - async with self.queue.pool.acquire() as conn : - result = await conn.fetchrow( - """ + async with self.queue.pool.acquire() as conn, conn.transaction(): + result = await conn.fetchrow(""" SELECT job FROM {} WHERE key =$1 """ .format(self.queue.jobs_table), job.key, - ) + ) assert result fetched_job = pickle.loads(result[0]) self.assertIsInstance(fetched_job, dict) @@ -976,16 +978,16 @@ async def test_finish_ttl_positive(self, mock_time: MagicMock) -> None: job = await self.enqueue("test", ttl=5) await self.dequeue() await self.finish(job, Status.COMPLETE) - async with self.queue.pool.acquire() as conn : + async with self.queue.pool.acquire() as conn: result = await conn.fetchval( - + """ SELECT expire_at FROM {} WHERE key = $1 - """ .format(self.queue.jobs_table), + """.format(self.queue.jobs_table), job.key, - ) + ) self.assertEqual(result,5) @mock.patch("saq.utils.time") @@ -996,14 +998,14 @@ async def test_finish_ttl_neutral(self, mock_time: MagicMock) -> None: await self.finish(job, Status.COMPLETE) async with self.queue.pool.acquire() as conn : result = await conn.fetchval( - + """ SELECT expire_at FROM {} WHERE key = $1 """ .format(self.queue.jobs_table), job.key, - ) + ) self.assertEqual(result,None) @mock.patch("saq.utils.time") @@ -1023,6 +1025,16 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: ) self.assertIsNone(result) + @mock.patch("saq.utils.time") + async def test_cron_job_close_to_target(self, mock_time: MagicMock) -> None: + mock_time.time.return_value = 1000.5 + await self.enqueue("test", scheduled=1001) + + # The job is scheduled to run at 1001, but we're running at 1000.5 + # so it should not be picked up + job = await self.queue.dequeue(timeout=1) + assert not job + async def test_bad_connection(self) -> None: job = await self.enqueue("test") original_connection = self.queue._dequeue_conn From a672c04720a88682e068ca708b73e085fa0c14d3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 16:47:26 +0000 Subject: [PATCH 08/13] feat: implement lock check for DDLs --- saq/queue/postgres_asyncpg.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py index de71e2c..47f99e1 100644 --- a/saq/queue/postgres_asyncpg.py +++ b/saq/queue/postgres_asyncpg.py @@ -95,6 +95,13 @@ def __init__( async def init_db(self) -> None: async with self.with_connection() as conn, conn.transaction(): + cursor = await conn.cursor( + "SELECT pg_try_advisory_lock($1, 0)", self.saq_lock_keyspace, + ) + result = await cursor.fetchrow() + + if result and not result[0]: + return for statement in DDL_STATEMENTS: await conn.execute( statement.format( From 87667f2606f319799395b749d48ca181160a4587 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 18:55:46 +0000 Subject: [PATCH 09/13] feat: merge upstream changes --- tests/test_queue.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_queue.py b/tests/test_queue.py index b32e0f4..35dbe6a 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -721,14 +721,9 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: result = await cursor.fetchone() self.assertIsNone(result) - @mock.patch("saq.utils.time") - async def test_cron_job_close_to_target(self, mock_time: MagicMock) -> None: - mock_time.time.return_value = 1000.5 - await self.enqueue("test", scheduled=1001) - - # The job is scheduled to run at 1001, but we're running at 1000.5 - # so it should not be picked up - job = await self.queue.dequeue(timeout=1) + async def test_cron_job_close_to_target(self) -> None: + await self.enqueue("test", scheduled=time.time() + 0.5) + job = await self.queue.dequeue(timeout=0.1) assert not job async def test_bad_connection(self) -> None: @@ -763,7 +758,6 @@ async def test_priority(self) -> None: self.assertEqual(await self.count("queued"), 1) assert not await self.queue.dequeue(0.01) - class TestAsyncpgPostgresQueue(TestQueue): async def asyncSetUp(self) -> None: await setup_postgres_asyncpg() From 425f0d4bff77f4a47eec2a8d25099cda6446911d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 19:54:47 +0000 Subject: [PATCH 10/13] fix: rebase and remove psycopg --- saq/queue/base.py | 3 - saq/queue/postgres.py | 773 +++++++++++++++------------------- saq/queue/postgres_asyncpg.py | 685 ------------------------------ setup.py | 4 +- tests/helpers.py | 37 +- tests/test_queue.py | 342 +-------------- 6 files changed, 348 insertions(+), 1496 deletions(-) delete mode 100644 saq/queue/postgres_asyncpg.py diff --git a/saq/queue/base.py b/saq/queue/base.py index 542cb03..e5edb09 100644 --- a/saq/queue/base.py +++ b/saq/queue/base.py @@ -156,9 +156,6 @@ def from_url(url: str, **kwargs: t.Any) -> Queue: from saq.queue.redis import RedisQueue return RedisQueue.from_url(url, **kwargs) - if url.startswith("postgres+asyncpg"): - from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue - return PostgresAsyncpgQueue.from_url(url.replace("postgres+asyncpg", "postgres"), **kwargs) if url.startswith("postgres"): from saq.queue.postgres import PostgresQueue diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 3bb6326..6189332 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -35,9 +35,9 @@ ) try: - from psycopg import AsyncConnection, OperationalError - from psycopg.sql import Identifier, SQL - from psycopg_pool import AsyncConnectionPool + from asyncpg import Pool, create_pool, Connection + from asyncpg.exceptions import ConnectionDoesNotExistError, InterfaceError + from asyncpg.pool import PoolConnectionProxy except ModuleNotFoundError as e: raise MissingDependencyError( "Missing dependencies for Postgres. Install them with `pip install saq[postgres]`." @@ -49,6 +49,8 @@ JOBS_TABLE = "saq_jobs" STATS_TABLE = "saq_stats" +ContextT = t.TypeVar("ContextT") + class PostgresQueue(Queue): """ @@ -61,11 +63,7 @@ class PostgresQueue(Queue): stats_table: name of the Postgres table SAQ will write stats to (default "saq_stats") dump: lambda that takes a dictionary and outputs bytes (default `json.dumps`) load: lambda that takes str or bytes and outputs a python dictionary (default `json.loads`) - min_size: minimum pool size. (default 4) - The minimum number of Postgres connections. - max_size: maximum pool size. (default 20) - If greater than 0, this limits the maximum number of connections to Postgres. - Otherwise, maintain `min_size` number of connections. + poll_interval: how often to poll for jobs. (default 1) If 0, the queue will not poll for jobs and will only rely on notifications from the server. This mean cron jobs will not be picked up in a timely fashion. @@ -76,23 +74,34 @@ class PostgresQueue(Queue): """ @classmethod - def from_url(cls: type[PostgresQueue], url: str, **kwargs: t.Any) -> PostgresQueue: - """Create a queue from a postgres url.""" - return cls( - AsyncConnectionPool(url, check=AsyncConnectionPool.check_connection, open=False), - **kwargs, - ) + def from_url( # pyright: ignore[reportIncompatibleMethodOverride] + cls: type[PostgresQueue], + url: str, + min_size: int = 4, + max_size: int = 20, + **kwargs: t.Any, + ) -> PostgresQueue: + """Create a queue from a postgres url. + + Args: + url: connection string for the databases + min_size: minimum pool size. (default 4) + The minimum number of Postgres connections. + max_size: maximum pool size. (default 20) + If greater than 0, this limits the maximum number of connections to Postgres. + Otherwise, maintain `min_size` number of connections. + + """ + return cls(create_pool(dsn=url, min_size=min_size, max_size=max_size), **kwargs) def __init__( self, - pool: AsyncConnectionPool, + pool: Pool[t.Any], name: str = "default", jobs_table: str = JOBS_TABLE, stats_table: str = STATS_TABLE, dump: DumpType | None = None, load: LoadType | None = None, - min_size: int = 4, - max_size: int = 20, poll_interval: int = 1, saq_lock_keyspace: int = 0, job_lock_keyspace: int = 1, @@ -100,11 +109,9 @@ def __init__( ) -> None: super().__init__(name=name, dump=dump, load=load) - self.jobs_table = Identifier(jobs_table) - self.stats_table = Identifier(stats_table) - self.pool = pool - self.min_size = min_size - self.max_size = max_size + self.jobs_table = jobs_table + self.stats_table = stats_table + self.pool = pool self.poll_interval = poll_interval self.saq_lock_keyspace = saq_lock_keyspace self.job_lock_keyspace = job_lock_keyspace @@ -112,7 +119,7 @@ def __init__( self._job_queue: asyncio.Queue = asyncio.Queue() self._waiting = 0 # Internal counter of worker tasks waiting for dequeue - self._dequeue_conn: AsyncConnection | None = None + self._dequeue_conn: PoolConnectionProxy | None = None self._connection_lock = asyncio.Lock() self._releasing: list[str] = [] self._has_sweep_lock = False @@ -121,28 +128,38 @@ def __init__( self._dequeue_lock = asyncio.Lock() self._listen_lock = asyncio.Lock() + @asynccontextmanager + async def with_connection( + self, connection: PoolConnectionProxy | None = None + ) -> t.AsyncGenerator[PoolConnectionProxy]: + async with self.nullcontext( + connection + ) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined] + yield conn + async def init_db(self) -> None: - async with self.pool.connection() as conn, conn.cursor() as cursor, conn.transaction(): - await cursor.execute( - SQL("SELECT pg_try_advisory_lock(%(key1)s, 0)"), - {"key1": self.saq_lock_keyspace}, + async with self.with_connection() as conn, conn.transaction(): + cursor = await conn.cursor( + "SELECT pg_try_advisory_lock($1, 0)", self.saq_lock_keyspace, ) - result = await cursor.fetchone() + result = await cursor.fetchrow() if result and not result[0]: return - for statement in DDL_STATEMENTS: - await cursor.execute( - SQL(statement).format(jobs_table=self.jobs_table, stats_table=self.stats_table) - ) + await conn.execute( + statement.format( + jobs_table=self.jobs_table, stats_table=self.stats_table + ) + ) async def connect(self) -> None: - if self.pool._opened: + if self._dequeue_conn: return - - await self.pool.open() - await self.pool.resize(min_size=self.min_size, max_size=self.max_size) + # the return from the `from_url` call must be awaited. The loop isn't running at the time `from_url` is called, so this seemed to make the most sense + self.pool._loop = asyncio.get_event_loop() # type: ignore[attr-defined] + await self.pool + self._dequeue_conn = await self.pool.acquire() await self.init_db() def serialize(self, job: Job) -> bytes | str: @@ -155,45 +172,35 @@ def serialize(self, job: Job) -> bytes | str: async def disconnect(self) -> None: async with self._connection_lock: if self._dequeue_conn: - await self._dequeue_conn.cancel_safe() - await self.pool.putconn(self._dequeue_conn) + await self.pool.release(self._dequeue_conn) self._dequeue_conn = None await self.pool.close() self._has_sweep_lock = False async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo: - async with self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - SELECT worker_id, stats FROM {stats_table} - WHERE NOW() <= TO_TIMESTAMP(expire_at) - """ - ) - ).format(stats_table=self.stats_table), + workers: dict[str, dict[str, t.Any]] = {} + async with self.with_connection() as conn: + results = await conn.fetch( + dedent(f""" + SELECT worker_id, stats FROM {self.stats_table} + WHERE NOW() <= TO_TIMESTAMP(expire_at) + """) ) - results = await cursor.fetchall() - workers: dict[str, dict[str, t.Any]] = dict(results) - + for record in results: + workers[record.get("worker_id")] = record.get("stats") queued = await self.count("queued") active = await self.count("active") incomplete = await self.count("incomplete") if jobs: - async with self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - SELECT job FROM {jobs_table} - WHERE status IN ('new', 'deferred', 'queued', 'active') - """ - ) - ).format(jobs_table=self.jobs_table), + async with self.with_connection() as conn: + results = await conn.fetch( + dedent(f""" + SELECT job FROM {self.jobs_table} + WHERE status IN ('new', 'deferred', 'queued', 'active') + """) ) - results = await cursor.fetchall() - deserialized_jobs = (self.deserialize(result[0]) for result in results) + deserialized_jobs = (self.deserialize(result["job"]) for result in results) jobs_info = [job.to_dict() for job in deserialized_jobs if job] else: jobs_info = [] @@ -208,54 +215,42 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu } async def count(self, kind: CountKind) -> int: - async with self.pool.connection() as conn, conn.cursor() as cursor: + async with self.with_connection() as conn: if kind == "queued": - await cursor.execute( - SQL( - dedent( - """ - SELECT count(*) - FROM {jobs_table} - WHERE status = 'queued' - AND queue = %(queue)s - AND NOW() >= TO_TIMESTAMP(scheduled) - """ - ) - ).format(jobs_table=self.jobs_table), - {"queue": self.name}, + result = await conn.fetchval( + dedent(f""" + SELECT count(*) + FROM {self.jobs_table} + WHERE status = 'queued' + AND queue = $1 + AND NOW() >= TO_TIMESTAMP(scheduled) + """), + self.name, ) elif kind == "active": - await cursor.execute( - SQL( - dedent( - """ - SELECT count(*) FROM {jobs_table} - WHERE status = 'active' - AND queue = %(queue)s - """ - ) - ).format(jobs_table=self.jobs_table), - {"queue": self.name}, + result = await conn.fetchval( + dedent(f""" + SELECT count(*) + FROM {self.jobs_table} + WHERE status = 'active' + AND queue = $1 + """), + self.name, ) elif kind == "incomplete": - await cursor.execute( - SQL( - dedent( - """ - SELECT count(*) FROM {jobs_table} - WHERE status IN ('new', 'deferred', 'queued', 'active') - AND queue = %(queue)s - """ - ) - ).format(jobs_table=self.jobs_table), - {"queue": self.name}, + result = await conn.fetchval( + dedent(f""" + SELECT count(*) + FROM {self.jobs_table} + WHERE status IN ('new', 'deferred', 'queued', 'active') + AND queue = $1 + """), + self.name, ) else: - raise ValueError("Can't count unknown type {kind}") + raise ValueError(f"Can't count unknown type {kind}") - result = await cursor.fetchone() - assert result - return result[0] + return result async def schedule(self, lock: int = 1) -> t.List[str]: await self._dequeue() @@ -264,89 +259,62 @@ async def schedule(self, lock: int = 1) -> t.List[str]: async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: """Delete jobs and stats past their expiration and sweep stuck jobs""" swept = [] - + if not self._has_sweep_lock: # Attempt to get the sweep lock and hold on to it - async with self._get_dequeue_conn() as conn, conn.cursor() as cursor, conn.transaction(): - await cursor.execute( - SQL("SELECT pg_try_advisory_lock(%(key1)s, hashtext(%(queue)s))"), - { - "key1": self.saq_lock_keyspace, - "queue": self.name, - }, + async with self._get_dequeue_conn() as conn: + result = await conn.fetchval( + dedent("SELECT pg_try_advisory_lock($1, hashtext($2))"), + self.saq_lock_keyspace, + self.name, ) - result = await cursor.fetchone() - if result and not result[0]: + if not result: # Could not acquire the sweep lock so another worker must already have it return [] self._has_sweep_lock = True - async with self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - -- Delete expired jobs - DELETE FROM {jobs_table} - WHERE queue = %(queue)s - AND status IN ('aborted', 'complete', 'failed') - AND NOW() >= TO_TIMESTAMP(expire_at); - """ - ) - ).format( - jobs_table=self.jobs_table, - stats_table=self.stats_table, - ), - { - "queue": self.name, - }, + async with self.with_connection() as conn, conn.transaction(): + await conn.execute( + dedent(f""" + -- Delete expired jobs + DELETE FROM {self.jobs_table} + WHERE queue = $1 + AND status IN ('aborted', 'complete', 'failed') + AND NOW() >= TO_TIMESTAMP(expire_at) + """), + self.name, ) - - await cursor.execute( - SQL( - dedent( - """ - -- Delete expired stats - DELETE FROM {stats_table} - WHERE NOW() >= TO_TIMESTAMP(expire_at); - """ - ) - ).format( - jobs_table=self.jobs_table, - stats_table=self.stats_table, - ), + await conn.execute( + dedent(f""" + -- Delete expired stats + DELETE FROM {self.stats_table} + WHERE NOW() >= TO_TIMESTAMP(expire_at); + """), ) - - await cursor.execute( - SQL( - dedent( - """ + results = await conn.fetch( + dedent( + f""" WITH locks AS ( SELECT objid FROM pg_locks WHERE locktype = 'advisory' - AND classid = %(job_lock_keyspace)s + AND classid = $1 AND objsubid = 2 -- key is int pair, not single bigint ) SELECT key, job, objid, status - FROM {jobs_table} + FROM {self.jobs_table} LEFT OUTER JOIN locks ON lock_key = objid - WHERE queue = %(queue)s + WHERE queue = $2 AND status IN ('active', 'aborting'); """ - ) - ).format( - jobs_table=self.jobs_table, ), - { - "queue": self.name, - "job_lock_keyspace": self.job_lock_keyspace, - }, + self.job_lock_keyspace, + self.name, ) - results = await cursor.fetchall() - for key, job_bytes, objid, status in results: + for row in results: + key, job_bytes, objid, status = row.values() job = self.deserialize(job_bytes) assert job if objid and not job.stuck: @@ -386,13 +354,15 @@ async def listen( if stop: break - async def notify(self, job: Job, connection: AsyncConnection | None = None) -> None: + async def notify( + self, job: Job, connection: PoolConnectionProxy | None = None + ) -> None: await self._notify(job.key, job.status, connection) async def update( self, job: Job, - connection: AsyncConnection | None = None, + connection: PoolConnectionProxy | None = None, expire_at: float | None = -1, **kwargs: t.Any, ) -> None: @@ -400,70 +370,50 @@ async def update( for k, v in kwargs.items(): setattr(job, k, v) - - async with self.nullcontext(connection) if connection else self.pool.connection() as conn: - await conn.execute( - SQL( - dedent( - """ - UPDATE {jobs_table} SET - job = %(job)s - ,status = %(status)s - ,scheduled = %(scheduled)s - {expire_at} - WHERE key = %(key)s - """ - ) - ).format( - jobs_table=self.jobs_table, - expire_at=SQL(",expire_at = %(expire_at)s" if expire_at != -1 else ""), - ), - { - "job": self.serialize(job), - "status": job.status, - "key": job.key, - "scheduled": job.scheduled, - "expire_at": expire_at, - }, - ) + async with self.with_connection(connection) as conn, conn.transaction(): + if expire_at != -1: + await conn.execute( + dedent(f""" + UPDATE {self.jobs_table} + SET job=$1, status = $2, scheduled = $3, expire_at = $4 + WHERE key = $5 + """), + self.serialize(job), + job.status, + job.scheduled, + expire_at, + job.key, + ) + else: + await conn.execute( + dedent(f""" + UPDATE {self.jobs_table} + SET job=$1, status = $2, scheduled=$3 + WHERE key = $4 + """), + self.serialize(job), + job.status, + job.scheduled, + job.key, + ) await self.notify(job, conn) async def job(self, job_key: str) -> Job | None: - async with self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - SELECT job - FROM {jobs_table} - WHERE key = %(key)s - """ - ) - ).format(jobs_table=self.jobs_table), - {"key": job_key}, + async with self.with_connection() as conn, conn.transaction(): + cursor = await conn.cursor( + f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key ) - job = await cursor.fetchone() - if job: - return self.deserialize(job[0]) - return None + record = await cursor.fetchrow() + return self.deserialize(record.get("job")) if record else None async def jobs(self, job_keys: Iterable[str]) -> t.List[Job | None]: keys = list(job_keys) - - async with self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - SELECT key, job - FROM {jobs_table} - WHERE key = ANY(%(keys)s) - """ - ) - ).format(jobs_table=self.jobs_table), - {"keys": keys}, - ) - results: dict[str, bytes | None] = dict(await cursor.fetchall()) + results: dict[str, bytes | None] = {} + async with self.with_connection() as conn, conn.transaction(): + async for record in conn.cursor( + f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys + ): + results[record.get("key")] = record.get("job") return [self.deserialize(results.get(key)) for key in keys] async def iter_jobs( @@ -471,46 +421,34 @@ async def iter_jobs( statuses: t.List[Status] = list(Status), batch_size: int = 100, ) -> t.AsyncIterator[Job]: - async with self.pool.connection() as conn, conn.cursor() as cursor: + async with self.with_connection() as conn, conn.transaction(): last_key = "" - while True: - await cursor.execute( - SQL( - dedent( - """ - SELECT key, job - FROM {jobs_table} - WHERE - status = ANY(%(statuses)s) - AND queue = %(queue)s - AND key > %(last_key)s - ORDER BY key - LIMIT %(batch_size)s - """ - ) - ).format(jobs_table=self.jobs_table), - { - "statuses": statuses, - "queue": self.name, - "batch_size": batch_size, - "last_key": last_key, - }, + rows = await conn.fetch( + dedent(f""" + SELECT key, job + FROM {self.jobs_table} + WHERE status = ANY($1) + AND queue = $2 + AND key > $3 + ORDER BY key + LIMIT $4"""), + statuses, + self.name, + last_key, + batch_size, ) - - rows = await cursor.fetchall() - if rows: - for key, job_bytes in rows: - last_key = key - job = self.deserialize(job_bytes) + for row in rows: + last_key = row.get("key") + job = self.deserialize(row.get("job")) if job: yield job else: break async def abort(self, job: Job, error: str, ttl: float = 5) -> None: - async with self.pool.connection() as conn: + async with self.with_connection() as conn: status = await self.get_job_status(job.key, for_update=True, connection=conn) if status == Status.QUEUED: await self.finish(job, Status.ABORTED, error=error, connection=conn) @@ -558,164 +496,120 @@ async def _dequeue(self) -> None: return async with self._dequeue_lock: - async with self._get_dequeue_conn() as conn, conn.cursor() as cursor, conn.transaction(): + async with self._get_dequeue_conn() as conn, conn.transaction(): if not self._waiting: return - await cursor.execute( - SQL( - dedent( - """ - WITH locked_job AS ( - SELECT key, lock_key - FROM {jobs_table} - WHERE status = 'queued' - AND queue = %(queue)s - AND NOW() >= TO_TIMESTAMP(scheduled) - AND priority BETWEEN %(plow)s AND %(phigh)s - AND group_key NOT IN ( + should_notify = False + async for record in conn.cursor( + dedent(f""" + WITH locked_job AS ( + SELECT key, lock_key + FROM {self.jobs_table} + WHERE status = 'queued' + AND queue = $1 + AND NOW() >= TO_TIMESTAMP(scheduled) + AND priority BETWEEN $2 AND $3 + AND group_key NOT IN ( SELECT DISTINCT group_key - FROM {jobs_table} + FROM {self.jobs_table} WHERE status = 'active' - AND queue = %(queue)s + AND queue = $1 AND group_key IS NOT NULL ) - ORDER BY priority, scheduled - LIMIT %(limit)s - FOR UPDATE SKIP LOCKED - ) - UPDATE {jobs_table} SET status = 'active' - FROM locked_job - WHERE {jobs_table}.key = locked_job.key - AND pg_try_advisory_lock({job_lock_keyspace}, locked_job.lock_key) - RETURNING job - """ - ) - ).format( - jobs_table=self.jobs_table, - job_lock_keyspace=self.job_lock_keyspace, - ), - { - "queue": self.name, - "limit": self._waiting, - "plow": self._priorities[0], - "phigh": self._priorities[1], - }, - ) - results = await cursor.fetchall() - - for result in results: - self._job_queue.put_nowait(self.deserialize(result[0])) - - if results: + ORDER BY priority, scheduled + LIMIT $4 + FOR UPDATE SKIP LOCKED + ) + UPDATE {self.jobs_table} SET status = 'active' + FROM locked_job + WHERE {self.jobs_table}.key = locked_job.key + AND pg_try_advisory_lock({self.job_lock_keyspace}, locked_job.lock_key) + RETURNING job + """), + self.name, + self._priorities[0], + self._priorities[1], + self._waiting, + ): + should_notify = True + self._job_queue.put_nowait(self.deserialize(record.get("job"))) + if should_notify: await self._notify(DEQUEUE) async def _enqueue(self, job: Job) -> Job | None: - async with self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - INSERT INTO {jobs_table} ( - key, - job, - queue, - status, - priority, - group_key, - scheduled - ) - VALUES ( - %(key)s, - %(job)s, - %(queue)s, - %(status)s, - %(priority)s, - %(group_key)s, - %(scheduled)s - ) - ON CONFLICT (key) DO UPDATE - SET - job = %(job)s, - queue = %(queue)s, - status = %(status)s, - priority = %(priority)s, - group_key = %(group_key)s, - scheduled = %(scheduled)s, - expire_at = null - WHERE - {jobs_table}.status IN ('aborted', 'complete', 'failed') - AND %(scheduled)s > {jobs_table}.scheduled - RETURNING 1 - """ - ) - ).format(jobs_table=self.jobs_table), - { - "key": job.key, - "job": self.serialize(job), - "queue": self.name, - "status": job.status, - "priority": job.priority, - "group_key": job.group_key, - "scheduled": job.scheduled or int(seconds(now())), - }, + async with self.with_connection() as conn, conn.transaction(): + cursor = await conn.cursor( + f""" + INSERT INTO {self.jobs_table} ( + key, + job, + queue, + status, + priority, + group_key, + scheduled + ) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (key) DO UPDATE + SET + job = $2, + queue = $3, + status = $4, + priority = $5, + group_key = $6, + scheduled = $7, + expire_at = null + WHERE + {self.jobs_table}.status IN ('aborted', 'complete', 'failed') + AND $7 > {self.jobs_table}.scheduled + RETURNING 1 + """, + job.key, + self.serialize(job), + self.name, + job.status, + job.priority, + str(job.group_key), + job.scheduled or int(seconds(now())), ) - - if not await cursor.fetchone(): + if not await cursor.fetchrow(): return None await self._notify(ENQUEUE, connection=conn) logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) return job async def write_stats(self, stats: QueueStats, ttl: int) -> None: - async with self.pool.connection() as conn: + async with self.with_connection() as conn: await conn.execute( - SQL( - dedent( - """ - INSERT INTO {stats_table} (worker_id, stats, expire_at) - VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s) - ON CONFLICT (worker_id) DO UPDATE - SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s - """ - ) - ).format(stats_table=self.stats_table), - { - "worker_id": self.uuid, - "stats": json.dumps(stats), - "ttl": ttl, - }, + dedent(f""" + INSERT INTO {self.stats_table} (worker_id, stats, expire_at) + VALUES ($1, $2, EXTRACT(EPOCH FROM NOW()) + $3) + ON CONFLICT (worker_id) DO UPDATE + SET stats = $2, expire_at = EXTRACT(EPOCH FROM NOW()) + $3 + """), + self.uuid, + json.dumps(stats), + ttl, ) async def get_job_status( self, key: str, for_update: bool = False, - connection: AsyncConnection | None = None, + connection: PoolConnectionProxy | None = None, ) -> Status: - async with self.nullcontext( - connection - ) if connection else self.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - dedent( - """ - SELECT status - FROM {jobs_table} - WHERE key = %(key)s - {for_update} - """ - ) - ).format( - jobs_table=self.jobs_table, - for_update=SQL("FOR UPDATE" if for_update else ""), - ), - { - "key": key, - }, - ) - result = await cursor.fetchone() + async with self.with_connection(connection) as conn: + result = await conn.fetchval( + dedent(f""" + SELECT status + FROM {self.jobs_table} + WHERE key = $1 + {"FOR UPDATE" if for_update else ""} + """), + key, + ) assert result - return result[0] + return result async def _retry(self, job: Job, error: str | None) -> None: next_retry_delay = job.next_retry_delay() @@ -733,45 +627,38 @@ async def _finish( *, result: t.Any = None, error: str | None = None, - connection: AsyncConnection | None = None, + connection: PoolConnectionProxy | None = None, ) -> None: key = job.key - async with self.nullcontext( - connection - ) if connection else self.pool.connection() as conn, conn.cursor() as cursor: + async with self.with_connection(connection) as conn: if job.ttl >= 0: - expire_at = seconds(now()) + job.ttl if job.ttl > 0 else None + expire_at = int(seconds(now())) + job.ttl if job.ttl > 0 else None await self.update(job, status=status, expire_at=expire_at, connection=conn) else: - await cursor.execute( - SQL( - dedent( - """ - DELETE FROM {jobs_table} - WHERE key = %(key)s - """ - ) - ).format(jobs_table=self.jobs_table), - {"key": key}, + await conn.execute( + dedent(f""" + DELETE FROM {self.jobs_table} + WHERE key = $1 + """), + key, ) - await self.notify(job, conn) - await self._release_job(key) + await self.notify(job, conn) + await self._release_job(key) async def _notify( - self, key: str, data: t.Any | None = None, connection: AsyncConnection | None = None + self, + key: str, + data: t.Any | None = None, + connection: PoolConnectionProxy | None = None, ) -> None: payload = {"key": key} if data is not None: payload["data"] = data - async with self.nullcontext(connection) if connection else self.pool.connection() as conn: - await conn.execute( - SQL("NOTIFY {channel}, {payload}").format( - channel=Identifier(self._channel), payload=json.dumps(payload) - ) - ) + async with self.with_connection(connection) as conn: + await conn.execute(f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'") @asynccontextmanager async def _get_dequeue_conn(self) -> t.AsyncGenerator: @@ -779,17 +666,18 @@ async def _get_dequeue_conn(self) -> t.AsyncGenerator: if self._dequeue_conn: try: # Pool normally performs this check when getting a connection. - await self.pool.check_connection(self._dequeue_conn) - except OperationalError: + await self._dequeue_conn.execute("SELECT 1") + except (ConnectionDoesNotExistError, InterfaceError): # The connection is bad so return it to the pool and get a new one. - await self.pool.putconn(self._dequeue_conn) - self._dequeue_conn = await self.pool.getconn() + await self.pool.release(self._dequeue_conn) + self._dequeue_conn = await self.pool.acquire() else: - self._dequeue_conn = await self.pool.getconn() + self._dequeue_conn = await self.pool.acquire() + yield self._dequeue_conn @asynccontextmanager - async def nullcontext(self, enter_result: t.Any | None = None) -> t.AsyncGenerator: + async def nullcontext(self, enter_result: ContextT) -> t.AsyncGenerator[ContextT]: """Async version of contextlib.nullcontext Async support has been added to contextlib.nullcontext in Python 3.10. @@ -801,36 +689,47 @@ async def _release_job(self, key: str) -> None: if self._connection_lock.locked(): return async with self._get_dequeue_conn() as conn: + txn = conn.transaction() + await txn.start() await conn.execute( - SQL( - dedent( - """ - SELECT pg_advisory_unlock({job_lock_keyspace}, lock_key) - FROM {jobs_table} - WHERE key = ANY(%(keys)s) - """ - ) - ).format( - jobs_table=self.jobs_table, - job_lock_keyspace=self.job_lock_keyspace, - ), - {"keys": self._releasing}, + f""" + SELECT pg_advisory_unlock({self.job_lock_keyspace}, lock_key) + FROM {self.jobs_table} + WHERE key = ANY($1) + """, + self._releasing, ) - await conn.commit() + await txn.commit() self._releasing.clear() class ListenMultiplexer(Multiplexer): - def __init__(self, pool: AsyncConnectionPool, key: str) -> None: + def __init__(self, pool: Pool, key: str) -> None: super().__init__() self.pool = pool self.key = key + self._connection: PoolConnectionProxy | None = None async def _start(self) -> None: - async with self.pool.connection() as conn: - await conn.execute(SQL("LISTEN {}").format(Identifier(self.key))) - await conn.commit() - - async for notify in conn.notifies(): - payload = json.loads(notify.payload) - self.publish(payload["key"], payload) \ No newline at end of file + if self._connection is None: + self._connection = await self.pool.acquire() + txn = self._connection.transaction() + await txn.start() + await self._connection.add_listener(self.key, self._notify_callback) + await txn.commit() + + async def _close(self) -> None: + if self._connection: + await self._connection.remove_listener(self.key, self._notify_callback) + await self._connection.close() + self._connection = None + + async def _notify_callback( + self, + connection: Connection | PoolConnectionProxy, + pid: int, + channel: str, + payload: t.Any, + ) -> None: + payload_data = json.loads(payload) + self.publish(payload_data["key"], payload_data) \ No newline at end of file diff --git a/saq/queue/postgres_asyncpg.py b/saq/queue/postgres_asyncpg.py deleted file mode 100644 index 47f99e1..0000000 --- a/saq/queue/postgres_asyncpg.py +++ /dev/null @@ -1,685 +0,0 @@ -""" -Postgres Queue using asyncpg -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import time -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from textwrap import dedent -from typing import TYPE_CHECKING, Any, AsyncIterator, List, TypeVar - -from saq.errors import MissingDependencyError -from saq.job import Job, Status -from saq.multiplexer import Multiplexer -from saq.queue.base import Queue, logger -from saq.queue.postgres_ddl import DDL_STATEMENTS -from saq.types import ListenCallback -from saq.utils import now, seconds - -if TYPE_CHECKING: - from collections.abc import Iterable - - from saq.types import CountKind, DumpType, LoadType, QueueInfo, QueueStats -try: - from asyncpg import Pool, create_pool, Connection - from asyncpg.exceptions import ConnectionDoesNotExistError, InterfaceError - from asyncpg.pool import PoolConnectionProxy - -except ModuleNotFoundError as e: - raise MissingDependencyError( - "Missing dependencies for Postgres. Install them with `pip install saq[asyncpg]`." - ) from e - -CHANNEL = "saq:{}" -ENQUEUE = "saq:enqueue" -DEQUEUE = "saq:dequeue" -JOBS_TABLE = "saq_jobs" -STATS_TABLE = "saq_stats" - -ContextT = TypeVar("ContextT") - - -class PostgresAsyncpgQueue(Queue): - """ - Queue is used to interact with Postgres using asyncpg. - """ - - @classmethod - def from_url( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type[PostgresAsyncpgQueue], - url: str, - min_size: int = 4, - max_size: int = 20, - **kwargs: Any, - ) -> PostgresAsyncpgQueue: - """Create a queue from a postgres url.""" - pool = create_pool(dsn=url, min_size=min_size, max_size=max_size) - return cls(pool, **kwargs) - - def __init__( - self, - pool: Pool[Any], - name: str = "default", - jobs_table: str = JOBS_TABLE, - stats_table: str = STATS_TABLE, - dump: DumpType | None = None, - load: LoadType | None = None, - poll_interval: int = 1, - saq_lock_keyspace: int = 0, - job_lock_keyspace: int = 1, - ) -> None: - super().__init__(name=name, dump=dump, load=load) - - self.jobs_table = jobs_table - self.stats_table = stats_table - self.pool = pool - self.poll_interval = poll_interval - self.saq_lock_keyspace = saq_lock_keyspace - self.job_lock_keyspace = job_lock_keyspace - - self._job_queue: asyncio.Queue = asyncio.Queue() - self._waiting = 0 - self._dequeue_conn: PoolConnectionProxy | None = None - self._connection_lock = asyncio.Lock() - self._releasing: list[str] = [] - self._has_sweep_lock = False - self._channel = CHANNEL.format(self.name) - self._listener = ListenMultiplexer(self.pool, self._channel) - self._dequeue_lock = asyncio.Lock() - self._listen_lock = asyncio.Lock() - - async def init_db(self) -> None: - async with self.with_connection() as conn, conn.transaction(): - cursor = await conn.cursor( - "SELECT pg_try_advisory_lock($1, 0)", self.saq_lock_keyspace, - ) - result = await cursor.fetchrow() - - if result and not result[0]: - return - for statement in DDL_STATEMENTS: - await conn.execute( - statement.format( - jobs_table=self.jobs_table, stats_table=self.stats_table - ) - ) - async def connect(self) -> None: - if self._dequeue_conn: - return - # the return from the `from_url` call must be awaited. The loop isn't running at the time `from_url` is called, so this seemed to make the most sense - self.pool._loop = asyncio.get_event_loop() # type: ignore[attr-defined] - await self.pool - self._dequeue_conn = await self.pool.acquire() - await self.init_db() - - def serialize(self, job: Job) -> bytes | str: - serialized = self._dump(job.to_dict()) - if isinstance(serialized, str): - return serialized.encode("utf-8") - return serialized - - @asynccontextmanager - async def with_connection( - self, connection: PoolConnectionProxy | None = None - ) -> AsyncGenerator[PoolConnectionProxy]: - async with self._nullcontext( - connection - ) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined] - yield conn - - async def disconnect(self) -> None: - async with self._connection_lock: - if self._dequeue_conn: - await self.pool.release(self._dequeue_conn) - self._dequeue_conn = None - await self.pool.close() - self._has_sweep_lock = False - - async def info( - self, jobs: bool = False, offset: int = 0, limit: int = 10 - ) -> QueueInfo: - workers: dict[str, dict[str, Any]] = {} - async with self.with_connection() as conn: - results = await conn.fetch( - dedent(f""" - SELECT worker_id, stats FROM {self.stats_table} - WHERE $1 <= expire_at - """), - int(seconds(now())), - ) - for record in results: - workers[record.get("worker_id")] = record.get("stats") - queued = await self.count("queued") - active = await self.count("active") - incomplete = await self.count("incomplete") - - if jobs: - async with self.with_connection() as conn: - results = await conn.fetch( - dedent(f""" - SELECT job FROM {self.jobs_table} - WHERE status IN ('new', 'deferred', 'queued', 'active') - """) - ) - deserialized_jobs = (self.deserialize(result["job"]) for result in results) - jobs_info = [job.to_dict() for job in deserialized_jobs if job] - else: - jobs_info = [] - - return { - "workers": workers, - "name": self.name, - "queued": queued, - "active": active, - "scheduled": incomplete - queued - active, - "jobs": jobs_info, - } - - async def count(self, kind: CountKind) -> int: - async with self.with_connection() as conn: - if kind == "queued": - result = await conn.fetchval( - dedent(f""" - SELECT count(*) FROM {self.jobs_table} - WHERE status = 'queued' - AND queue = $1 - AND $2 >= scheduled - """), - self.name, - int(seconds(now())), - ) - elif kind == "active": - result = await conn.fetchval( - dedent(f""" - SELECT count(*) FROM {self.jobs_table} - WHERE status = 'active' - AND queue = $1 - """), - self.name, - ) - elif kind == "incomplete": - result = await conn.fetchval( - dedent(f""" - SELECT count(*) FROM {self.jobs_table} - WHERE status IN ('new', 'deferred', 'queued', 'active') - AND queue = $1 - """), - self.name, - ) - else: - raise ValueError(f"Can't count unknown type {kind}") - - return result - - async def schedule(self, lock: int = 1) -> list[str]: - await self._dequeue() - return [] - - async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: - """Delete jobs and stats past their expiration and sweep stuck jobs""" - - swept = [] - if not self._has_sweep_lock: - # Attempt to get the sweep lock and hold on to it - async with self._get_dequeue_conn() as conn: - result = await conn.fetchval( - dedent("SELECT pg_try_advisory_lock($1, hashtext($2))"), - self.saq_lock_keyspace, - self.name, - ) - if not result: - return [] - self._has_sweep_lock = True - - async with self.with_connection() as conn, conn.transaction(): - expired_at = int(seconds(now())) - await conn.execute( - dedent(f""" - DELETE FROM {self.jobs_table} - WHERE queue = $1 - AND status IN ('aborted', 'complete', 'failed') - AND $2 >= expire_at - """), - self.name, - expired_at, - ) - await conn.execute( - dedent(f""" - DELETE FROM {self.stats_table} - WHERE $1 >= expire_at; - """), - expired_at, - ) - results = await conn.fetch( - dedent( - f""" - WITH locks AS ( - SELECT objid - FROM pg_locks - WHERE locktype = 'advisory' - AND classid = $1 - AND objsubid = 2 -- key is int pair, not single bigint - ) - SELECT key, job, objid, status - FROM {self.jobs_table} - LEFT OUTER JOIN locks - ON lock_key = objid - WHERE queue = $2 - AND status IN ('active', 'aborting'); - """ - ), - self.job_lock_keyspace, - self.name, - ) - - for row in results: - key, job_bytes, objid, _status = row.values() - job = self.deserialize(job_bytes) - assert job - if objid and not job.stuck: - continue - - swept.append(key) - await self.abort(job, error="swept") - - try: - await job.refresh(abort) - except asyncio.TimeoutError: - logger.info("Could not abort job %s", key) - - logger.info("Sweeping job %s", job.info(logger.isEnabledFor(logging.DEBUG))) - if job.retryable: - await self.retry(job, error="swept") - else: - await self.finish(job, Status.ABORTED, error="swept") - return swept - - async def listen( - self, - job_keys: Iterable[str], - callback: ListenCallback, - timeout: float | None = 10, - ) -> None: - if not job_keys: - return - - async for message in self._listener.listen(*job_keys, timeout=timeout): - job_key = message["key"] - status = Status[message["data"].upper()] - if asyncio.iscoroutinefunction(callback): - stop = await callback(job_key, status) - else: - stop = callback(job_key, status) - if stop: - break - - async def notify( - self, job: Job, connection: PoolConnectionProxy | None = None - ) -> None: - await self._notify(job.key, job.status, connection) - - async def update( - self, - job: Job, - connection: PoolConnectionProxy | None = None, - expire_at: float | None = -1, - **kwargs: Any, - ) -> None: - job.touched = now() - - for k, v in kwargs.items(): - setattr(job, k, v) - async with self.with_connection(connection) as conn, conn.transaction(): - if expire_at != -1: - await conn.execute( - dedent(f""" - UPDATE {self.jobs_table} - SET job=$1, status = $2, scheduled = $3, expire_at = $4 - WHERE key = $5 - """), - self.serialize(job), - job.status, - job.scheduled, - expire_at, - job.key, - ) - else: - await conn.execute( - dedent(f""" - UPDATE {self.jobs_table} - SET job=$1, status = $2, scheduled=$3 - WHERE key = $4 - """), - self.serialize(job), - job.status, - job.scheduled, - job.key, - ) - await self.notify(job, conn) - - async def job(self, job_key: str) -> Job | None: - async with self.with_connection() as conn, conn.transaction(): - cursor = await conn.cursor( - f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key - ) - record = await cursor.fetchrow() - return self.deserialize(record.get("job")) if record else None - - async def jobs(self, job_keys: Iterable[str]) -> List[Job | None]: - keys = list(job_keys) - results = {} - async with self.with_connection() as conn, conn.transaction(): - async for record in conn.cursor( - f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys - ): - results[record.get("key")] = record.get("job") - return [self.deserialize(results.get(key)) for key in keys] - - async def iter_jobs( - self, - statuses: List[Status] = list(Status), - batch_size: int = 100, - ) -> AsyncIterator[Job]: - async with self.with_connection() as conn, conn.transaction(): - last_key = "" - while True: - rows = await conn.fetch( - dedent(f""" - SELECT key, job - FROM {self.jobs_table} - WHERE status = ANY($1) - AND queue = $2 - AND key > $3 - ORDER BY key - LIMIT $4"""), - statuses, - self.name, - last_key, - batch_size, - ) - if rows: - for row in rows: - last_key = row.get("key") - job = self.deserialize(row.get("job")) - if job: - yield job - else: - break - - async def abort(self, job: Job, error: str, ttl: float = 5) -> None: - async with self.with_connection() as conn: - status = await self.get_job_status( - job.key, for_update=True, connection=conn - ) - if status == Status.QUEUED: - await self.finish(job, Status.ABORTED, error=error, connection=conn) - else: - await self.update( - job, status=Status.ABORTING, error=error, connection=conn - ) - - async def dequeue(self, timeout: float = 0) -> Job | None: - job = None - - try: - self._waiting += 1 - - if self._job_queue.empty(): - await self._dequeue() - - if not self._job_queue.empty(): - job = self._job_queue.get_nowait() - elif self._listen_lock.locked(): - job = await ( - asyncio.wait_for(self._job_queue.get(), timeout) - if timeout > 0 - else self._job_queue.get() - ) - else: - async with self._listen_lock: - async for payload in self._listener.listen( - ENQUEUE, DEQUEUE, timeout=timeout - ): - if payload["key"] == ENQUEUE: - await self._dequeue() - - if not self._job_queue.empty(): - job = self._job_queue.get_nowait() - break - except (asyncio.TimeoutError, asyncio.CancelledError): - pass - finally: - self._waiting -= 1 - - if job: - self._job_queue.task_done() - - return job - - async def _dequeue(self) -> None: - if self._dequeue_lock.locked(): - return - - async with self._dequeue_lock: - async with self._get_dequeue_conn() as conn, conn.transaction(): - if not self._waiting: - return - should_notify = False - async for record in conn.cursor( - dedent(f""" - WITH locked_job AS ( - SELECT key, lock_key - FROM {self.jobs_table} - WHERE status = 'queued' - AND queue = $1 - AND $2 >= scheduled - ORDER BY scheduled - LIMIT $3 - FOR UPDATE SKIP LOCKED - ) - UPDATE {self.jobs_table} SET status = 'active' - FROM locked_job - WHERE {self.jobs_table}.key = locked_job.key - AND pg_try_advisory_lock({self.job_lock_keyspace}, locked_job.lock_key) - RETURNING job - """), - self.name, - int(seconds(now())), - self._waiting, - ): - should_notify = True - self._job_queue.put_nowait(self.deserialize(record.get("job"))) - if should_notify: - await self._notify(DEQUEUE) - - async def _enqueue(self, job: Job) -> Job | None: - async with self.with_connection() as conn, conn.transaction(): - cursor = await conn.cursor( - f""" - INSERT INTO {self.jobs_table} (key, job, queue, status, scheduled) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (key) DO UPDATE - SET - job = $2, - queue = $3, - status = $4, - scheduled = $5, - expire_at = null - WHERE - {self.jobs_table}.status IN ('aborted', 'complete', 'failed') - AND $5 > {self.jobs_table}.scheduled - RETURNING 1 - """, - job.key, - self.serialize(job), - self.name, - job.status, - job.scheduled or int(seconds(now())), - ) - if not await cursor.fetchrow(): - return None - await self._notify(ENQUEUE, connection=conn) - logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG))) - return job - - async def write_stats(self, stats: QueueStats, ttl: int) -> None: - async with self.with_connection() as conn: - await conn.execute( - f""" - INSERT INTO {self.stats_table} (worker_id, stats, expire_at) - VALUES ($1, $2, $3) - ON CONFLICT (worker_id) DO UPDATE - SET stats = $2, expire_at = $3 - """, - self.uuid, - json.dumps(stats), - seconds(now()) + ttl, - ) - - async def get_job_status( - self, - key: str, - for_update: bool = False, - connection: PoolConnectionProxy | None = None, - ) -> Status: - async with self.with_connection(connection) as conn: - result = await conn.fetchval( - dedent(f""" - SELECT status - FROM {self.jobs_table} - WHERE key = $1 - {"FOR UPDATE" if for_update else ""} - """), - key, - ) - assert result - return result - - async def _retry(self, job: Job, error: str | None) -> None: - next_retry_delay = job.next_retry_delay() - if next_retry_delay: - scheduled = time.time() + next_retry_delay - else: - scheduled = job.scheduled or int(seconds(now())) - - await self.update(job, scheduled=int(scheduled), expire_at=None) - - async def _finish( - self, - job: Job, - status: Status, - *, - result: Any = None, - error: str | None = None, - connection: PoolConnectionProxy | None = None, - ) -> None: - key = job.key - - async with self.with_connection(connection) as conn: - if job.ttl >= 0: - expire_at = int(seconds(now())) + job.ttl if job.ttl > 0 else None - await self.update( - job, status=status, expire_at=expire_at, connection=conn - ) - else: - await conn.execute( - dedent(f""" - DELETE FROM {self.jobs_table} - WHERE key = $1 - """), - key, - ) - await self.notify(job, conn) - await self._release_job(key) - - async def _notify( - self, - key: str, - data: Any | None = None, - connection: PoolConnectionProxy | None = None, - ) -> None: - payload = {"key": key} - - if data is not None: - payload["data"] = data - - async with self.with_connection(connection) as conn: - await conn.execute(f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'") - - async def _release_job(self, key: str) -> None: - self._releasing.append(key) - if self._connection_lock.locked(): - return - async with self._get_dequeue_conn() as conn: - txn = conn.transaction() - await txn.start() - await conn.execute( - f""" - SELECT pg_advisory_unlock({self.job_lock_keyspace}, lock_key) - FROM {self.jobs_table} - WHERE key = ANY($1) - """, - self._releasing, - ) - await txn.commit() - self._releasing.clear() - - @asynccontextmanager - async def _get_dequeue_conn(self) -> AsyncGenerator[PoolConnectionProxy]: - async with self._connection_lock: - if self._dequeue_conn: - try: - await self._dequeue_conn.execute("SELECT 1") - except (ConnectionDoesNotExistError, InterfaceError): - await self.pool.release(self._dequeue_conn) - self._dequeue_conn = await self.pool.acquire() - else: - self._dequeue_conn = await self.pool.acquire() - - yield self._dequeue_conn - - - @asynccontextmanager - async def _nullcontext(self, enter_result: ContextT) -> AsyncGenerator[ContextT]: - """Async version of contextlib.nullcontext - - Async support has been added to contextlib.nullcontext in Python 3.10. - """ - yield enter_result - - -class ListenMultiplexer(Multiplexer): - def __init__(self, pool: Pool, key: str) -> None: - super().__init__() - self.pool = pool - self.key = key - self._connection: PoolConnectionProxy | None = None - - async def _start(self) -> None: - if self._connection is None: - self._connection = await self.pool.acquire() - txn = self._connection.transaction() - await txn.start() - await self._connection.add_listener(self.key, self._notify_callback) - await txn.commit() - - async def _close(self) -> None: - if self._connection: - await self._connection.remove_listener(self.key, self._notify_callback) - await self._connection.close() - self._connection = None - - async def _notify_callback( - self, - connection: Connection | PoolConnectionProxy, - pid: int, - channel: str, - payload: Any, - ) -> None: - payload_data = json.loads(payload) - self.publish(payload_data["key"], payload_data) diff --git a/setup.py b/setup.py index cd7b2a8..f8d88d6 100644 --- a/setup.py +++ b/setup.py @@ -36,8 +36,7 @@ extras_require={ "hiredis": ["redis[hiredis]>=4.2.0"], "http": ["aiohttp"], - "postgres": ["psycopg[pool]>=3.2.0"], - "asyncpg": ["asyncpg"], + "postgres": ["asyncpg"], "redis": ["redis>=4.2,<6.0"], "web": ["aiohttp", "aiohttp_basicauth"], "dev": [ @@ -46,7 +45,6 @@ "aiohttp_basicauth", "coverage", "mypy", - "psycopg[pool]>=3.2.0", "pre-commit", "redis>=4.2,<6.0", "ruff", diff --git a/tests/helpers.py b/tests/helpers.py index b10dd7c..366b3ee 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,6 @@ from saq.queue import Queue from saq.queue.postgres import PostgresQueue -from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue as AsyncpgPostgresQueue from saq.queue.redis import RedisQueue POSTGRES_TEST_SCHEMA = "test_saq" @@ -39,41 +38,11 @@ async def cleanup_queue(queue: Queue) -> None: async def setup_postgres() -> None: - async with await psycopg.AsyncConnection.connect( - "postgres://postgres@localhost", autocommit=True - ) as conn: - await conn.execute(f"DROP SCHEMA IF EXISTS {POSTGRES_TEST_SCHEMA} CASCADE") - await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {POSTGRES_TEST_SCHEMA}") - - -async def teardown_postgres() -> None: - async with await psycopg.AsyncConnection.connect( - "postgres://postgres@localhost", autocommit=True - ) as conn: - await conn.execute(f"DROP SCHEMA {POSTGRES_TEST_SCHEMA} CASCADE") - - - - -async def create_postgres_asyncpg_queue(**kwargs: t.Any) -> AsyncpgPostgresQueue: - queue = t.cast( - AsyncpgPostgresQueue, - Queue.from_url( - f"postgres+asyncpg://postgres@localhost?options=--search_path%3D{POSTGRES_TEST_SCHEMA}", - **kwargs, - ), - ) - await queue.connect() - await asyncio.sleep(0.1) # Give some time for the tasks to start - return queue - -async def setup_postgres_asyncpg() -> None: - async with asyncpg.create_pool( - "postgres://postgres@localhost", min_size=1, max_size=10, command_timeout=60 - ) as pool: + async with asyncpg.create_pool("postgres://postgres@localhost") as pool: + await pool.execute(f"DROP SCHEMA IF EXISTS {POSTGRES_TEST_SCHEMA} CASCADE") await pool.execute(f"CREATE SCHEMA IF NOT EXISTS {POSTGRES_TEST_SCHEMA}") -async def teardown_postgres_asyncpg() -> None: +async def teardown_postgres() -> None: async with asyncpg.create_pool("postgres://postgres@localhost") as pool: await pool.execute(f"DROP SCHEMA {POSTGRES_TEST_SCHEMA} CASCADE") diff --git a/tests/test_queue.py b/tests/test_queue.py index 35dbe6a..1364f27 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -15,20 +15,16 @@ from saq.utils import uuid1 from saq.worker import Worker from tests.helpers import ( - cleanup_queue, - create_postgres_asyncpg_queue, + cleanup_queue, create_postgres_queue, create_redis_queue, - setup_postgres, - setup_postgres_asyncpg, - teardown_postgres, - teardown_postgres_asyncpg, + setup_postgres, + teardown_postgres, ) if t.TYPE_CHECKING: from unittest.mock import MagicMock - from saq.queue.postgres_asyncpg import PostgresAsyncpgQueue from saq.queue.postgres import PostgresQueue from saq.queue.redis import RedisQueue from saq.types import Context, CountKind, Function @@ -581,328 +577,13 @@ async def test_sweep_jobs(self) -> None: await job2.refresh() self.assertEqual(job2.status, Status.COMPLETE) - async def test_sweep_stats(self) -> None: - # Stats are deleted - await self.queue.stats(ttl=1) - await asyncio.sleep(1.5) - await self.queue.sweep() - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - """ - SELECT stats - FROM {} - WHERE worker_id = %s - """ - ).format(self.queue.stats_table), - (self.queue.uuid,), - ) - self.assertIsNone(await cursor.fetchone()) - - # Stats are not deleted - await self.queue.stats(ttl=60) - await asyncio.sleep(1) - await self.queue.sweep() - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - """ - SELECT stats - FROM {} - WHERE worker_id = %s - """ - ).format(self.queue.stats_table), - (self.queue.uuid,), - ) - self.assertIsNotNone(await cursor.fetchone()) - - async def test_job_lock(self) -> None: - query = SQL( - """ - SELECT count(*) - FROM {} JOIN pg_locks ON lock_key = objid - WHERE key = %(key)s - AND classid = {} - AND objsubid = 2 -- key is int pair, not single bigint - """ - ).format(self.queue.jobs_table, self.queue.job_lock_keyspace) - job = await self.enqueue("test") - await self.dequeue() - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute(query, {"key": job.key}) - self.assertEqual(await cursor.fetchone(), (1,)) - - await self.finish(job, Status.COMPLETE, result=1) - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute(query, {"key": job.key}) - self.assertEqual(await cursor.fetchone(), (0,)) - - async def test_load_dump_pickle(self) -> None: - self.queue = await self.create_queue(dump=pickle.dumps, load=pickle.loads) - job = await self.enqueue("test") - - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - """ - SELECT job - FROM {} - WHERE key = %s - """ - ).format(self.queue.jobs_table), - (job.key,), - ) - result = await cursor.fetchone() - assert result - fetched_job = pickle.loads(result[0]) - self.assertIsInstance(fetched_job, dict) - self.assertEqual(fetched_job["key"], job.key) - - dequeued_job = await self.dequeue() - self.assertEqual(dequeued_job, job) - - @mock.patch("saq.utils.time") - async def test_finish_ttl_positive(self, mock_time: MagicMock) -> None: - mock_time.time.return_value = 0 - job = await self.enqueue("test", ttl=5) - await self.dequeue() - await self.finish(job, Status.COMPLETE) - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - """ - SELECT expire_at - FROM {} - WHERE key = %s - """ - ).format(self.queue.jobs_table), - (job.key,), - ) - result = await cursor.fetchone() - self.assertEqual(result, (5,)) - - @mock.patch("saq.utils.time") - async def test_finish_ttl_neutral(self, mock_time: MagicMock) -> None: - mock_time.time.return_value = 0 - job = await self.enqueue("test", ttl=0) - await self.dequeue() - await self.finish(job, Status.COMPLETE) - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - """ - SELECT expire_at - FROM {} - WHERE key = %s - """ - ).format(self.queue.jobs_table), - (job.key,), - ) - result = await cursor.fetchone() - self.assertEqual(result, (None,)) - - @mock.patch("saq.utils.time") - async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: - mock_time.time.return_value = 0 - job = await self.enqueue("test", ttl=-1) - await self.dequeue() - await self.finish(job, Status.COMPLETE) - async with self.queue.pool.connection() as conn, conn.cursor() as cursor: - await cursor.execute( - SQL( - """ - SELECT expire_at - FROM {} - WHERE key = %s - """ - ).format(self.queue.jobs_table), - (job.key,), - ) - result = await cursor.fetchone() - self.assertIsNone(result) - - async def test_cron_job_close_to_target(self) -> None: - await self.enqueue("test", scheduled=time.time() + 0.5) - job = await self.queue.dequeue(timeout=0.1) - assert not job - - async def test_bad_connection(self) -> None: - job = await self.enqueue("test") - - async with self.queue._get_dequeue_conn() as original_connection: - pass - - await original_connection.close() - # Test dequeue still works - self.assertEqual((await self.dequeue()), job) - # Check queue has a new connection - self.assertNotEqual(original_connection, self.queue._dequeue_conn) - - await self.queue.pool.putconn(original_connection) - - async def test_group_key(self) -> None: - job1 = await self.enqueue("test", group_key=1) - assert job1 - job2 = await self.enqueue("test", group_key=1) - assert job2 - self.assertEqual(await self.count("queued"), 2) - - assert await self.dequeue() - self.assertEqual(await self.count("queued"), 1) - assert not await self.queue.dequeue(0.01) - await job1.update(status="finished") - assert await self.dequeue() - - async def test_priority(self) -> None: - assert await self.enqueue("test", priority=-1) - self.assertEqual(await self.count("queued"), 1) - assert not await self.queue.dequeue(0.01) - -class TestAsyncpgPostgresQueue(TestQueue): - async def asyncSetUp(self) -> None: - await setup_postgres_asyncpg() - self.create_queue = create_postgres_asyncpg_queue - self.queue: PostgresAsyncpgQueue = await self.create_queue() - - async def asyncTearDown(self) -> None: - await super().asyncTearDown() - await teardown_postgres_asyncpg() - - @unittest.skip("Not implemented") - async def test_job_key(self) -> None: - pass - - @unittest.skip("Not implemented") - @mock.patch("saq.utils.time") - async def test_schedule(self, mock_time: MagicMock) -> None: - pass - - async def test_enqueue_dup(self) -> None: - job = await self.enqueue("test", key="1") - self.assertEqual(job.id, "1") - self.assertIsNone(await self.queue.enqueue("test", key="1")) - self.assertIsNone(await self.queue.enqueue(job)) - - async def test_abort(self) -> None: - job = await self.enqueue("test", retries=2) - self.assertEqual(await self.count("queued"), 1) - self.assertEqual(await self.count("incomplete"), 1) - await self.queue.abort(job, "test") - self.assertEqual(await self.count("queued"), 0) - self.assertEqual(await self.count("incomplete"), 0) - await job.refresh() - self.assertEqual(job.status, Status.ABORTED) - self.assertEqual(await self.queue.get_job_status(job.key), Status.ABORTED) - - job = await self.enqueue("test", retries=2) - await self.dequeue() - self.assertEqual(await self.count("queued"), 0) - self.assertEqual(await self.count("incomplete"), 1) - self.assertEqual(await self.count("active"), 1) - await self.queue.abort(job, "test") - self.assertEqual(await self.count("queued"), 0) - self.assertEqual(await self.count("incomplete"), 0) - self.assertEqual(await self.count("active"), 0) - self.assertEqual(await self.queue.get_job_status(job.key), Status.ABORTING) - - @mock.patch("saq.utils.time") - async def test_sweep(self, mock_time: MagicMock) -> None: - mock_time.time.return_value = 1 - job1 = await self.enqueue("test", heartbeat=1, retries=0) - job2 = await self.enqueue("test", timeout=1) - await self.enqueue("test", timeout=2) - await self.enqueue("test", heartbeat=2) - job3 = await self.enqueue("test", timeout=1) - for _ in range(4): - job = await self.dequeue() - job.status = Status.ACTIVE - job.started = 1000 - await self.queue.update(job) - await self.dequeue() - - mock_time.time.return_value = 3 - self.assertEqual(await self.count("active"), 5) - swept = await self.queue.sweep(abort=0.01) - self.assertEqual( - set(swept), - { - job1.key, - job2.key, - }, - ) - await job1.refresh() - await job2.refresh() - await job3.refresh() - self.assertEqual(job1.status, Status.ABORTED) - self.assertEqual(job2.status, Status.QUEUED) - self.assertEqual(job3.status, Status.QUEUED) - self.assertEqual(await self.count("active"), 3) - - @mock.patch("saq.utils.time") - async def test_sweep_stuck(self, mock_time: MagicMock) -> None: - job1 = await self.queue.enqueue("test") - assert job1 - job = await self.dequeue() - job.status = Status.ACTIVE - job.started = 1000 - await self.queue.update(job) - - # Enqueue 2 more jobs that will become stuck - job2 = await self.queue.enqueue("test", retries=0) - assert job2 - job3 = await self.queue.enqueue("test") - assert job3 - - another_queue = await self.create_queue() - for _ in range(2): - job = await another_queue.dequeue() - job.status = Status.ACTIVE - job.started = 1000 - await another_queue.update(job) - - # Disconnect another_queue to simulate worker going down - await another_queue.disconnect() - - mock_time.time.return_value = 3 - self.assertEqual(await self.count("active"), 3) - swept = await self.queue.sweep(abort=0.01) - self.assertEqual( - set(swept), - { - job2.id, - job3.id, - }, - ) - await job1.refresh() - await job2.refresh() - await job3.refresh() - self.assertEqual(job1.status, Status.ACTIVE) - self.assertEqual(job2.status, Status.ABORTED) - self.assertEqual(job3.status, Status.QUEUED) - self.assertEqual(await self.count("active"), 1) - - async def test_sweep_jobs(self) -> None: - job1 = await self.enqueue("test", ttl=1) - job2 = await self.enqueue("test", ttl=60) - await self.queue.finish(job1, Status.COMPLETE) - await self.queue.finish(job2, Status.COMPLETE) - await asyncio.sleep(1.5) - - await self.queue.sweep() - with self.assertRaisesRegex(RuntimeError, "doesn't exist"): - await job1.refresh() - await job2.refresh() - self.assertEqual(job2.status, Status.COMPLETE) - async def test_sweep_stats(self) -> None: # Stats are deleted await self.queue.stats(ttl=1) await asyncio.sleep(1.5) await self.queue.sweep() async with self.queue.pool.acquire() as conn, conn.transaction(): - cursor = await conn.cursor( - - """ + cursor = await conn.cursor(""" SELECT stats FROM {} WHERE worker_id = $1 @@ -914,7 +595,7 @@ async def test_sweep_stats(self) -> None: # Stats are not deleted await self.queue.stats(ttl=60) await asyncio.sleep(1) - # await self.queue.sweep() + await self.queue.sweep() async with self.queue.pool.acquire() as conn, conn.transaction(): cursor = await conn.cursor( @@ -1019,14 +700,9 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: ) self.assertIsNone(result) - @mock.patch("saq.utils.time") - async def test_cron_job_close_to_target(self, mock_time: MagicMock) -> None: - mock_time.time.return_value = 1000.5 - await self.enqueue("test", scheduled=1001) - - # The job is scheduled to run at 1001, but we're running at 1000.5 - # so it should not be picked up - job = await self.queue.dequeue(timeout=1) + async def test_cron_job_close_to_target(self) -> None: + await self.enqueue("test", scheduled=time.time() + 0.5) + job = await self.queue.dequeue(timeout=0.1) assert not job async def test_bad_connection(self) -> None: @@ -1038,7 +714,6 @@ async def test_bad_connection(self) -> None: self.assertEqual((await self.dequeue()), job) # Check queue has a new connection self.assertNotEqual(original_connection,self.queue._dequeue_conn) - async def test_group_key(self) -> None: job1 = await self.enqueue("test", group_key=1) @@ -1057,4 +732,3 @@ async def test_priority(self) -> None: assert await self.enqueue("test", priority=-1) self.assertEqual(await self.count("queued"), 1) assert not await self.queue.dequeue(0.01) - From 42b5552b16d7b9ea85d41459793a0cf1581904f7 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 19:55:29 +0000 Subject: [PATCH 11/13] fix: remove formatting change --- saq/queue/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/saq/queue/base.py b/saq/queue/base.py index e5edb09..33948d5 100644 --- a/saq/queue/base.py +++ b/saq/queue/base.py @@ -156,10 +156,12 @@ def from_url(url: str, **kwargs: t.Any) -> Queue: from saq.queue.redis import RedisQueue return RedisQueue.from_url(url, **kwargs) + if url.startswith("postgres"): from saq.queue.postgres import PostgresQueue return PostgresQueue.from_url(url, **kwargs) + from saq.queue.http import HttpQueue return HttpQueue.from_url(url, **kwargs) From 31f2fa3716f7fa73aa41d8ea74171d92fa6e115b Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 19:57:46 +0000 Subject: [PATCH 12/13] fix: linting/formatting --- saq/queue/postgres.py | 47 ++++++++++++++++++------------------------ setup.py | 2 +- tests/helpers.py | 1 - tests/test_queue.py | 48 +++++++++++++++++++++---------------------- 4 files changed, 44 insertions(+), 54 deletions(-) diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 6189332..deeac05 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -80,7 +80,7 @@ def from_url( # pyright: ignore[reportIncompatibleMethodOverride] min_size: int = 4, max_size: int = 20, **kwargs: t.Any, - ) -> PostgresQueue: + ) -> PostgresQueue: """Create a queue from a postgres url. Args: @@ -90,8 +90,8 @@ def from_url( # pyright: ignore[reportIncompatibleMethodOverride] max_size: maximum pool size. (default 20) If greater than 0, this limits the maximum number of connections to Postgres. Otherwise, maintain `min_size` number of connections. - - """ + + """ return cls(create_pool(dsn=url, min_size=min_size, max_size=max_size), **kwargs) def __init__( @@ -111,7 +111,7 @@ def __init__( self.jobs_table = jobs_table self.stats_table = stats_table - self.pool = pool + self.pool = pool self.poll_interval = poll_interval self.saq_lock_keyspace = saq_lock_keyspace self.job_lock_keyspace = job_lock_keyspace @@ -132,15 +132,14 @@ def __init__( async def with_connection( self, connection: PoolConnectionProxy | None = None ) -> t.AsyncGenerator[PoolConnectionProxy]: - async with self.nullcontext( - connection - ) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined] + async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined] yield conn async def init_db(self) -> None: async with self.with_connection() as conn, conn.transaction(): cursor = await conn.cursor( - "SELECT pg_try_advisory_lock($1, 0)", self.saq_lock_keyspace, + "SELECT pg_try_advisory_lock($1, 0)", + self.saq_lock_keyspace, ) result = await cursor.fetchrow() @@ -148,10 +147,8 @@ async def init_db(self) -> None: return for statement in DDL_STATEMENTS: await conn.execute( - statement.format( - jobs_table=self.jobs_table, stats_table=self.stats_table - ) - ) + statement.format(jobs_table=self.jobs_table, stats_table=self.stats_table) + ) async def connect(self) -> None: if self._dequeue_conn: @@ -225,7 +222,7 @@ async def count(self, kind: CountKind) -> int: AND queue = $1 AND NOW() >= TO_TIMESTAMP(scheduled) """), - self.name, + self.name, ) elif kind == "active": result = await conn.fetchval( @@ -259,7 +256,7 @@ async def schedule(self, lock: int = 1) -> t.List[str]: async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: """Delete jobs and stats past their expiration and sweep stuck jobs""" swept = [] - + if not self._has_sweep_lock: # Attempt to get the sweep lock and hold on to it async with self._get_dequeue_conn() as conn: @@ -273,7 +270,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: return [] self._has_sweep_lock = True - async with self.with_connection() as conn, conn.transaction(): + async with self.with_connection() as conn, conn.transaction(): await conn.execute( dedent(f""" -- Delete expired jobs @@ -282,14 +279,14 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: AND status IN ('aborted', 'complete', 'failed') AND NOW() >= TO_TIMESTAMP(expire_at) """), - self.name, + self.name, ) await conn.execute( dedent(f""" -- Delete expired stats DELETE FROM {self.stats_table} WHERE NOW() >= TO_TIMESTAMP(expire_at); - """), + """), ) results = await conn.fetch( dedent( @@ -354,9 +351,7 @@ async def listen( if stop: break - async def notify( - self, job: Job, connection: PoolConnectionProxy | None = None - ) -> None: + async def notify(self, job: Job, connection: PoolConnectionProxy | None = None) -> None: await self._notify(job.key, job.status, connection) async def update( @@ -400,15 +395,13 @@ async def update( async def job(self, job_key: str) -> Job | None: async with self.with_connection() as conn, conn.transaction(): - cursor = await conn.cursor( - f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key - ) + cursor = await conn.cursor(f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key) record = await cursor.fetchrow() return self.deserialize(record.get("job")) if record else None async def jobs(self, job_keys: Iterable[str]) -> t.List[Job | None]: keys = list(job_keys) - results: dict[str, bytes | None] = {} + results: dict[str, bytes | None] = {} async with self.with_connection() as conn, conn.transaction(): async for record in conn.cursor( f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys @@ -607,7 +600,7 @@ async def get_job_status( {"FOR UPDATE" if for_update else ""} """), key, - ) + ) assert result return result @@ -643,7 +636,7 @@ async def _finish( """), key, ) - await self.notify(job, conn) + await self.notify(job, conn) await self._release_job(key) async def _notify( @@ -732,4 +725,4 @@ async def _notify_callback( payload: t.Any, ) -> None: payload_data = json.loads(payload) - self.publish(payload_data["key"], payload_data) \ No newline at end of file + self.publish(payload_data["key"], payload_data) diff --git a/setup.py b/setup.py index f8d88d6..67ae305 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ extras_require={ "hiredis": ["redis[hiredis]>=4.2.0"], "http": ["aiohttp"], - "postgres": ["asyncpg"], + "postgres": ["asyncpg"], "redis": ["redis>=4.2,<6.0"], "web": ["aiohttp", "aiohttp_basicauth"], "dev": [ diff --git a/tests/helpers.py b/tests/helpers.py index 366b3ee..02efa08 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,7 +2,6 @@ import typing as t import asyncpg -import psycopg from saq.queue import Queue from saq.queue.postgres import PostgresQueue diff --git a/tests/test_queue.py b/tests/test_queue.py index 1364f27..014597a 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -8,18 +8,17 @@ import unittest from unittest import mock -from psycopg.sql import SQL from saq.job import Job, Status from saq.queue import JobError, Queue from saq.utils import uuid1 from saq.worker import Worker from tests.helpers import ( - cleanup_queue, + cleanup_queue, create_postgres_queue, create_redis_queue, - setup_postgres, - teardown_postgres, + setup_postgres, + teardown_postgres, ) @@ -583,12 +582,13 @@ async def test_sweep_stats(self) -> None: await asyncio.sleep(1.5) await self.queue.sweep() async with self.queue.pool.acquire() as conn, conn.transaction(): - cursor = await conn.cursor(""" + cursor = await conn.cursor( + """ SELECT stats FROM {} WHERE worker_id = $1 """.format(self.queue.stats_table), - self.queue.uuid + self.queue.uuid, ) self.assertIsNone(await cursor.fetchrow()) @@ -597,19 +597,18 @@ async def test_sweep_stats(self) -> None: await asyncio.sleep(1) await self.queue.sweep() async with self.queue.pool.acquire() as conn, conn.transaction(): - cursor = await conn.cursor( - - """ + cursor = await conn.cursor( + """ SELECT stats FROM {} WHERE worker_id = $1 """.format(self.queue.stats_table), - self.queue.uuid + self.queue.uuid, ) self.assertIsNotNone(await cursor.fetchrow()) async def test_job_lock(self) -> None: - query = """ + query = """ SELECT count(*) FROM {} JOIN pg_locks ON lock_key = objid WHERE key = $1 @@ -632,11 +631,12 @@ async def test_load_dump_pickle(self) -> None: job = await self.enqueue("test") async with self.queue.pool.acquire() as conn, conn.transaction(): - result = await conn.fetchrow(""" + result = await conn.fetchrow( + """ SELECT job FROM {} WHERE key =$1 - """ .format(self.queue.jobs_table), + """.format(self.queue.jobs_table), job.key, ) assert result @@ -655,15 +655,14 @@ async def test_finish_ttl_positive(self, mock_time: MagicMock) -> None: await self.finish(job, Status.COMPLETE) async with self.queue.pool.acquire() as conn: result = await conn.fetchval( - - """ + """ SELECT expire_at FROM {} WHERE key = $1 """.format(self.queue.jobs_table), job.key, ) - self.assertEqual(result,5) + self.assertEqual(result, 5) @mock.patch("saq.utils.time") async def test_finish_ttl_neutral(self, mock_time: MagicMock) -> None: @@ -671,17 +670,16 @@ async def test_finish_ttl_neutral(self, mock_time: MagicMock) -> None: job = await self.enqueue("test", ttl=0) await self.dequeue() await self.finish(job, Status.COMPLETE) - async with self.queue.pool.acquire() as conn : + async with self.queue.pool.acquire() as conn: result = await conn.fetchval( - - """ + """ SELECT expire_at FROM {} WHERE key = $1 - """ .format(self.queue.jobs_table), + """.format(self.queue.jobs_table), job.key, ) - self.assertEqual(result,None) + self.assertEqual(result, None) @mock.patch("saq.utils.time") async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: @@ -689,13 +687,13 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None: job = await self.enqueue("test", ttl=-1) await self.dequeue() await self.finish(job, Status.COMPLETE) - async with self.queue.pool.acquire() as conn : + async with self.queue.pool.acquire() as conn: result = await conn.fetchval( - """ + """ SELECT expire_at FROM {} WHERE key = $1 - """ .format(self.queue.jobs_table), + """.format(self.queue.jobs_table), job.key, ) self.assertIsNone(result) @@ -713,7 +711,7 @@ async def test_bad_connection(self) -> None: # Test dequeue still works self.assertEqual((await self.dequeue()), job) # Check queue has a new connection - self.assertNotEqual(original_connection,self.queue._dequeue_conn) + self.assertNotEqual(original_connection, self.queue._dequeue_conn) async def test_group_key(self) -> None: job1 = await self.enqueue("test", group_key=1) From 9f9c73683faf0a492a5e70428e4126b9409234fe Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 13 Oct 2024 20:27:21 +0000 Subject: [PATCH 13/13] feat: additional typing changes --- saq/queue/postgres.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index deeac05..a73bcfb 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -11,7 +11,6 @@ import typing as t from contextlib import asynccontextmanager from textwrap import dedent - from saq.errors import MissingDependencyError from saq.job import ( Job, @@ -23,8 +22,7 @@ from saq.utils import now, seconds if t.TYPE_CHECKING: - from collections.abc import Iterable - + from collections.abc import AsyncGenerator, Iterable from saq.types import ( CountKind, ListenCallback, @@ -131,7 +129,7 @@ def __init__( @asynccontextmanager async def with_connection( self, connection: PoolConnectionProxy | None = None - ) -> t.AsyncGenerator[PoolConnectionProxy]: + ) -> AsyncGenerator[PoolConnectionProxy]: async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined] yield conn @@ -216,7 +214,7 @@ async def count(self, kind: CountKind) -> int: if kind == "queued": result = await conn.fetchval( dedent(f""" - SELECT count(*) + SELECT count(*) FROM {self.jobs_table} WHERE status = 'queued' AND queue = $1 @@ -227,7 +225,7 @@ async def count(self, kind: CountKind) -> int: elif kind == "active": result = await conn.fetchval( dedent(f""" - SELECT count(*) + SELECT count(*) FROM {self.jobs_table} WHERE status = 'active' AND queue = $1 @@ -237,7 +235,7 @@ async def count(self, kind: CountKind) -> int: elif kind == "incomplete": result = await conn.fetchval( dedent(f""" - SELECT count(*) + SELECT count(*) FROM {self.jobs_table} WHERE status IN ('new', 'deferred', 'queued', 'active') AND queue = $1 @@ -259,7 +257,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]: if not self._has_sweep_lock: # Attempt to get the sweep lock and hold on to it - async with self._get_dequeue_conn() as conn: + async with self._get_dequeue_conn() as conn, conn.transaction(): result = await conn.fetchval( dedent("SELECT pg_try_advisory_lock($1, hashtext($2))"), self.saq_lock_keyspace, @@ -654,7 +652,7 @@ async def _notify( await conn.execute(f"NOTIFY \"{self._channel}\", '{json.dumps(payload)}'") @asynccontextmanager - async def _get_dequeue_conn(self) -> t.AsyncGenerator: + async def _get_dequeue_conn(self) -> AsyncGenerator: async with self._connection_lock: if self._dequeue_conn: try: @@ -670,7 +668,7 @@ async def _get_dequeue_conn(self) -> t.AsyncGenerator: yield self._dequeue_conn @asynccontextmanager - async def nullcontext(self, enter_result: ContextT) -> t.AsyncGenerator[ContextT]: + async def nullcontext(self, enter_result: ContextT) -> AsyncGenerator[ContextT]: """Async version of contextlib.nullcontext Async support has been added to contextlib.nullcontext in Python 3.10.