Skip to content

Commit

Permalink
chore: move to a singular table
Browse files Browse the repository at this point in the history
  • Loading branch information
benfdking committed Jan 21, 2025
1 parent 8e22b2b commit c01c907
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 161 deletions.
41 changes: 28 additions & 13 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LoadType,
QueueInfo,
WorkerStats,
WorkerInfo,
)


Expand Down Expand Up @@ -131,24 +132,25 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
"""
Returns & updates stats on the queue.
Write stats and metadata for a worker.
Args:
worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats
are attributed to the worker and not the queue instance in the proxy server.
queue_key: The key of the queue.
metadata: The metadata to write.
stats: The stats to write.
ttl: The time-to-live in seconds.
"""
pass

@abstractmethod
async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
) -> None:
pass

@abstractmethod
async def _retry(self, job: Job, error: str | None) -> None:
pass
Expand Down Expand Up @@ -200,15 +202,19 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
async def worker_info(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict] = None, ttl: int = 60
) -> WorkerInfo:
"""
Method to be used by workers to update stats.
Method to be used by workers to update worker info.
Args:
worker_id: The worker id.
ttl: Time stats are valid for in seconds.
queue_key: The key of the queue.
metadata: The metadata to write.
Returns: The stats.
Returns: Worker info.
"""
stats: WorkerStats = {
"complete": self.complete,
Expand All @@ -217,8 +223,17 @@ async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
"aborted": self.aborted,
"uptime": now() - self.started,
}
await self.write_stats(worker_id, stats, ttl)
return stats
info: WorkerInfo = {
"stats": stats,
"queue_key": queue_key,
"metadata": metadata,
}
await self.write_worker_info(
worker_id,
info,
ttl=ttl,
)
return info

def register_before_enqueue(self, callback: BeforeEnqueueType) -> None:
self._before_enqueues[id(callback)] = callback
Expand Down
38 changes: 18 additions & 20 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import typing as t


from saq.errors import MissingDependencyError
from saq.job import Job, Status
from saq.queue.base import Queue
Expand All @@ -18,7 +17,7 @@
from saq.types import (
CountKind,
QueueInfo,
WorkerStats,
WorkerInfo,
)

try:
Expand Down Expand Up @@ -102,17 +101,15 @@ async def process(self, body: str) -> str | None:
jobs=req["jobs"], offset=req["offset"], limit=req["limit"]
)
)
if kind == "write_stats":
await self.queue.write_stats(
worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"]
)
return None
if kind == "write_worker_metadata":
await self.queue.write_worker_metadata(
metadata=req["metadata"],
ttl=req["ttl"],
queue_key=req["queue_key"],
if kind == "write_worker_info":
await self.queue.write_worker_info(
worker_id=req["worker_id"],
info={
"stats": req["stats"],
"queue_key": req["queue_key"],
"metadata": req["metadata"],
},
ttl=req["ttl"],
)
return None
raise ValueError(f"Invalid request {body}")
Expand Down Expand Up @@ -216,18 +213,19 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)

async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
await self._send(
"write_worker_metadata",
"write_worker_info",
worker_id=worker_id,
metadata=metadata,
stats=info["stats"],
ttl=ttl,
queue_key=queue_key,
queue_key=info["queue_key"],
metadata=info["metadata"],
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
Expand Down
74 changes: 19 additions & 55 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
DumpType,
LoadType,
QueueInfo,
WorkerStats,
WorkerInfo,
)

Expand All @@ -51,7 +50,6 @@
JOBS_TABLE = "saq_jobs"
STATS_TABLE = "saq_stats"
VERSIONS_TABLE = "saq_versions"
METADATA_TABLE = "saq_worker_metadata"


class PostgresQueue(Queue):
Expand Down Expand Up @@ -95,7 +93,6 @@ def __init__(
versions_table: str = VERSIONS_TABLE,
jobs_table: str = JOBS_TABLE,
stats_table: str = STATS_TABLE,
metadata_table: str = METADATA_TABLE,
dump: DumpType | None = None,
load: LoadType | None = None,
min_size: int = 4,
Expand All @@ -110,7 +107,6 @@ def __init__(
self.versions_table = Identifier(versions_table)
self.jobs_table = Identifier(jobs_table)
self.stats_table = Identifier(stats_table)
self.metadata_table = Identifier(metadata_table)
self.pool = pool
self.min_size = min_size
self.max_size = max_size
Expand Down Expand Up @@ -152,7 +148,6 @@ async def init_db(self) -> None:
migrations = get_migrations(
jobs_table=self.jobs_table,
stats_table=self.stats_table,
worker_metadata_table=self.metadata_table,
)
target_version = migrations[-1][0]
await cursor.execute(
Expand Down Expand Up @@ -223,43 +218,20 @@ async def disconnect(self) -> None:
await self.pool.close()
self._has_sweep_lock = False

async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
dedent(
"""
INSERT INTO {worker_metadata_table} (worker_id, queue_key, metadata, expire_at)
VALUES
(%(worker_id)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET metadata = %(metadata)s, queue_key = %(queue_key)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(worker_metadata_table=self.metadata_table),
{
"worker_id": worker_id,
"metadata": json.dumps(metadata),
"queue_key": queue_key,
"ttl": ttl,
},
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
async with self.pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
SQL(
dedent(
"""
SELECT stats.worker_id, stats.stats, meta.queue_key, meta.metadata
FROM {stats_table} stats
LEFT JOIN {worker_metadata_table} meta ON meta.worker_id = stats.worker_id
WHERE NOW() <= TO_TIMESTAMP(stats.expire_at)
"""
SELECT worker_id, stats, queue_key, metadata
FROM {stats_table}
WHERE expire_at >= EXTRACT(EPOCH FROM NOW())
AND queue_key = %(queue)s
"""
)
).format(stats_table=self.stats_table, worker_metadata_table=self.metadata_table),
).format(stats_table=self.stats_table),
{"queue": self.name},
)
results = await cursor.fetchall()
workers: dict[str, WorkerInfo] = {
Expand Down Expand Up @@ -413,21 +385,6 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
{"now": now_ts},
)

await cursor.execute(
SQL(
dedent(
"""
-- Delete expired worker metadata
DELETE FROM {worker_metadata_table}
WHERE %(now)s >= expire_at;
"""
)
).format(
worker_metadata_table=self.metadata_table,
),
{"now": now_ts},
)

await cursor.execute(
SQL(
dedent(
Expand Down Expand Up @@ -778,23 +735,30 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
dedent(
"""
INSERT INTO {stats_table} (worker_id, stats, expire_at)
VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
INSERT INTO {stats_table} (worker_id, stats, queue_key, metadata, expire_at)
VALUES (%(worker_id)s, %(stats)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
SET stats = %(stats)s, queue_key = %(queue_key)s, metadata = %(metadata)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(stats_table=self.stats_table),
{
"worker_id": worker_id,
"stats": json.dumps(stats),
"stats": json.dumps(info["stats"]),
"ttl": ttl,
"queue_key": info["queue_key"],
"metadata": json.dumps(info["metadata"]),
},
)

Expand Down
13 changes: 5 additions & 8 deletions saq/queue/postgres_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def get_migrations(
jobs_table: Identifier,
stats_table: Identifier,
worker_metadata_table: Identifier,
) -> t.List[t.Tuple[int, t.List[Composed]]]:
return [
(
Expand Down Expand Up @@ -49,14 +48,12 @@ def get_migrations(
[
SQL(
dedent("""
CREATE TABLE IF NOT EXISTS {worker_metadata_table} (
worker_id TEXT PRIMARY KEY,
queue_key TEXT NOT NULL,
expire_at BIGINT NOT NULL,
metadata JSONB
);
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS metadata JSONB;
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS queue_key TEXT;
CREATE INDEX IF NOT EXISTS saq_stats_expire_at_idx ON {stats_table} (expire_at);
CREATE INDEX IF NOT EXISTS saq_stats_queue_key_idx ON {stats_table} (queue_key);
""")
).format(worker_metadata_table=worker_metadata_table),
).format(stats_table=stats_table),
],
),
]
Loading

0 comments on commit c01c907

Please sign in to comment.