Skip to content

Commit

Permalink
fix: linting/formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Oct 13, 2024
1 parent 42b5552 commit 31f2fa3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 54 deletions.
47 changes: 20 additions & 27 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def from_url( # pyright: ignore[reportIncompatibleMethodOverride]
min_size: int = 4,
max_size: int = 20,
**kwargs: t.Any,
) -> PostgresQueue:
) -> PostgresQueue:
"""Create a queue from a postgres url.
Args:
Expand All @@ -90,8 +90,8 @@ def from_url( # pyright: ignore[reportIncompatibleMethodOverride]
max_size: maximum pool size. (default 20)
If greater than 0, this limits the maximum number of connections to Postgres.
Otherwise, maintain `min_size` number of connections.
"""
"""
return cls(create_pool(dsn=url, min_size=min_size, max_size=max_size), **kwargs)

def __init__(
Expand All @@ -111,7 +111,7 @@ def __init__(

self.jobs_table = jobs_table
self.stats_table = stats_table
self.pool = pool
self.pool = pool
self.poll_interval = poll_interval
self.saq_lock_keyspace = saq_lock_keyspace
self.job_lock_keyspace = job_lock_keyspace
Expand All @@ -132,26 +132,23 @@ def __init__(
async def with_connection(
self, connection: PoolConnectionProxy | None = None
) -> t.AsyncGenerator[PoolConnectionProxy]:
async with self.nullcontext(
connection
) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined]
async with self.nullcontext(connection) if connection else self.pool.acquire() as conn: # type: ignore[attr-defined]
yield conn

async def init_db(self) -> None:
async with self.with_connection() as conn, conn.transaction():
cursor = await conn.cursor(
"SELECT pg_try_advisory_lock($1, 0)", self.saq_lock_keyspace,
"SELECT pg_try_advisory_lock($1, 0)",
self.saq_lock_keyspace,
)
result = await cursor.fetchrow()

if result and not result[0]:
return
for statement in DDL_STATEMENTS:
await conn.execute(
statement.format(
jobs_table=self.jobs_table, stats_table=self.stats_table
)
)
statement.format(jobs_table=self.jobs_table, stats_table=self.stats_table)
)

async def connect(self) -> None:
if self._dequeue_conn:
Expand Down Expand Up @@ -225,7 +222,7 @@ async def count(self, kind: CountKind) -> int:
AND queue = $1
AND NOW() >= TO_TIMESTAMP(scheduled)
"""),
self.name,
self.name,
)
elif kind == "active":
result = await conn.fetchval(
Expand Down Expand Up @@ -259,7 +256,7 @@ async def schedule(self, lock: int = 1) -> t.List[str]:
async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
"""Delete jobs and stats past their expiration and sweep stuck jobs"""
swept = []

if not self._has_sweep_lock:
# Attempt to get the sweep lock and hold on to it
async with self._get_dequeue_conn() as conn:
Expand All @@ -273,7 +270,7 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
return []
self._has_sweep_lock = True

async with self.with_connection() as conn, conn.transaction():
async with self.with_connection() as conn, conn.transaction():
await conn.execute(
dedent(f"""
-- Delete expired jobs
Expand All @@ -282,14 +279,14 @@ async def sweep(self, lock: int = 60, abort: float = 5.0) -> list[str]:
AND status IN ('aborted', 'complete', 'failed')
AND NOW() >= TO_TIMESTAMP(expire_at)
"""),
self.name,
self.name,
)
await conn.execute(
dedent(f"""
-- Delete expired stats
DELETE FROM {self.stats_table}
WHERE NOW() >= TO_TIMESTAMP(expire_at);
"""),
"""),
)
results = await conn.fetch(
dedent(
Expand Down Expand Up @@ -354,9 +351,7 @@ async def listen(
if stop:
break

async def notify(
self, job: Job, connection: PoolConnectionProxy | None = None
) -> None:
async def notify(self, job: Job, connection: PoolConnectionProxy | None = None) -> None:
await self._notify(job.key, job.status, connection)

async def update(
Expand Down Expand Up @@ -400,15 +395,13 @@ async def update(

async def job(self, job_key: str) -> Job | None:
async with self.with_connection() as conn, conn.transaction():
cursor = await conn.cursor(
f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key
)
cursor = await conn.cursor(f"SELECT job FROM {self.jobs_table} WHERE key = $1", job_key)
record = await cursor.fetchrow()
return self.deserialize(record.get("job")) if record else None

async def jobs(self, job_keys: Iterable[str]) -> t.List[Job | None]:
keys = list(job_keys)
results: dict[str, bytes | None] = {}
results: dict[str, bytes | None] = {}
async with self.with_connection() as conn, conn.transaction():
async for record in conn.cursor(
f"SELECT key, job FROM {self.jobs_table} WHERE key = ANY($1)", keys
Expand Down Expand Up @@ -607,7 +600,7 @@ async def get_job_status(
{"FOR UPDATE" if for_update else ""}
"""),
key,
)
)
assert result
return result

Expand Down Expand Up @@ -643,7 +636,7 @@ async def _finish(
"""),
key,
)
await self.notify(job, conn)
await self.notify(job, conn)
await self._release_job(key)

async def _notify(
Expand Down Expand Up @@ -732,4 +725,4 @@ async def _notify_callback(
payload: t.Any,
) -> None:
payload_data = json.loads(payload)
self.publish(payload_data["key"], payload_data)
self.publish(payload_data["key"], payload_data)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
extras_require={
"hiredis": ["redis[hiredis]>=4.2.0"],
"http": ["aiohttp"],
"postgres": ["asyncpg"],
"postgres": ["asyncpg"],
"redis": ["redis>=4.2,<6.0"],
"web": ["aiohttp", "aiohttp_basicauth"],
"dev": [
Expand Down
1 change: 0 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing as t

import asyncpg
import psycopg

from saq.queue import Queue
from saq.queue.postgres import PostgresQueue
Expand Down
48 changes: 23 additions & 25 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@
import unittest
from unittest import mock

from psycopg.sql import SQL

from saq.job import Job, Status
from saq.queue import JobError, Queue
from saq.utils import uuid1
from saq.worker import Worker
from tests.helpers import (
cleanup_queue,
cleanup_queue,
create_postgres_queue,
create_redis_queue,
setup_postgres,
teardown_postgres,
setup_postgres,
teardown_postgres,
)


Expand Down Expand Up @@ -583,12 +582,13 @@ async def test_sweep_stats(self) -> None:
await asyncio.sleep(1.5)
await self.queue.sweep()
async with self.queue.pool.acquire() as conn, conn.transaction():
cursor = await conn.cursor("""
cursor = await conn.cursor(
"""
SELECT stats
FROM {}
WHERE worker_id = $1
""".format(self.queue.stats_table),
self.queue.uuid
self.queue.uuid,
)
self.assertIsNone(await cursor.fetchrow())

Expand All @@ -597,19 +597,18 @@ async def test_sweep_stats(self) -> None:
await asyncio.sleep(1)
await self.queue.sweep()
async with self.queue.pool.acquire() as conn, conn.transaction():
cursor = await conn.cursor(

"""
cursor = await conn.cursor(
"""
SELECT stats
FROM {}
WHERE worker_id = $1
""".format(self.queue.stats_table),
self.queue.uuid
self.queue.uuid,
)
self.assertIsNotNone(await cursor.fetchrow())

async def test_job_lock(self) -> None:
query = """
query = """
SELECT count(*)
FROM {} JOIN pg_locks ON lock_key = objid
WHERE key = $1
Expand All @@ -632,11 +631,12 @@ async def test_load_dump_pickle(self) -> None:
job = await self.enqueue("test")

async with self.queue.pool.acquire() as conn, conn.transaction():
result = await conn.fetchrow("""
result = await conn.fetchrow(
"""
SELECT job
FROM {}
WHERE key =$1
""" .format(self.queue.jobs_table),
""".format(self.queue.jobs_table),
job.key,
)
assert result
Expand All @@ -655,47 +655,45 @@ async def test_finish_ttl_positive(self, mock_time: MagicMock) -> None:
await self.finish(job, Status.COMPLETE)
async with self.queue.pool.acquire() as conn:
result = await conn.fetchval(

"""
"""
SELECT expire_at
FROM {}
WHERE key = $1
""".format(self.queue.jobs_table),
job.key,
)
self.assertEqual(result,5)
self.assertEqual(result, 5)

@mock.patch("saq.utils.time")
async def test_finish_ttl_neutral(self, mock_time: MagicMock) -> None:
mock_time.time.return_value = 0
job = await self.enqueue("test", ttl=0)
await self.dequeue()
await self.finish(job, Status.COMPLETE)
async with self.queue.pool.acquire() as conn :
async with self.queue.pool.acquire() as conn:
result = await conn.fetchval(

"""
"""
SELECT expire_at
FROM {}
WHERE key = $1
""" .format(self.queue.jobs_table),
""".format(self.queue.jobs_table),
job.key,
)
self.assertEqual(result,None)
self.assertEqual(result, None)

@mock.patch("saq.utils.time")
async def test_finish_ttl_negative(self, mock_time: MagicMock) -> None:
mock_time.time.return_value = 0
job = await self.enqueue("test", ttl=-1)
await self.dequeue()
await self.finish(job, Status.COMPLETE)
async with self.queue.pool.acquire() as conn :
async with self.queue.pool.acquire() as conn:
result = await conn.fetchval(
"""
"""
SELECT expire_at
FROM {}
WHERE key = $1
""" .format(self.queue.jobs_table),
""".format(self.queue.jobs_table),
job.key,
)
self.assertIsNone(result)
Expand All @@ -713,7 +711,7 @@ async def test_bad_connection(self) -> None:
# Test dequeue still works
self.assertEqual((await self.dequeue()), job)
# Check queue has a new connection
self.assertNotEqual(original_connection,self.queue._dequeue_conn)
self.assertNotEqual(original_connection, self.queue._dequeue_conn)

async def test_group_key(self) -> None:
job1 = await self.enqueue("test", group_key=1)
Expand Down

0 comments on commit 31f2fa3

Please sign in to comment.