Skip to content

Commit

Permalink
chore: address feeedback
Browse files Browse the repository at this point in the history
  • Loading branch information
benfdking committed Jan 8, 2025
1 parent c7c6d0e commit 395c5a6
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 52 deletions.
8 changes: 4 additions & 4 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
DumpType,
LoadType,
QueueInfo,
QueueStats,
WorkerStats,
)


Expand Down Expand Up @@ -131,7 +131,7 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
"""
Returns & updates stats on the queue.
Expand Down Expand Up @@ -194,7 +194,7 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, worker_id: str, ttl: int = 60) -> QueueStats:
async def stats(self, worker_id: str, ttl: int = 60) -> WorkerStats:
"""
Method to be used by workers to update stats.
Expand All @@ -204,7 +204,7 @@ async def stats(self, worker_id: str, ttl: int = 60) -> QueueStats:
Returns: The stats.
"""
stats: QueueStats = {
stats: WorkerStats = {
"complete": self.complete,
"failed": self.failed,
"retried": self.retried,
Expand Down
4 changes: 2 additions & 2 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from saq.types import (
CountKind,
QueueInfo,
QueueStats,
WorkerStats,
)

try:
Expand Down Expand Up @@ -208,7 +208,7 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
Expand Down
4 changes: 2 additions & 2 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DumpType,
LoadType,
QueueInfo,
QueueStats,
WorkerStats,
)

try:
Expand Down Expand Up @@ -670,7 +670,7 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
Expand Down
4 changes: 2 additions & 2 deletions saq/queue/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ListenCallback,
LoadType,
QueueInfo,
QueueStats,
WorkerStats,
VersionTuple,
)

Expand Down Expand Up @@ -389,7 +389,7 @@ async def finish_abort(self, job: Job) -> None:
await self.redis.delete(job.abort_id)
await super().finish_abort(job)

async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: WorkerStats, ttl: int) -> None:
current = now()
async with self.redis.pipeline(transaction=True) as pipe:
key = self.namespace(f"stats:{worker_id}")
Expand Down
4 changes: 2 additions & 2 deletions saq/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class QueueInfo(t.TypedDict):
"A truncated list containing the jobs that are scheduled to execute soonest"


class QueueStats(t.TypedDict):
class WorkerStats(t.TypedDict):
"""
Queue Stats, could also be used for Worker Stats
Worker Stats
"""

complete: int
Expand Down
6 changes: 3 additions & 3 deletions saq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ReceivesContext,
SettingsDict,
TimersDict,
QueueStats,
WorkerStats,
)


Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
self.burst_jobs_processed = 0
self.burst_jobs_processed_lock = threading.Lock()
self.burst_condition_met = False
self.id = id if id is not None else uuid1()
self.id = uuid1() if id is None else id

if self.burst:
if self.dequeue_timeout <= 0:
Expand Down Expand Up @@ -218,7 +218,7 @@ async def schedule(self, lock: int = 1) -> None:
if scheduled:
logger.info("Scheduled %s", scheduled)

async def stats(self, ttl: int = 60) -> QueueStats:
async def stats(self, ttl: int = 60) -> WorkerStats:
return await self.queue.stats(self.id, ttl)

async def upkeep(self) -> list[Task[None]]:
Expand Down
57 changes: 20 additions & 37 deletions tests/test_http_stats.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,44 @@
"""Validate that the worker id in the context of the http proxy is the id of the worker rather than the queue."""

import unittest
from http.server import HTTPServer, BaseHTTPRequestHandler
from aiohttp import web

from saq import Queue, Worker
from saq.queue.http import HttpProxy
from saq.types import Context
from tests.helpers import setup_postgres, create_postgres_queue, teardown_postgres
import asyncio
import threading


async def echo(_ctx: Context, *, a: int) -> int:
return a


class ProxyRequestHandler(BaseHTTPRequestHandler):
def __init__(self, *args, proxy=None, **kwargs):
self.proxy = proxy
super().__init__(*args, **kwargs)

def do_POST(self):
length = int(self.headers["Content-Length"])
body = self.rfile.read(length).decode("utf-8")
response = asyncio.run(self.proxy.process(body))
class TestQueue(unittest.IsolatedAsyncioTestCase):
async def handle_post(self, request):
body = await request.text()
response = await self.proxy.process(body)
if response:
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(response.encode("utf-8"))
return web.Response(text=response, content_type="application/json")
else:
self.send_response(200)
self.end_headers()
return web.Response(status=200)


class TestQueue(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
await setup_postgres()
self.queue = await create_postgres_queue()
self.proxy = HttpProxy(queue=self.queue)
self.app = web.Application()
self.app.add_routes([web.post("/", self.handle_post)])
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8080)
await self.site.start()

async def asyncTearDown(self) -> None:
await teardown_postgres()
await self.site.stop()
await self.runner.cleanup()

async def test_http_proxy_with_two_workers(self) -> None:
queue = await create_postgres_queue()
proxy = HttpProxy(queue=queue)

server = HTTPServer(
("localhost", 8080),
lambda *args, **kwargs: ProxyRequestHandler(*args, proxy=proxy, **kwargs),
)
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()

queue1 = Queue.from_url("http://localhost:8080/")
await queue1.connect()
queue2 = Queue.from_url("http://localhost:8080/")
Expand All @@ -69,12 +55,12 @@ async def test_http_proxy_with_two_workers(self) -> None:
)
await worker2.stats()
local_worker = Worker(
queue=queue,
queue=self.queue,
functions=[echo],
)
await local_worker.stats()

root_info = await queue.info()
root_info = await self.queue.info()
info1 = await queue1.info()
info2 = await queue2.info()

Expand All @@ -85,7 +71,4 @@ async def test_http_proxy_with_two_workers(self) -> None:

await queue1.disconnect()
await queue2.disconnect()
await queue.disconnect()

server.shutdown()
server_thread.join()
await self.queue.disconnect()

0 comments on commit 395c5a6

Please sign in to comment.