Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: add priorities and groups to postgres #176

Merged
merged 1 commit into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added benchmarks/__init__.py
Empty file.
4 changes: 3 additions & 1 deletion benchmarks/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time

from benchmarks.funcs import *
from funcs import *


SEM = asyncio.Semaphore(20)
Expand Down Expand Up @@ -68,6 +68,8 @@ async def enqueue(func):
while await queue.count("incomplete"):
await asyncio.sleep(0.1)
print(f"SAQ process {N} sleep {time.time() - now}")
await worker.stop()
await queue.disconnect()


def bench_rq():
Expand Down
4 changes: 3 additions & 1 deletion examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ async def cron_job(ctx):
print("executing cron job")


queue = Queue.from_url("postgres://postgres@localhost")

settings = {
"queue": queue,
"functions": [sleeper, adder],
"concurrency": 100,
"cron_jobs": [CronJob(cron_job, cron="* * * * * */5")],
Expand All @@ -33,7 +36,6 @@ async def cron_job(ctx):


async def enqueue(func, **kwargs):
queue = Queue.from_url("redis://localhost")
for _ in range(10000):
await queue.enqueue(func, **{k: v() for k, v in kwargs.items()})

Expand Down
18 changes: 11 additions & 7 deletions saq/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ class Job:
Don't set these, but you can read them.

Parameters:
attempts (int): number of attempts a job has had
completed (int): job completion time epoch seconds
queued (int): job enqueued time epoch seconds
started (int): job started time epoch seconds
touched (int): job touched/updated time epoch seconds
attempts: number of attempts a job has had
completed: job completion time epoch seconds
queued: job enqueued time epoch seconds
started: job started time epoch seconds
touched: job touched/updated time epoch seconds
result: payload containing the results, this is the return of the function provided, must be serializable, defaults to json
error (str | None): stack trace if a runtime error occurs
status (Status): Status Enum, default to Status.New
error: stack trace if a runtime error occurs
status: Status Enum, default to Status.New
priority: The priority of a job, only available in postgres.
group_key: Only one job per group can be active at any time, only available in postgres.
"""

function: str
Expand All @@ -131,6 +133,8 @@ class Job:
result: t.Any = None
error: str | None = None
status: Status = Status.NEW
priority: int = 0
group_key: str | None = None
meta: dict[t.Any, t.Any] = dataclasses.field(default_factory=dict)

_EXCLUDE_NON_FULL = {
Expand Down
66 changes: 47 additions & 19 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class PostgresQueue(Queue):
saq_lock_keyspace: The first of two advisory lock keys used by SAQ. (default 0)
SAQ uses advisory locks for coordinating tasks between its workers, e.g. sweeping.
job_lock_keyspace: The first of two advisory lock keys used for jobs. (default 1)
priorities: The priority range to dequeue (default (0, 32767))
"""

@classmethod
Expand All @@ -95,6 +96,7 @@ def __init__(
poll_interval: int = 1,
saq_lock_keyspace: int = 0,
job_lock_keyspace: int = 1,
priorities: tuple[int, int] = (0, 32767),
) -> None:
super().__init__(name=name, dump=dump, load=load)

Expand All @@ -106,6 +108,7 @@ def __init__(
self.poll_interval = poll_interval
self.saq_lock_keyspace = saq_lock_keyspace
self.job_lock_keyspace = job_lock_keyspace
self._priorities = priorities

self._job_queue: asyncio.Queue = asyncio.Queue()
self._waiting = 0 # Internal counter of worker tasks waiting for dequeue
Expand Down Expand Up @@ -165,11 +168,10 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
dedent(
"""
SELECT worker_id, stats FROM {stats_table}
WHERE %(now)s <= expire_at
WHERE NOW() <= TO_TIMESTAMP(expire_at)
"""
)
).format(stats_table=self.stats_table),
{"now": seconds(now())},
)
results = await cursor.fetchall()
workers: dict[str, dict[str, t.Any]] = dict(results)
Expand Down Expand Up @@ -212,14 +214,15 @@ async def count(self, kind: CountKind) -> int:
SQL(
dedent(
"""
SELECT count(*) FROM {jobs_table}
SELECT count(*)
FROM {jobs_table}
WHERE status = 'queued'
AND queue = %(queue)s
AND %(now)s >= scheduled
AND NOW() >= TO_TIMESTAMP(scheduled)
"""
)
).format(jobs_table=self.jobs_table),
{"queue": self.name, "now": seconds(now())},
{"queue": self.name},
)
elif kind == "active":
await cursor.execute(
Expand Down Expand Up @@ -287,7 +290,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
DELETE FROM {jobs_table}
WHERE queue = %(queue)s
AND status IN ('aborted', 'complete', 'failed')
AND %(now)s >= expire_at;
AND NOW() >= TO_TIMESTAMP(expire_at);
"""
)
).format(
Expand All @@ -296,7 +299,6 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
),
{
"queue": self.name,
"now": seconds(now()),
},
)

Expand All @@ -306,16 +308,13 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
"""
-- Delete expired stats
DELETE FROM {stats_table}
WHERE %(now)s >= expire_at;
WHERE NOW() >= TO_TIMESTAMP(expire_at);
"""
)
).format(
jobs_table=self.jobs_table,
stats_table=self.stats_table,
),
{
"now": seconds(now()),
},
)

await cursor.execute(
Expand Down Expand Up @@ -571,8 +570,16 @@ async def _dequeue(self) -> None:
FROM {jobs_table}
WHERE status = 'queued'
AND queue = %(queue)s
AND %(now)s >= scheduled
ORDER BY scheduled
AND NOW() >= TO_TIMESTAMP(scheduled)
AND priority BETWEEN %(plow)s AND %(phigh)s
AND group_key NOT IN (
SELECT DISTINCT group_key
FROM {jobs_table}
WHERE status = 'active'
AND queue = %(queue)s
AND group_key IS NOT NULL
)
ORDER BY priority, scheduled
LIMIT %(limit)s
FOR UPDATE SKIP LOCKED
)
Expand All @@ -589,8 +596,9 @@ async def _dequeue(self) -> None:
),
{
"queue": self.name,
"now": seconds(now()),
"limit": self._waiting,
"plow": self._priorities[0],
"phigh": self._priorities[1],
},
)
results = await cursor.fetchall()
Expand All @@ -607,13 +615,31 @@ async def _enqueue(self, job: Job) -> Job | None:
SQL(
dedent(
"""
INSERT INTO {jobs_table} (key, job, queue, status, scheduled)
VALUES (%(key)s, %(job)s, %(queue)s, %(status)s, %(scheduled)s)
INSERT INTO {jobs_table} (
key,
job,
queue,
status,
priority,
group_key,
scheduled
)
VALUES (
%(key)s,
%(job)s,
%(queue)s,
%(status)s,
%(priority)s,
%(group_key)s,
%(scheduled)s
)
ON CONFLICT (key) DO UPDATE
SET
job = %(job)s,
queue = %(queue)s,
status = %(status)s,
priority = %(priority)s,
group_key = %(group_key)s,
scheduled = %(scheduled)s,
expire_at = null
WHERE
Expand All @@ -628,6 +654,8 @@ async def _enqueue(self, job: Job) -> Job | None:
"job": self.serialize(job),
"queue": self.name,
"status": job.status,
"priority": job.priority,
"group_key": job.group_key,
"scheduled": job.scheduled or int(seconds(now())),
},
)
Expand All @@ -645,16 +673,16 @@ async def write_stats(self, stats: QueueStats, ttl: int) -> None:
dedent(
"""
INSERT INTO {stats_table} (worker_id, stats, expire_at)
VALUES (%(worker_id)s, %(stats)s, %(expire_at)s)
VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET stats = %(stats)s, expire_at = %(expire_at)s
SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(stats_table=self.stats_table),
{
"worker_id": self.uuid,
"stats": json.dumps(stats),
"expire_at": seconds(now()) + ttl,
"ttl": ttl,
},
)

Expand Down
5 changes: 3 additions & 2 deletions saq/queue/postgres_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
CREATE TABLE IF NOT EXISTS {jobs_table} (
key TEXT PRIMARY KEY,
lock_key SERIAL NOT NULL,
queued BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM now()),
job BYTEA NOT NULL,
queue TEXT NOT NULL,
status TEXT NOT NULL,
scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM now()),
priority SMALLINT NOT NULL DEFAULT 0,
group_key TEXT,
scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
expire_at BIGINT
);
"""
Expand Down
29 changes: 21 additions & 8 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,14 +718,9 @@ async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None:
result = await cursor.fetchone()
self.assertIsNone(result)

@mock.patch("saq.utils.time")
async def test_cron_job_close_to_target(self, mock_time: MagicMock) -> None:
mock_time.time.return_value = 1000.5
await self.enqueue("test", scheduled=1001)

# The job is scheduled to run at 1001, but we're running at 1000.5
# so it should not be picked up
job = await self.queue.dequeue(timeout=1)
async def test_cron_job_close_to_target(self) -> None:
await self.enqueue("test", scheduled=time.time() + 0.5)
job = await self.queue.dequeue(timeout=0.1)
assert not job

async def test_bad_connection(self) -> None:
Expand All @@ -741,3 +736,21 @@ async def test_bad_connection(self) -> None:
self.assertNotEqual(original_connection, self.queue._dequeue_conn)

await self.queue.pool.putconn(original_connection)

async def test_group_key(self) -> None:
job1 = await self.enqueue("test", group_key=1)
assert job1
job2 = await self.enqueue("test", group_key=1)
assert job2
self.assertEqual(await self.count("queued"), 2)

assert await self.dequeue()
self.assertEqual(await self.count("queued"), 1)
assert not await self.queue.dequeue(0.01)
await job1.update(status="finished")
assert await self.dequeue()

async def test_priority(self) -> None:
assert await self.enqueue("test", priority=-1)
self.assertEqual(await self.count("queued"), 1)
assert not await self.queue.dequeue(0.01)
8 changes: 3 additions & 5 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import contextvars
import logging
import time
import typing as t
import unittest
from unittest import mock
Expand Down Expand Up @@ -522,8 +521,7 @@ async def test_schedule(self, mock_time: MagicMock) -> None:
self.skipTest("Not implemented")

@mock.patch("saq.worker.logger")
@mock.patch("saq.utils.time")
async def test_cron(self, mock_time: MagicMock, mock_logger: MagicMock) -> None:
async def test_cron(self, mock_logger: MagicMock) -> None:
with self.assertRaises(ValueError):
Worker(
self.queue,
Expand All @@ -534,15 +532,15 @@ async def test_cron(self, mock_time: MagicMock, mock_logger: MagicMock) -> None:
worker = Worker(
self.queue,
functions=functions,
cron_jobs=[CronJob(cron, cron="* * * * *")],
cron_jobs=[CronJob(cron, cron="* * * * * *")],
)
self.assertEqual(await self.queue.count("queued"), 0)
self.assertEqual(await self.queue.count("incomplete"), 0)
await worker.schedule()
self.assertEqual(await self.queue.count("queued"), 0)
self.assertEqual(await self.queue.count("incomplete"), 1)
await asyncio.sleep(1)

mock_time.time.return_value = time.time() + 60
self.assertEqual(await self.queue.count("queued"), 1)
self.assertEqual(await self.queue.count("incomplete"), 1)

Expand Down
Loading