Skip to content

Commit

Permalink
Secure LISTEN and NOTIFY queries
Browse files Browse the repository at this point in the history
  • Loading branch information
vchan committed Aug 23, 2024
1 parent 77aae0a commit 8ceb15f
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

try:
from psycopg import AsyncConnection
from psycopg import AsyncConnection, Notify
from psycopg.sql import Identifier, SQL
from psycopg.types import json
from psycopg_pool import AsyncConnectionPool
Expand Down Expand Up @@ -332,25 +332,24 @@ async def listen(
if not job_keys:
return

async def _listen() -> None:
async with self.pool.connection() as conn:
for key in job_keys:
await conn.execute(f'LISTEN "{key}"')
await conn.commit()
gen = conn.notifies()
async for notify in gen:
payload = self._load(notify.payload)
key = payload["key"]
status = Status[payload["status"].upper()]
if asyncio.iscoroutinefunction(callback):
stop = await callback(key, status)
else:
stop = callback(key, status)

if stop:
await gen.aclose()

await asyncio.wait_for(_listen(), timeout or None)
async def _listen(gen: t.AsyncGenerator[Notify]) -> None:
async for notify in gen:
payload = self._load(notify.payload)
key = payload["key"]
status = Status[payload["status"].upper()]
if asyncio.iscoroutinefunction(callback):
stop = await callback(key, status)
else:
stop = callback(key, status)

if stop:
await gen.aclose()

async with self.pool.connection() as conn:
for key in job_keys:
await conn.execute(SQL("LISTEN {}").format(Identifier(key)))
await conn.commit()
await asyncio.wait_for(_listen(conn.notifies()), timeout or None)

async def notify(self, job: Job, connection: AsyncConnection | None = None) -> None:
payload = self._dump({"key": job.key, "status": job.status})
Expand All @@ -362,7 +361,11 @@ async def _notify(
async with (
self.nullcontext(connection) if connection else self.pool.connection()
) as conn:
await conn.execute(f"NOTIFY \"{channel}\", '{payload}'")
await conn.execute(
SQL("NOTIFY {channel}, {payload}").format(
channel=Identifier(channel), payload=payload
)
)

async def update(
self,
Expand Down Expand Up @@ -463,11 +466,10 @@ async def finish(
).format(jobs_table=self.jobs_table),
{"key": key},
)
await self.notify(job, conn)
await self._release_job(key)

self._update_stats(status)

await self.notify(job, conn)
logger.info("Finished %s", job.info(logger.isEnabledFor(logging.DEBUG)))

async def dequeue(self, timeout: float = 0) -> Job | None:
Expand Down Expand Up @@ -560,7 +562,7 @@ async def dequeue_timer(self, poll_interval: int) -> None:
async def listen_for_enqueues(self, timeout: float | None = None) -> None:
"""Wakes up a single dequeue task when a Postgres enqueue notification is received."""
async with self.pool.connection() as conn:
await conn.execute(f'LISTEN "{ENQUEUE_CHANNEL}"')
await conn.execute(SQL("LISTEN {}").format(Identifier(ENQUEUE_CHANNEL)))
await conn.commit()
gen = conn.notifies(timeout=timeout)
async for _ in gen:
Expand Down

0 comments on commit 8ceb15f

Please sign in to comment.