diff --git a/saq/queue/postgres.py b/saq/queue/postgres.py index 587d151..91c4155 100644 --- a/saq/queue/postgres.py +++ b/saq/queue/postgres.py @@ -20,7 +20,7 @@ ) from saq.multiplexer import Multiplexer from saq.queue.base import Queue, logger -from saq.queue.postgres_ddl import DDL_STATEMENTS +from saq.queue.postgres_migrations import get_migrations from saq.utils import now, now_seconds if t.TYPE_CHECKING: @@ -49,6 +49,7 @@ DEQUEUE = "saq:dequeue" JOBS_TABLE = "saq_jobs" STATS_TABLE = "saq_stats" +VERSIONS_TABLE = "saq_versions" class PostgresQueue(Queue): @@ -58,6 +59,7 @@ class PostgresQueue(Queue): Args: pool: instance of psycopg_pool.AsyncConnectionPool name: name of the queue (default "default") + versions_table: name of the Postgres table SAQ will use to maintain migrations jobs_table: name of the Postgres table SAQ will write jobs to (default "saq_jobs") stats_table: name of the Postgres table SAQ will write stats to (default "saq_stats") dump: lambda that takes a dictionary and outputs bytes (default `json.dumps`) @@ -88,6 +90,7 @@ def __init__( self, pool: AsyncConnectionPool, name: str = "default", + versions_table: str = VERSIONS_TABLE, jobs_table: str = JOBS_TABLE, stats_table: str = STATS_TABLE, dump: DumpType | None = None, @@ -101,6 +104,7 @@ def __init__( ) -> None: super().__init__(name=name, dump=dump, load=load) + self.versions_table = Identifier(versions_table) self.jobs_table = Identifier(jobs_table) self.stats_table = Identifier(stats_table) self.pool = pool @@ -128,14 +132,58 @@ async def init_db(self) -> None: {"key1": self.saq_lock_keyspace}, ) result = await cursor.fetchone() - if result and not result[0]: return - for statement in DDL_STATEMENTS: - await cursor.execute( - SQL(statement).format(jobs_table=self.jobs_table, stats_table=self.stats_table) - ) + await cursor.execute( + SQL( + dedent(""" + CREATE TABLE IF NOT EXISTS {versions_table} ( + version INT + ); + """) + ).format(versions_table=self.versions_table) + ) + + migrations = get_migrations( + jobs_table=self.jobs_table, + stats_table=self.stats_table, + ) + target_version = len(migrations) + await cursor.execute( + SQL( + dedent( + """ + SELECT version FROM {versions_table} + """ + ) + ).format(versions_table=self.versions_table), + ) + result = await cursor.fetchone() + if result is not None: + current_version = result[0] + if current_version == target_version: + return + if current_version > target_version: + raise ValueError("The library version is behind the schema version.") + + current = result[0] if result else 0 + for migration in migrations[current:]: + for migration_statement in migration: + await cursor.execute( + migration_statement, + ) + + await cursor.execute( + SQL( + dedent( + """ + DELETE FROM {versions_table}; + INSERT INTO {versions_table} (version) VALUES ({target_version}); + """ + ) + ).format(versions_table=self.versions_table, target_version=target_version), + ) async def connect(self) -> None: if self.pool._opened: diff --git a/saq/queue/postgres_ddl.py b/saq/queue/postgres_ddl.py deleted file mode 100644 index a2256bc..0000000 --- a/saq/queue/postgres_ddl.py +++ /dev/null @@ -1,31 +0,0 @@ -CREATE_JOBS_TABLE = """ -CREATE TABLE IF NOT EXISTS {jobs_table} ( - key TEXT PRIMARY KEY, - lock_key SERIAL NOT NULL, - job BYTEA NOT NULL, - queue TEXT NOT NULL, - status TEXT NOT NULL, - priority SMALLINT NOT NULL DEFAULT 0, - group_key TEXT, - scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), - expire_at BIGINT -); -""" - -CREATE_JOBS_DEQUEUE_INDEX = """ -CREATE INDEX IF NOT EXISTS saq_jobs_dequeue_idx ON {jobs_table} (status, queue, scheduled) -""" - -CREATE_STATS_TABLE = """ -CREATE TABLE IF NOT EXISTS {stats_table} ( - worker_id TEXT PRIMARY KEY, - stats JSONB, - expire_at BIGINT -); -""" - -DDL_STATEMENTS = [ - CREATE_JOBS_TABLE, - CREATE_JOBS_DEQUEUE_INDEX, - CREATE_STATS_TABLE, -] diff --git a/saq/queue/postgres_migrations.py b/saq/queue/postgres_migrations.py new file mode 100644 index 0000000..c877e78 --- /dev/null +++ b/saq/queue/postgres_migrations.py @@ -0,0 +1,43 @@ +import typing as t + +from psycopg.sql import Identifier, SQL, Composed +from textwrap import dedent + + +def get_migrations( + jobs_table: Identifier, + stats_table: Identifier, +) -> t.List[t.List[Composed]]: + return [ + [ + SQL( + dedent(""" +CREATE TABLE IF NOT EXISTS {jobs_table} ( + key TEXT PRIMARY KEY, + lock_key SERIAL NOT NULL, + job BYTEA NOT NULL, + queue TEXT NOT NULL, + status TEXT NOT NULL, + priority SMALLINT NOT NULL DEFAULT 0, + group_key TEXT, + scheduled BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), + expire_at BIGINT +); +""") + ).format(jobs_table=jobs_table), + SQL( + dedent(""" +CREATE INDEX IF NOT EXISTS saq_jobs_dequeue_idx ON {jobs_table} (status, queue, scheduled); +""") + ).format(jobs_table=jobs_table), + SQL( + dedent(""" +CREATE TABLE IF NOT EXISTS {stats_table} ( + worker_id TEXT PRIMARY KEY, + stats JSONB, + expire_at BIGINT +); +""") + ).format(stats_table=stats_table), + ], + ]