From e08ed14109a72ce30fc3859802601c19c55e26ad Mon Sep 17 00:00:00 2001 From: margrietpalm Date: Tue, 26 Nov 2024 15:08:25 +0100 Subject: [PATCH] Support migration progress tracking (#141) Add `progress_func` argument to schema.upgrade --- threedi_schema/application/schema.py | 18 ++++- threedi_schema/application/upgrade_utils.py | 81 +++++++++++++++++++++ threedi_schema/tests/test_upgrade_utils.py | 56 ++++++++++++++ 3 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 threedi_schema/application/upgrade_utils.py create mode 100644 threedi_schema/tests/test_upgrade_utils.py diff --git a/threedi_schema/application/schema.py b/threedi_schema/application/schema.py index 9442d0a..cd05b7f 100644 --- a/threedi_schema/application/schema.py +++ b/threedi_schema/application/schema.py @@ -17,6 +17,7 @@ from ..infrastructure.spatial_index import ensure_spatial_indexes from ..infrastructure.spatialite_versions import copy_models, get_spatialite_version from .errors import MigrationMissingError, UpgradeFailedError +from .upgrade_utils import setup_logging __all__ = ["ModelSchema"] @@ -39,11 +40,12 @@ def get_schema_version(): return int(env.get_head_revision()) -def _upgrade_database(db, revision="head", unsafe=True): +def _upgrade_database(db, revision="head", unsafe=True, progress_func=None): """Upgrade ThreediDatabase instance""" engine = db.engine - config = get_alembic_config(engine, unsafe=unsafe) + if progress_func is not None: + setup_logging(db.schema, revision, config, progress_func) alembic_command.upgrade(config, revision) @@ -88,6 +90,7 @@ def upgrade( backup=True, upgrade_spatialite_version=False, convert_to_geopackage=False, + progress_func=None, ): """Upgrade the database to the latest version. @@ -106,6 +109,9 @@ def upgrade( Specify 'convert_to_geopackage=True' to also convert from spatialite to geopackage file version after the upgrade. + + Specify a 'progress_func' to handle progress updates. `progress_func` should + expect a single argument representing the fraction of progress """ try: rev_nr = get_schema_version() if revision == "head" else int(revision) @@ -127,9 +133,13 @@ def upgrade( ) if backup: with self.db.file_transaction() as work_db: - _upgrade_database(work_db, revision=revision, unsafe=True) + _upgrade_database( + work_db, revision=revision, unsafe=True, progress_func=progress_func + ) else: - _upgrade_database(self.db, revision=revision, unsafe=False) + _upgrade_database( + self.db, revision=revision, unsafe=False, progress_func=progress_func + ) if upgrade_spatialite_version: self.upgrade_spatialite_version() elif convert_to_geopackage: diff --git a/threedi_schema/application/upgrade_utils.py b/threedi_schema/application/upgrade_utils.py new file mode 100644 index 0000000..cdf39f5 --- /dev/null +++ b/threedi_schema/application/upgrade_utils.py @@ -0,0 +1,81 @@ +import logging +from typing import Callable, TYPE_CHECKING + +from alembic.config import Config +from alembic.script import ScriptDirectory + +if TYPE_CHECKING: + from .schema import ModelSchema +else: + ModelSchema = None + + +class ProgressHandler(logging.Handler): + def __init__(self, progress_func, total_steps): + super().__init__() + self.progress_func = progress_func + self.total_steps = total_steps + self.current_step = 0 + + def emit(self, record): + msg = record.getMessage() + if msg.startswith("Running upgrade"): + self.progress_func(100 * self.current_step / self.total_steps) + self.current_step += 1 + + +def get_upgrade_steps_count( + config: Config, current_revision: int, target_revision: str = "head" +) -> int: + """ + Count number of upgrade steps for a schematisation upgrade. + + Args: + config: Config parameter containing the configuration information + current_revision: current revision as integer + target_revision: target revision as zero-padded 4 digit string or "head" + """ + if target_revision != "head": + try: + int(target_revision) + except TypeError: + # this should lead to issues in the upgrade pipeline, lets not take over that error handling here + return 0 + # walk_revisions also includes the revision from current_revision to previous + # reduce the number of steps with 1 + offset = -1 + # The first defined revision is 200; revision numbers < 200 will cause walk_revisions to fail + if current_revision < 200: + current_revision = 200 + # set offset to 0 because previous to current is not included in walk_revisions + offset = 0 + if target_revision != "head" and int(target_revision) < current_revision: + # assume that this will be correctly handled by alembic + return 0 + current_revision_str = f"{current_revision:04d}" + script = ScriptDirectory.from_config(config) + # Determine upgrade steps + revisions = script.walk_revisions(current_revision_str, target_revision) + return len(list(revisions)) + offset + + +def setup_logging( + schema: ModelSchema, + target_revision: str, + config: Config, + progress_func: Callable[[float], None], +): + """ + Set up logging for schematisation upgrade + + Args: + schema: ModelSchema object representing the current schema of the application + target_revision: A str specifying the target revision for migration + config: Config object containing configuration settings + progress_func: A Callable with a single argument of type float, used to track progress during migration + """ + n_steps = get_upgrade_steps_count(config, schema.get_version(), target_revision) + logger = logging.getLogger("alembic.runtime.migration") + logger.setLevel(logging.INFO) + handler = ProgressHandler(progress_func, total_steps=n_steps) + logger.addHandler(handler) diff --git a/threedi_schema/tests/test_upgrade_utils.py b/threedi_schema/tests/test_upgrade_utils.py new file mode 100644 index 0000000..21463b7 --- /dev/null +++ b/threedi_schema/tests/test_upgrade_utils.py @@ -0,0 +1,56 @@ +import logging +from pathlib import Path +from unittest.mock import call, MagicMock + +import pytest + +from threedi_schema.application import upgrade_utils +from threedi_schema.application.schema import get_alembic_config +from threedi_schema.application.threedi_database import ThreediDatabase + +data_dir = Path(__file__).parent / "data" + + +def test_progress_handler(): + progress_func = MagicMock() + mock_record = MagicMock(levelno=logging.INFO, percent=40) + expected_calls = [call(100 * i / 5) for i in range(5)] + handler = upgrade_utils.ProgressHandler(progress_func, total_steps=5) + for _ in range(5): + handler.handle(mock_record) + assert progress_func.call_args_list == expected_calls + + +@pytest.mark.parametrize( + "target_revision, nsteps_expected", [("0226", 5), ("0200", 0), (None, 0)] +) +def test_get_upgrade_steps_count(target_revision, nsteps_expected): + schema = ThreediDatabase(data_dir.joinpath("v2_bergermeer_221.sqlite")).schema + nsteps = upgrade_utils.get_upgrade_steps_count( + config=get_alembic_config(), + current_revision=schema.get_version(), + target_revision=target_revision, + ) + assert nsteps == nsteps_expected + + +def test_get_upgrade_steps_count_pre_200(oldest_sqlite): + schema = oldest_sqlite.schema + nsteps = upgrade_utils.get_upgrade_steps_count( + config=get_alembic_config(), + current_revision=schema.get_version(), + target_revision="0226", + ) + assert nsteps == 27 + + +def test_upgrade_with_progress_func(oldest_sqlite): + schema = oldest_sqlite.schema + progress_func = MagicMock() + schema.upgrade( + backup=False, + upgrade_spatialite_version=False, + progress_func=progress_func, + revision="0201", + ) + assert progress_func.call_args_list == [call(0.0), call(50.0)]