Skip to content

Commit f38422b

Browse files
authored
Merge pull request #171 from tobymao/toby/fix_solo_cron
fix: concurrency 1 needs to queue cron jobs
2 parents b0f7a61 + 99d23c5 commit f38422b

File tree

3 files changed

+81
-57
lines changed

3 files changed

+81
-57
lines changed

saq/queue/postgres.py

+61-56
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
CHANNEL = "saq:{}"
4848
ENQUEUE = "saq:enqueue"
49+
DEQUEUE = "saq:dequeue"
4950
JOBS_TABLE = "saq_jobs"
5051
STATS_TABLE = "saq_stats"
5152

@@ -126,14 +127,11 @@ async def init_db(self) -> None:
126127
)
127128

128129
async def connect(self) -> None:
129-
if self._dequeue_conn:
130-
# If connection exists, connect() was already called
130+
if self.pool._opened:
131131
return
132132

133133
await self.pool.open()
134134
await self.pool.resize(min_size=self.min_size, max_size=self.max_size)
135-
# Reserve a connection for dequeue and advisory locks
136-
self._dequeue_conn = await self.pool.getconn()
137135
await self.init_db()
138136

139137
def serialize(self, job: Job) -> bytes | str:
@@ -531,8 +529,9 @@ async def dequeue(self, timeout: float = 0) -> Job | None:
531529
)
532530
else:
533531
async with self._listen_lock:
534-
async for _ in self._listener.listen(ENQUEUE, timeout=timeout):
535-
await self._dequeue()
532+
async for payload in self._listener.listen(ENQUEUE, DEQUEUE, timeout=timeout):
533+
if payload["key"] == ENQUEUE:
534+
await self._dequeue()
536535

537536
if not self._job_queue.empty():
538537
job = self._job_queue.get_nowait()
@@ -547,6 +546,53 @@ async def dequeue(self, timeout: float = 0) -> Job | None:
547546

548547
return job
549548

549+
async def _dequeue(self) -> None:
550+
if self._dequeue_lock.locked():
551+
return
552+
553+
async with self._dequeue_lock:
554+
async with self._get_dequeue_conn() as conn, conn.cursor() as cursor, conn.transaction():
555+
if not self._waiting:
556+
return
557+
await cursor.execute(
558+
SQL(
559+
dedent(
560+
"""
561+
WITH locked_job AS (
562+
SELECT key, lock_key
563+
FROM {jobs_table}
564+
WHERE status = 'queued'
565+
AND queue = %(queue)s
566+
AND %(now)s >= scheduled
567+
ORDER BY scheduled
568+
LIMIT %(limit)s
569+
FOR UPDATE SKIP LOCKED
570+
)
571+
UPDATE {jobs_table} SET status = 'active'
572+
FROM locked_job
573+
WHERE {jobs_table}.key = locked_job.key
574+
AND pg_try_advisory_lock({job_lock_keyspace}, locked_job.lock_key)
575+
RETURNING job
576+
"""
577+
)
578+
).format(
579+
jobs_table=self.jobs_table,
580+
job_lock_keyspace=self.job_lock_keyspace,
581+
),
582+
{
583+
"queue": self.name,
584+
"now": math.ceil(seconds(now())),
585+
"limit": self._waiting,
586+
},
587+
)
588+
results = await cursor.fetchall()
589+
590+
for result in results:
591+
self._job_queue.put_nowait(self.deserialize(result[0]))
592+
593+
if results:
594+
await self._notify(DEQUEUE)
595+
550596
async def _enqueue(self, job: Job) -> Job | None:
551597
async with self.pool.connection() as conn, conn.cursor() as cursor:
552598
await cursor.execute(
@@ -676,49 +722,6 @@ async def _finish(
676722
await self.notify(job, conn)
677723
await self._release_job(key)
678724

679-
async def _dequeue(self) -> None:
680-
if self._dequeue_lock.locked():
681-
return
682-
683-
async with self._dequeue_lock:
684-
async with self._get_dequeue_conn() as conn, conn.cursor() as cursor, conn.transaction():
685-
if not self._waiting:
686-
return
687-
await cursor.execute(
688-
SQL(
689-
dedent(
690-
"""
691-
WITH locked_job AS (
692-
SELECT key, lock_key
693-
FROM {jobs_table}
694-
WHERE status = 'queued'
695-
AND queue = %(queue)s
696-
AND %(now)s >= scheduled
697-
ORDER BY scheduled
698-
LIMIT %(limit)s
699-
FOR UPDATE SKIP LOCKED
700-
)
701-
UPDATE {jobs_table} SET status = 'active'
702-
FROM locked_job
703-
WHERE {jobs_table}.key = locked_job.key
704-
AND pg_try_advisory_lock({job_lock_keyspace}, locked_job.lock_key)
705-
RETURNING job
706-
"""
707-
)
708-
).format(
709-
jobs_table=self.jobs_table,
710-
job_lock_keyspace=self.job_lock_keyspace,
711-
),
712-
{
713-
"queue": self.name,
714-
"now": math.ceil(seconds(now())),
715-
"limit": self._waiting,
716-
},
717-
)
718-
results = await cursor.fetchall()
719-
for result in results:
720-
self._job_queue.put_nowait(self.deserialize(result[0]))
721-
722725
async def _notify(
723726
self, key: str, data: t.Any | None = None, connection: AsyncConnection | None = None
724727
) -> None:
@@ -736,14 +739,16 @@ async def _notify(
736739

737740
@asynccontextmanager
738741
async def _get_dequeue_conn(self) -> t.AsyncGenerator:
739-
assert self._dequeue_conn
740742
async with self._connection_lock:
741-
try:
742-
# Pool normally performs this check when getting a connection.
743-
await self.pool.check_connection(self._dequeue_conn)
744-
except OperationalError:
745-
# The connection is bad so return it to the pool and get a new one.
746-
await self.pool.putconn(self._dequeue_conn)
743+
if self._dequeue_conn:
744+
try:
745+
# Pool normally performs this check when getting a connection.
746+
await self.pool.check_connection(self._dequeue_conn)
747+
except OperationalError:
748+
# The connection is bad so return it to the pool and get a new one.
749+
await self.pool.putconn(self._dequeue_conn)
750+
self._dequeue_conn = await self.pool.getconn()
751+
else:
747752
self._dequeue_conn = await self.pool.getconn()
748753
yield self._dequeue_conn
749754

tests/test_queue.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,10 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None:
720720

721721
async def test_bad_connection(self) -> None:
722722
job = await self.enqueue("test")
723-
original_connection = self.queue._dequeue_conn
723+
724+
async with self.queue._get_dequeue_conn() as original_connection:
725+
pass
726+
724727
await original_connection.close()
725728
# Test dequeue still works
726729
self.assertEqual((await self.dequeue()), job)

tests/test_worker.py

+16
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,22 @@ async def handler(_ctx: Context) -> None:
482482
await asyncio.sleep(6)
483483
self.assertEqual(state["counter"], 0)
484484

485+
async def test_cron_solo_worker(self) -> None:
486+
state = {"counter": 0}
487+
488+
async def handler(_ctx: Context) -> None:
489+
state["counter"] += 1
490+
491+
self.worker = Worker(
492+
self.queue,
493+
functions=[],
494+
cron_jobs=[CronJob(handler, cron="* * * * * */1")],
495+
concurrency=1,
496+
)
497+
asyncio.create_task(self.worker.start())
498+
await asyncio.sleep(2)
499+
self.assertGreater(state["counter"], 0)
500+
485501

486502
class TestWorkerRedisQueue(TestWorker):
487503
async def asyncSetUp(self) -> None:

0 commit comments

Comments
 (0)