diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 94cf982..7513ad4 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -33,7 +33,7 @@ ) try: - from psycopg import AsyncConnection + from psycopg import AsyncConnection, Notify from psycopg.sql import Identifier, SQL from psycopg.types import json from psycopg_pool import AsyncConnectionPool @@ -332,25 +332,24 @@ async def listen( if not job_keys: return - async def _listen() -> None: - async with self.pool.connection() as conn: - for key in job_keys: - await conn.execute(f'LISTEN "{key}"') - await conn.commit() - gen = conn.notifies() - async for notify in gen: - payload = self._load(notify.payload) - key = payload["key"] - status = Status[payload["status"].upper()] - if asyncio.iscoroutinefunction(callback): - stop = await callback(key, status) - else: - stop = callback(key, status) - - if stop: - await gen.aclose() - - await asyncio.wait_for(_listen(), timeout or None) + async def _listen(gen: t.AsyncGenerator[Notify]) -> None: + async for notify in gen: + payload = self._load(notify.payload) + key = payload["key"] + status = Status[payload["status"].upper()] + if asyncio.iscoroutinefunction(callback): + stop = await callback(key, status) + else: + stop = callback(key, status) + + if stop: + await gen.aclose() + + async with self.pool.connection() as conn: + for key in job_keys: + await conn.execute(SQL("LISTEN {}").format(Identifier(key))) + await conn.commit() + await asyncio.wait_for(_listen(conn.notifies()), timeout or None) async def notify(self, job: Job, connection: AsyncConnection | None = None) -> None: payload = self._dump({"key": job.key, "status": job.status}) @@ -362,7 +361,11 @@ async def _notify( async with ( self.nullcontext(connection) if connection else self.pool.connection() ) as conn: - await conn.execute(f"NOTIFY \"{channel}\", '{payload}'") + await conn.execute( + SQL("NOTIFY {channel}, {payload}").format( + channel=Identifier(channel), payload=payload + ) + ) async def update( self, @@ -463,11 +466,10 @@ async def finish( ).format(jobs_table=self.jobs_table), {"key": key}, ) + await self.notify(job, conn) await self._release_job(key) self._update_stats(status) - - await self.notify(job, conn) logger.info("Finished %s", job.info(logger.isEnabledFor(logging.DEBUG))) async def dequeue(self, timeout: float = 0) -> Job | None: @@ -560,7 +562,7 @@ async def dequeue_timer(self, poll_interval: int) -> None: async def listen_for_enqueues(self, timeout: float | None = None) -> None: """Wakes up a single dequeue task when a Postgres enqueue notification is received.""" async with self.pool.connection() as conn: - await conn.execute(f'LISTEN "{ENQUEUE_CHANNEL}"') + await conn.execute(SQL("LISTEN {}").format(Identifier(ENQUEUE_CHANNEL))) await conn.commit() gen = conn.notifies(timeout=timeout) async for _ in gen: