Skip to content

Commit

Permalink
chore: move to a singular table
Browse files Browse the repository at this point in the history
  • Loading branch information
benfdking committed Jan 21, 2025
1 parent 8e22b2b commit 855b913
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 120 deletions.
29 changes: 19 additions & 10 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,27 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_stats_and_metadata(
self,
worker_id: str,
queue_key: str,
metadata: t.Optional[dict],
stats: WorkerStats,
ttl: int,
) -> None:
"""
Returns & updates stats on the queue.
Write stats and metadata for a worker.
Args:
worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats
are attributed to the worker and not the queue instance in the proxy server.
queue_key: The key of the queue.
metadata: The metadata to write.
stats: The stats to write.
ttl: The time-to-live in seconds.
"""
pass

@abstractmethod
async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
) -> None:
pass

@abstractmethod
async def _retry(self, job: Job, error: str | None) -> None:
pass
Expand Down Expand Up @@ -200,13 +203,17 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
async def stats_and_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict] = None, ttl: int = 60
) -> WorkerStats:
"""
Method to be used by workers to update stats.
Args:
worker_id: The worker id.
ttl: Time stats are valid for in seconds.
queue_key: The key of the queue.
metadata: The metadata to write.
Returns: The stats.
"""
Expand All @@ -217,7 +224,9 @@ async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
"aborted": self.aborted,
"uptime": now() - self.started,
}
await self.write_stats(worker_id, stats, ttl)
await self.write_stats_and_metadata(
worker_id, queue_key=queue_key, metadata=metadata, stats=stats, ttl=ttl
)
return stats

def register_before_enqueue(self, callback: BeforeEnqueueType) -> None:
Expand Down
32 changes: 15 additions & 17 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import typing as t


from saq.errors import MissingDependencyError
from saq.job import Job, Status
from saq.queue.base import Queue
Expand Down Expand Up @@ -102,17 +101,13 @@ async def process(self, body: str) -> str | None:
jobs=req["jobs"], offset=req["offset"], limit=req["limit"]
)
)
if kind == "write_stats":
await self.queue.write_stats(
worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"]
)
return None
if kind == "write_worker_metadata":
await self.queue.write_worker_metadata(
metadata=req["metadata"],
if kind == "write_stats_and_metadata":
await self.queue.write_stats_and_metadata(
worker_id=req["worker_id"],
stats=req["stats"],
ttl=req["ttl"],
queue_key=req["queue_key"],
worker_id=req["worker_id"],
metadata=req["metadata"],
)
return None
raise ValueError(f"Invalid request {body}")
Expand Down Expand Up @@ -216,18 +211,21 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)

async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
async def write_stats_and_metadata(
self,
worker_id: str,
queue_key: str,
metadata: t.Optional[dict],
stats: WorkerStats,
ttl: int,
) -> None:
await self._send(
"write_worker_metadata",
"write_stats_and_metadata",
worker_id=worker_id,
metadata=metadata,
stats=stats,
ttl=ttl,
queue_key=queue_key,
metadata=metadata,
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
Expand Down
45 changes: 17 additions & 28 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
JOBS_TABLE = "saq_jobs"
STATS_TABLE = "saq_stats"
VERSIONS_TABLE = "saq_versions"
METADATA_TABLE = "saq_worker_metadata"


class PostgresQueue(Queue):
Expand Down Expand Up @@ -95,7 +94,6 @@ def __init__(
versions_table: str = VERSIONS_TABLE,
jobs_table: str = JOBS_TABLE,
stats_table: str = STATS_TABLE,
metadata_table: str = METADATA_TABLE,
dump: DumpType | None = None,
load: LoadType | None = None,
min_size: int = 4,
Expand All @@ -110,7 +108,6 @@ def __init__(
self.versions_table = Identifier(versions_table)
self.jobs_table = Identifier(jobs_table)
self.stats_table = Identifier(stats_table)
self.metadata_table = Identifier(metadata_table)
self.pool = pool
self.min_size = min_size
self.max_size = max_size
Expand Down Expand Up @@ -152,7 +149,6 @@ async def init_db(self) -> None:
migrations = get_migrations(
jobs_table=self.jobs_table,
stats_table=self.stats_table,
worker_metadata_table=self.metadata_table,
)
target_version = migrations[-1][0]
await cursor.execute(
Expand Down Expand Up @@ -231,14 +227,14 @@ async def write_worker_metadata(
SQL(
dedent(
"""
INSERT INTO {worker_metadata_table} (worker_id, queue_key, metadata, expire_at)
INSERT INTO {stats_table} (worker_id, queue_key, metadata, expire_at)
VALUES
(%(worker_id)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET metadata = %(metadata)s, queue_key = %(queue_key)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(worker_metadata_table=self.metadata_table),
).format(stats_table=self.stats_table),
{
"worker_id": worker_id,
"metadata": json.dumps(metadata),
Expand All @@ -253,13 +249,12 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
SQL(
dedent(
"""
SELECT stats.worker_id, stats.stats, meta.queue_key, meta.metadata
SELECT worker_id, stats, queue_key, metadata
FROM {stats_table} stats
LEFT JOIN {worker_metadata_table} meta ON meta.worker_id = stats.worker_id
WHERE NOW() <= TO_TIMESTAMP(stats.expire_at)
"""
)
).format(stats_table=self.stats_table, worker_metadata_table=self.metadata_table),
).format(stats_table=self.stats_table),
)
results = await cursor.fetchall()
workers: dict[str, WorkerInfo] = {
Expand Down Expand Up @@ -413,21 +408,6 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
{"now": now_ts},
)

await cursor.execute(
SQL(
dedent(
"""
-- Delete expired worker metadata
DELETE FROM {worker_metadata_table}
WHERE %(now)s >= expire_at;
"""
)
).format(
worker_metadata_table=self.metadata_table,
),
{"now": now_ts},
)

await cursor.execute(
SQL(
dedent(
Expand Down Expand Up @@ -778,23 +758,32 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_stats_and_metadata(
self,
worker_id: str,
queue_key: str,
metadata: t.Optional[dict],
stats: WorkerStats,
ttl: int,
) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
dedent(
"""
INSERT INTO {stats_table} (worker_id, stats, expire_at)
VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
INSERT INTO {stats_table} (worker_id, stats, queue_key, metadata, expire_at)
VALUES (%(worker_id)s, %(stats)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
SET stats = %(stats)s, queue_key = %(queue_key)s, metadata = %(metadata)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(stats_table=self.stats_table),
{
"worker_id": worker_id,
"stats": json.dumps(stats),
"ttl": ttl,
"queue_key": queue_key,
"metadata": json.dumps(metadata),
},
)

Expand Down
11 changes: 3 additions & 8 deletions saq/queue/postgres_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def get_migrations(
jobs_table: Identifier,
stats_table: Identifier,
worker_metadata_table: Identifier,
) -> t.List[t.Tuple[int, t.List[Composed]]]:
return [
(
Expand Down Expand Up @@ -49,14 +48,10 @@ def get_migrations(
[
SQL(
dedent("""
CREATE TABLE IF NOT EXISTS {worker_metadata_table} (
worker_id TEXT PRIMARY KEY,
queue_key TEXT NOT NULL,
expire_at BIGINT NOT NULL,
metadata JSONB
);
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS metadata JSONB;
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS queue_key TEXT;
""")
).format(worker_metadata_table=worker_metadata_table),
).format(stats_table=stats_table),
],
),
]
47 changes: 18 additions & 29 deletions saq/queue/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,30 +131,15 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
*_, worker_uuid = key_str.split(":")
worker_uuids.append(worker_uuid)

worker_stats = await self.redis.mget(
self.namespace(f"stats:{worker_uuid}") for worker_uuid in worker_uuids
)
worker_metadata = await self.redis.mget(
self.namespace(f"metadata:{worker_uuid}") for worker_uuid in worker_uuids
self.namespace(f"metadata_and_stats:{worker_uuid}") for worker_uuid in worker_uuids
)
workers: dict[str, WorkerInfo] = {}
worker_metadata_dict = dict(zip(worker_uuids, worker_metadata))
worker_stats_dict = dict(zip(worker_uuids, worker_stats))
for worker in worker_uuids:
workers[worker] = {
"queue_key": None,
"stats": None,
"metadata": None,
}
metadata = worker_metadata_dict[worker]
metadata = worker_metadata_dict.get(worker)
if metadata:
metadata_obj = json.loads(metadata)
workers[worker]["metadata"] = metadata_obj["metadata"]
workers[worker]["queue_key"] = metadata_obj["queue_key"]
stats = worker_stats_dict[worker]
if stats:
stats_obj = json.loads(stats)
workers[worker]["stats"] = stats_obj
workers[worker] = json.loads(metadata)

queued = await self.count("queued")
active = await self.count("active")
Expand Down Expand Up @@ -405,26 +390,30 @@ async def finish_abort(self, job: Job) -> None:
await self.redis.delete(job.abort_id)
await super().finish_abort(job)

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_stats_and_metadata(
self,
worker_id: str,
queue_key: str,
metadata: t.Optional[dict],
stats: WorkerStats,
ttl: int,
) -> None:
current = now()
data: WorkerInfo = {
"queue_key": queue_key,
"metadata": metadata,
"stats": stats,
}
async with self.redis.pipeline(transaction=True) as pipe:
key = self.namespace(f"stats:{worker_id}")
key = self.namespace(f"metadata_and_stats:{worker_id}")
await (
pipe.setex(key, ttl, json.dumps(stats))
pipe.setex(key, ttl, json.dumps(data))
.zremrangebyscore(self._stats, 0, current)
.zadd(self._stats, {key: current + millis(ttl)})
.expire(self._stats, ttl)
.execute()
)

async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
) -> None:
async with self.redis.pipeline(transaction=True) as pipe:
key = self.namespace(f"metadata:{worker_id}")
metadata = {"queue_key": queue_key, "metadata": metadata}
await pipe.setex(key, ttl, json.dumps(metadata)).expire(key, ttl).execute()

async def _retry(self, job: Job, error: str | None) -> None:
job_id = job.id
next_retry_delay = job.next_retry_delay()
Expand Down
6 changes: 2 additions & 4 deletions saq/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,12 @@ class TimersDict(t.TypedDict):

schedule: int
"How often we poll to schedule jobs in seconds (default 1)"
stats: int
"How often to update stats in seconds (default 10)"
stats_and_metadata: int
"How often to update stats and metadata in seconds (default 10)"
sweep: int
"How often to clean up stuck jobs in seconds (default 60)"
abort: int
"How often to check if a job is aborted in seconds (default 1)"
metadata: int
"How often to write worker metadata in seconds (default 60)"


class PartialTimersDict(TimersDict, total=False):
Expand Down
Loading

0 comments on commit 855b913

Please sign in to comment.