Skip to content

Commit

Permalink
chore: move to a singular table
Browse files Browse the repository at this point in the history
  • Loading branch information
benfdking committed Jan 21, 2025
1 parent 8e22b2b commit 099a0a2
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 158 deletions.
41 changes: 28 additions & 13 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LoadType,
QueueInfo,
WorkerStats,
WorkerInfo,
)


Expand Down Expand Up @@ -131,24 +132,25 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
"""
Returns & updates stats on the queue.
Write stats and metadata for a worker.
Args:
worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats
are attributed to the worker and not the queue instance in the proxy server.
queue_key: The key of the queue.
metadata: The metadata to write.
stats: The stats to write.
ttl: The time-to-live in seconds.
"""
pass

@abstractmethod
async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
) -> None:
pass

@abstractmethod
async def _retry(self, job: Job, error: str | None) -> None:
pass
Expand Down Expand Up @@ -200,15 +202,19 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
async def worker_info(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict] = None, ttl: int = 60
) -> WorkerInfo:
"""
Method to be used by workers to update stats.
Method to be used by workers to update worker info.
Args:
worker_id: The worker id.
ttl: Time stats are valid for in seconds.
queue_key: The key of the queue.
metadata: The metadata to write.
Returns: The stats.
Returns: Worker info.
"""
stats: WorkerStats = {
"complete": self.complete,
Expand All @@ -217,8 +223,17 @@ async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
"aborted": self.aborted,
"uptime": now() - self.started,
}
await self.write_stats(worker_id, stats, ttl)
return stats
info: WorkerInfo = {
"stats": stats,
"queue_key": queue_key,
"metadata": metadata,
}
await self.write_worker_info(
worker_id,
info,
ttl=ttl,
)
return info

def register_before_enqueue(self, callback: BeforeEnqueueType) -> None:
self._before_enqueues[id(callback)] = callback
Expand Down
38 changes: 18 additions & 20 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import typing as t


from saq.errors import MissingDependencyError
from saq.job import Job, Status
from saq.queue.base import Queue
Expand All @@ -18,7 +17,7 @@
from saq.types import (
CountKind,
QueueInfo,
WorkerStats,
WorkerInfo,
)

try:
Expand Down Expand Up @@ -102,17 +101,15 @@ async def process(self, body: str) -> str | None:
jobs=req["jobs"], offset=req["offset"], limit=req["limit"]
)
)
if kind == "write_stats":
await self.queue.write_stats(
worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"]
)
return None
if kind == "write_worker_metadata":
await self.queue.write_worker_metadata(
metadata=req["metadata"],
ttl=req["ttl"],
queue_key=req["queue_key"],
if kind == "write_worker_info":
await self.queue.write_worker_info(
worker_id=req["worker_id"],
info={
"stats": req["stats"],
"queue_key": req["queue_key"],
"metadata": req["metadata"],
},
ttl=req["ttl"],
)
return None
raise ValueError(f"Invalid request {body}")
Expand Down Expand Up @@ -216,18 +213,19 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)

async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
await self._send(
"write_worker_metadata",
"write_worker_info",
worker_id=worker_id,
metadata=metadata,
stats=info["stats"],
ttl=ttl,
queue_key=queue_key,
queue_key=info["queue_key"],
metadata=info["metadata"],
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
Expand Down
66 changes: 14 additions & 52 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
DumpType,
LoadType,
QueueInfo,
WorkerStats,
WorkerInfo,
)

Expand All @@ -51,7 +50,6 @@
JOBS_TABLE = "saq_jobs"
STATS_TABLE = "saq_stats"
VERSIONS_TABLE = "saq_versions"
METADATA_TABLE = "saq_worker_metadata"


class PostgresQueue(Queue):
Expand Down Expand Up @@ -95,7 +93,6 @@ def __init__(
versions_table: str = VERSIONS_TABLE,
jobs_table: str = JOBS_TABLE,
stats_table: str = STATS_TABLE,
metadata_table: str = METADATA_TABLE,
dump: DumpType | None = None,
load: LoadType | None = None,
min_size: int = 4,
Expand All @@ -110,7 +107,6 @@ def __init__(
self.versions_table = Identifier(versions_table)
self.jobs_table = Identifier(jobs_table)
self.stats_table = Identifier(stats_table)
self.metadata_table = Identifier(metadata_table)
self.pool = pool
self.min_size = min_size
self.max_size = max_size
Expand Down Expand Up @@ -152,7 +148,6 @@ async def init_db(self) -> None:
migrations = get_migrations(
jobs_table=self.jobs_table,
stats_table=self.stats_table,
worker_metadata_table=self.metadata_table,
)
target_version = migrations[-1][0]
await cursor.execute(
Expand Down Expand Up @@ -223,43 +218,18 @@ async def disconnect(self) -> None:
await self.pool.close()
self._has_sweep_lock = False

async def write_worker_metadata(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict], ttl: int
) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
dedent(
"""
INSERT INTO {worker_metadata_table} (worker_id, queue_key, metadata, expire_at)
VALUES
(%(worker_id)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET metadata = %(metadata)s, queue_key = %(queue_key)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(worker_metadata_table=self.metadata_table),
{
"worker_id": worker_id,
"metadata": json.dumps(metadata),
"queue_key": queue_key,
"ttl": ttl,
},
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
async with self.pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
SQL(
dedent(
"""
SELECT stats.worker_id, stats.stats, meta.queue_key, meta.metadata
SELECT worker_id, stats, queue_key, metadata
FROM {stats_table} stats
LEFT JOIN {worker_metadata_table} meta ON meta.worker_id = stats.worker_id
WHERE NOW() <= TO_TIMESTAMP(stats.expire_at)
"""
)
).format(stats_table=self.stats_table, worker_metadata_table=self.metadata_table),
).format(stats_table=self.stats_table),
)
results = await cursor.fetchall()
workers: dict[str, WorkerInfo] = {
Expand Down Expand Up @@ -413,21 +383,6 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
{"now": now_ts},
)

await cursor.execute(
SQL(
dedent(
"""
-- Delete expired worker metadata
DELETE FROM {worker_metadata_table}
WHERE %(now)s >= expire_at;
"""
)
).format(
worker_metadata_table=self.metadata_table,
),
{"now": now_ts},
)

await cursor.execute(
SQL(
dedent(
Expand Down Expand Up @@ -778,23 +733,30 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
dedent(
"""
INSERT INTO {stats_table} (worker_id, stats, expire_at)
VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
INSERT INTO {stats_table} (worker_id, stats, queue_key, metadata, expire_at)
VALUES (%(worker_id)s, %(stats)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
SET stats = %(stats)s, queue_key = %(queue_key)s, metadata = %(metadata)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(stats_table=self.stats_table),
{
"worker_id": worker_id,
"stats": json.dumps(stats),
"stats": json.dumps(info["stats"]),
"ttl": ttl,
"queue_key": info["queue_key"],
"metadata": json.dumps(info["metadata"]),
},
)

Expand Down
11 changes: 3 additions & 8 deletions saq/queue/postgres_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def get_migrations(
jobs_table: Identifier,
stats_table: Identifier,
worker_metadata_table: Identifier,
) -> t.List[t.Tuple[int, t.List[Composed]]]:
return [
(
Expand Down Expand Up @@ -49,14 +48,10 @@ def get_migrations(
[
SQL(
dedent("""
CREATE TABLE IF NOT EXISTS {worker_metadata_table} (
worker_id TEXT PRIMARY KEY,
queue_key TEXT NOT NULL,
expire_at BIGINT NOT NULL,
metadata JSONB
);
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS metadata JSONB;
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS queue_key TEXT;
""")
).format(worker_metadata_table=worker_metadata_table),
).format(stats_table=stats_table),
],
),
]
Loading

0 comments on commit 099a0a2

Please sign in to comment.