Skip to content

Commit

Permalink
feat: introduce simple migration framework
Browse files Browse the repository at this point in the history
  • Loading branch information
benfdking committed Jan 14, 2025
1 parent b7d1760 commit a9b90a1
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 37 deletions.
60 changes: 54 additions & 6 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from saq.multiplexer import Multiplexer
from saq.queue.base import Queue, logger
from saq.queue.postgres_ddl import DDL_STATEMENTS
from saq.queue.postgres_migrations import get_migrations
from saq.utils import now, now_seconds

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +49,7 @@
DEQUEUE = "saq:dequeue"
JOBS_TABLE = "saq_jobs"
STATS_TABLE = "saq_stats"
VERSIONS_TABLE = "saq_versions"


class PostgresQueue(Queue):
Expand All @@ -58,6 +59,7 @@ class PostgresQueue(Queue):
Args:
pool: instance of psycopg_pool.AsyncConnectionPool
name: name of the queue (default "default")
versions_table: name of the Postgres table SAQ will use to maintain migrations
jobs_table: name of the Postgres table SAQ will write jobs to (default "saq_jobs")
stats_table: name of the Postgres table SAQ will write stats to (default "saq_stats")
dump: lambda that takes a dictionary and outputs bytes (default `json.dumps`)
Expand Down Expand Up @@ -88,6 +90,7 @@ def __init__(
self,
pool: AsyncConnectionPool,
name: str = "default",
versions_table: str = VERSIONS_TABLE,
jobs_table: str = JOBS_TABLE,
stats_table: str = STATS_TABLE,
dump: DumpType | None = None,
Expand All @@ -101,6 +104,7 @@ def __init__(
) -> None:
super().__init__(name=name, dump=dump, load=load)

self.versions_table = Identifier(versions_table)
self.jobs_table = Identifier(jobs_table)
self.stats_table = Identifier(stats_table)
self.pool = pool
Expand Down Expand Up @@ -128,14 +132,58 @@ async def init_db(self) -> None:
{"key1": self.saq_lock_keyspace},
)
result = await cursor.fetchone()

if result and not result[0]:
return

for statement in DDL_STATEMENTS:
await cursor.execute(
SQL(statement).format(jobs_table=self.jobs_table, stats_table=self.stats_table)
)
await cursor.execute(
SQL(
dedent("""
CREATE TABLE IF NOT EXISTS {versions_table} (
version INT
);
""")
).format(versions_table=self.versions_table)
)

migrations = get_migrations(
jobs_table=self.jobs_table,
stats_table=self.stats_table,
)
target_version = len(migrations)
await cursor.execute(
SQL(
dedent(
"""
SELECT version FROM {versions_table}
"""
)
).format(versions_table=self.versions_table),
)
result = await cursor.fetchone()
if result is not None:
current_version = result[0]
if current_version == target_version:
return
if current_version > target_version:
raise ValueError("The library version is behind the schema version.")

current = result[0] if result else 0
for migration in migrations[current:]:
for migration_statement in migration:
await cursor.execute(
migration_statement,
)

await cursor.execute(
SQL(
dedent(
"""
DELETE FROM {versions_table};
INSERT INTO {versions_table} (version) VALUES ({target_version});
"""
)
).format(versions_table=self.versions_table, target_version=target_version),
)

async def connect(self) -> None:
if self.pool._opened:
Expand Down
31 changes: 0 additions & 31 deletions saq/queue/postgres_ddl.py

This file was deleted.

43 changes: 43 additions & 0 deletions saq/queue/postgres_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing as t

from psycopg.sql import Identifier, SQL, Composed
from textwrap import dedent


def get_migrations(
jobs_table: Identifier,
stats_table: Identifier,
) -> t.List[t.List[Composed]]:
return [
[
SQL(
dedent("""
CREATE TABLE IF NOT EXISTS {jobs_table} (
key TEXT PRIMARY KEY,
lock_key SERIAL NOT NULL,
job BYTEA NOT NULL,
queue TEXT NOT NULL,
status TEXT NOT NULL,
priority SMALLINT NOT NULL DEFAULT 0,
group_key TEXT,
scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
expire_at BIGINT
);
""")
).format(jobs_table=jobs_table),
SQL(
dedent("""
CREATE INDEX IF NOT EXISTS saq_jobs_dequeue_idx ON {jobs_table} (status, queue, scheduled);
""")
).format(jobs_table=jobs_table),
SQL(
dedent("""
CREATE TABLE IF NOT EXISTS {stats_table} (
worker_id TEXT PRIMARY KEY,
stats JSONB,
expire_at BIGINT
);
""")
).format(stats_table=stats_table),
],
]

0 comments on commit a9b90a1

Please sign in to comment.