Skip to content

Commit

Permalink
feat!: adding metadata to workers (#201)
Browse files Browse the repository at this point in the history
* feat!: adding metadata to workers

-alter table to store metadata in what is known as the stats table
- unify functions so that we only have really the concept of worker_info to simplify things
  • Loading branch information
benfdking authored Jan 21, 2025
1 parent e560f07 commit 5b32588
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 61 deletions.
35 changes: 28 additions & 7 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LoadType,
QueueInfo,
WorkerStats,
WorkerInfo,
)


Expand Down Expand Up @@ -131,13 +132,20 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
"""
Returns & updates stats on the queue.
Write stats and metadata for a worker.
Args:
worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats
are attributed to the worker and not the queue instance in the proxy server.
queue_key: The key of the queue.
metadata: The metadata to write.
stats: The stats to write.
ttl: The time-to-live in seconds.
"""
Expand Down Expand Up @@ -194,15 +202,19 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
async def worker_info(
self, worker_id: str, queue_key: str, metadata: t.Optional[dict] = None, ttl: int = 60
) -> WorkerInfo:
"""
Method to be used by workers to update stats.
Method to be used by workers to update worker info.
Args:
worker_id: The worker id.
ttl: Time stats are valid for in seconds.
queue_key: The key of the queue.
metadata: The metadata to write.
Returns: The stats.
Returns: Worker info.
"""
stats: WorkerStats = {
"complete": self.complete,
Expand All @@ -211,8 +223,17 @@ async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
"aborted": self.aborted,
"uptime": now() - self.started,
}
await self.write_stats(worker_id, stats, ttl)
return stats
info: WorkerInfo = {
"stats": stats,
"queue_key": queue_key,
"metadata": metadata,
}
await self.write_worker_info(
worker_id,
info,
ttl=ttl,
)
return info

def register_before_enqueue(self, callback: BeforeEnqueueType) -> None:
self._before_enqueues[id(callback)] = callback
Expand Down
31 changes: 24 additions & 7 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import typing as t


from saq.errors import MissingDependencyError
from saq.job import Job, Status
from saq.queue.base import Queue
Expand All @@ -18,7 +17,7 @@
from saq.types import (
CountKind,
QueueInfo,
WorkerStats,
WorkerInfo,
)

try:
Expand Down Expand Up @@ -102,9 +101,15 @@ async def process(self, body: str) -> str | None:
jobs=req["jobs"], offset=req["offset"], limit=req["limit"]
)
)
if kind == "write_stats":
await self.queue.write_stats(
worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"]
if kind == "write_worker_info":
await self.queue.write_worker_info(
worker_id=req["worker_id"],
info={
"stats": req["stats"],
"queue_key": req["queue_key"],
"metadata": req["metadata"],
},
ttl=req["ttl"],
)
return None
raise ValueError(f"Invalid request {body}")
Expand Down Expand Up @@ -208,8 +213,20 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
await self._send(
"write_worker_info",
worker_id=worker_id,
stats=info["stats"],
ttl=ttl,
queue_key=info["queue_key"],
metadata=info["metadata"],
)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
return json.loads(await self._send("info", jobs=jobs, offset=offset, limit=limit))
Expand Down
37 changes: 27 additions & 10 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DumpType,
LoadType,
QueueInfo,
WorkerStats,
WorkerInfo,
)

try:
Expand Down Expand Up @@ -224,14 +224,24 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
SQL(
dedent(
"""
SELECT worker_id, stats FROM {stats_table}
WHERE NOW() <= TO_TIMESTAMP(expire_at)
"""
SELECT worker_id, stats, queue_key, metadata
FROM {stats_table}
WHERE expire_at >= EXTRACT(EPOCH FROM NOW())
AND queue_key = %(queue)s
"""
)
).format(stats_table=self.stats_table),
{"queue": self.name},
)
results = await cursor.fetchall()
workers: dict[str, dict[str, t.Any]] = dict(results)
workers: dict[str, WorkerInfo] = {
worker_id: {
"stats": stats,
"metadata": metadata,
"queue_key": queue_key,
}
for worker_id, stats, queue_key, metadata in results
}

queued = await self.count("queued")
active = await self.count("active")
Expand Down Expand Up @@ -725,23 +735,30 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
dedent(
"""
INSERT INTO {stats_table} (worker_id, stats, expire_at)
VALUES (%(worker_id)s, %(stats)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
INSERT INTO {stats_table} (worker_id, stats, queue_key, metadata, expire_at)
VALUES (%(worker_id)s, %(stats)s, %(queue_key)s, %(metadata)s, EXTRACT(EPOCH FROM NOW()) + %(ttl)s)
ON CONFLICT (worker_id) DO UPDATE
SET stats = %(stats)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
SET stats = %(stats)s, queue_key = %(queue_key)s, metadata = %(metadata)s, expire_at = EXTRACT(EPOCH FROM NOW()) + %(ttl)s
"""
)
).format(stats_table=self.stats_table),
{
"worker_id": worker_id,
"stats": json.dumps(stats),
"stats": json.dumps(info["stats"]),
"ttl": ttl,
"queue_key": info["queue_key"],
"metadata": json.dumps(info["metadata"]),
},
)

Expand Down
13 changes: 13 additions & 0 deletions saq/queue/postgres_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,17 @@ def get_migrations(
).format(stats_table=stats_table),
],
),
(
2,
[
SQL(
dedent("""
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS metadata JSONB;
ALTER TABLE {stats_table} ADD COLUMN IF NOT EXISTS queue_key TEXT;
CREATE INDEX IF NOT EXISTS saq_stats_expire_at_idx ON {stats_table} (expire_at);
CREATE INDEX IF NOT EXISTS saq_stats_queue_key_idx ON {stats_table} (queue_key);
""")
).format(stats_table=stats_table),
],
),
]
31 changes: 18 additions & 13 deletions saq/queue/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
ListenCallback,
LoadType,
QueueInfo,
WorkerStats,
VersionTuple,
WorkerInfo,
)

ID_PREFIX = "saq:job:"
Expand Down Expand Up @@ -130,15 +130,15 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
*_, worker_uuid = key_str.split(":")
worker_uuids.append(worker_uuid)

worker_stats = await self.redis.mget(
self.namespace(f"stats:{worker_uuid}") for worker_uuid in worker_uuids
worker_metadata = await self.redis.mget(
self.namespace(f"worker_info:{worker_uuid}") for worker_uuid in worker_uuids
)

worker_info = {}
for worker_uuid, stats in zip(worker_uuids, worker_stats):
if stats:
stats_obj = json.loads(stats)
worker_info[worker_uuid] = stats_obj
workers: dict[str, WorkerInfo] = {}
worker_metadata_dict = dict(zip(worker_uuids, worker_metadata))
for worker in worker_uuids:
metadata = worker_metadata_dict.get(worker)
if metadata:
workers[worker] = json.loads(metadata)

queued = await self.count("queued")
active = await self.count("active")
Expand Down Expand Up @@ -166,7 +166,7 @@ async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> Qu
job_info = []

return {
"workers": worker_info,
"workers": workers,
"name": self.name,
"queued": queued,
"active": active,
Expand Down Expand Up @@ -389,12 +389,17 @@ async def finish_abort(self, job: Job) -> None:
await self.redis.delete(job.abort_id)
await super().finish_abort(job)

async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async def write_worker_info(
self,
worker_id: str,
info: WorkerInfo,
ttl: int,
) -> None:
current = now()
async with self.redis.pipeline(transaction=True) as pipe:
key = self.namespace(f"stats:{worker_id}")
key = self.namespace(f"worker_info:{worker_id}")
await (
pipe.setex(key, ttl, json.dumps(stats))
pipe.setex(key, ttl, json.dumps(info))
.zremrangebyscore(self._stats, 0, current)
.zadd(self._stats, {key: current + millis(ttl)})
.expire(self._stats, ttl)
Expand Down
16 changes: 13 additions & 3 deletions saq/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,22 @@ class JobTaskContext(t.TypedDict, total=False):
"If this task has been aborted, this is the reason"


class WorkerInfo(t.TypedDict):
"""
Worker Info
"""

queue_key: t.Optional[str]
stats: t.Optional[WorkerStats]
metadata: t.Optional[dict[str, t.Any]]


class QueueInfo(t.TypedDict):
"""
Queue Info
"""

workers: dict[str, dict[str, t.Any]]
workers: dict[str, WorkerInfo]
"Worker information"
name: str
"Queue name"
Expand Down Expand Up @@ -87,8 +97,8 @@ class TimersDict(t.TypedDict):

schedule: int
"How often we poll to schedule jobs in seconds (default 1)"
stats: int
"How often to update stats in seconds (default 10)"
worker_info: int
"How often to update worker info, stats and metadata in seconds (default 10)"
sweep: int
"How often to clean up stuck jobs in seconds (default 60)"
abort: int
Expand Down
Loading

0 comments on commit 5b32588

Please sign in to comment.