Skip to content

Commit

Permalink
Support migration progress tracking (#141)
Browse files Browse the repository at this point in the history
Add `progress_func` argument to schema.upgrade
  • Loading branch information
margrietpalm authored Nov 26, 2024
1 parent 10a51dc commit e08ed14
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 4 deletions.
18 changes: 14 additions & 4 deletions threedi_schema/application/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..infrastructure.spatial_index import ensure_spatial_indexes
from ..infrastructure.spatialite_versions import copy_models, get_spatialite_version
from .errors import MigrationMissingError, UpgradeFailedError
from .upgrade_utils import setup_logging

__all__ = ["ModelSchema"]

Expand All @@ -39,11 +40,12 @@ def get_schema_version():
return int(env.get_head_revision())


def _upgrade_database(db, revision="head", unsafe=True):
def _upgrade_database(db, revision="head", unsafe=True, progress_func=None):
"""Upgrade ThreediDatabase instance"""
engine = db.engine

config = get_alembic_config(engine, unsafe=unsafe)
if progress_func is not None:
setup_logging(db.schema, revision, config, progress_func)
alembic_command.upgrade(config, revision)


Expand Down Expand Up @@ -88,6 +90,7 @@ def upgrade(
backup=True,
upgrade_spatialite_version=False,
convert_to_geopackage=False,
progress_func=None,
):
"""Upgrade the database to the latest version.
Expand All @@ -106,6 +109,9 @@ def upgrade(
Specify 'convert_to_geopackage=True' to also convert from spatialite
to geopackage file version after the upgrade.
Specify a 'progress_func' to handle progress updates. `progress_func` should
expect a single argument representing the fraction of progress
"""
try:
rev_nr = get_schema_version() if revision == "head" else int(revision)
Expand All @@ -127,9 +133,13 @@ def upgrade(
)
if backup:
with self.db.file_transaction() as work_db:
_upgrade_database(work_db, revision=revision, unsafe=True)
_upgrade_database(
work_db, revision=revision, unsafe=True, progress_func=progress_func
)
else:
_upgrade_database(self.db, revision=revision, unsafe=False)
_upgrade_database(
self.db, revision=revision, unsafe=False, progress_func=progress_func
)
if upgrade_spatialite_version:
self.upgrade_spatialite_version()
elif convert_to_geopackage:
Expand Down
81 changes: 81 additions & 0 deletions threedi_schema/application/upgrade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging
from typing import Callable, TYPE_CHECKING

from alembic.config import Config
from alembic.script import ScriptDirectory

if TYPE_CHECKING:
from .schema import ModelSchema
else:
ModelSchema = None


class ProgressHandler(logging.Handler):
def __init__(self, progress_func, total_steps):
super().__init__()
self.progress_func = progress_func
self.total_steps = total_steps
self.current_step = 0

def emit(self, record):
msg = record.getMessage()
if msg.startswith("Running upgrade"):
self.progress_func(100 * self.current_step / self.total_steps)
self.current_step += 1


def get_upgrade_steps_count(
config: Config, current_revision: int, target_revision: str = "head"
) -> int:
"""
Count number of upgrade steps for a schematisation upgrade.
Args:
config: Config parameter containing the configuration information
current_revision: current revision as integer
target_revision: target revision as zero-padded 4 digit string or "head"
"""
if target_revision != "head":
try:
int(target_revision)
except TypeError:
# this should lead to issues in the upgrade pipeline, lets not take over that error handling here
return 0
# walk_revisions also includes the revision from current_revision to previous
# reduce the number of steps with 1
offset = -1
# The first defined revision is 200; revision numbers < 200 will cause walk_revisions to fail
if current_revision < 200:
current_revision = 200
# set offset to 0 because previous to current is not included in walk_revisions
offset = 0
if target_revision != "head" and int(target_revision) < current_revision:
# assume that this will be correctly handled by alembic
return 0
current_revision_str = f"{current_revision:04d}"
script = ScriptDirectory.from_config(config)
# Determine upgrade steps
revisions = script.walk_revisions(current_revision_str, target_revision)
return len(list(revisions)) + offset


def setup_logging(
schema: ModelSchema,
target_revision: str,
config: Config,
progress_func: Callable[[float], None],
):
"""
Set up logging for schematisation upgrade
Args:
schema: ModelSchema object representing the current schema of the application
target_revision: A str specifying the target revision for migration
config: Config object containing configuration settings
progress_func: A Callable with a single argument of type float, used to track progress during migration
"""
n_steps = get_upgrade_steps_count(config, schema.get_version(), target_revision)
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
handler = ProgressHandler(progress_func, total_steps=n_steps)
logger.addHandler(handler)
56 changes: 56 additions & 0 deletions threedi_schema/tests/test_upgrade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from pathlib import Path
from unittest.mock import call, MagicMock

import pytest

from threedi_schema.application import upgrade_utils
from threedi_schema.application.schema import get_alembic_config
from threedi_schema.application.threedi_database import ThreediDatabase

data_dir = Path(__file__).parent / "data"


def test_progress_handler():
progress_func = MagicMock()
mock_record = MagicMock(levelno=logging.INFO, percent=40)
expected_calls = [call(100 * i / 5) for i in range(5)]
handler = upgrade_utils.ProgressHandler(progress_func, total_steps=5)
for _ in range(5):
handler.handle(mock_record)
assert progress_func.call_args_list == expected_calls


@pytest.mark.parametrize(
"target_revision, nsteps_expected", [("0226", 5), ("0200", 0), (None, 0)]
)
def test_get_upgrade_steps_count(target_revision, nsteps_expected):
schema = ThreediDatabase(data_dir.joinpath("v2_bergermeer_221.sqlite")).schema
nsteps = upgrade_utils.get_upgrade_steps_count(
config=get_alembic_config(),
current_revision=schema.get_version(),
target_revision=target_revision,
)
assert nsteps == nsteps_expected


def test_get_upgrade_steps_count_pre_200(oldest_sqlite):
schema = oldest_sqlite.schema
nsteps = upgrade_utils.get_upgrade_steps_count(
config=get_alembic_config(),
current_revision=schema.get_version(),
target_revision="0226",
)
assert nsteps == 27


def test_upgrade_with_progress_func(oldest_sqlite):
schema = oldest_sqlite.schema
progress_func = MagicMock()
schema.upgrade(
backup=False,
upgrade_spatialite_version=False,
progress_func=progress_func,
revision="0201",
)
assert progress_func.call_args_list == [call(0.0), call(50.0)]

0 comments on commit e08ed14

Please sign in to comment.