-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support migration progress tracking (#141)
Add `progress_func` argument to schema.upgrade
- Loading branch information
1 parent
10a51dc
commit e08ed14
Showing
3 changed files
with
151 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import logging | ||
from typing import Callable, TYPE_CHECKING | ||
|
||
from alembic.config import Config | ||
from alembic.script import ScriptDirectory | ||
|
||
if TYPE_CHECKING: | ||
from .schema import ModelSchema | ||
else: | ||
ModelSchema = None | ||
|
||
|
||
class ProgressHandler(logging.Handler): | ||
def __init__(self, progress_func, total_steps): | ||
super().__init__() | ||
self.progress_func = progress_func | ||
self.total_steps = total_steps | ||
self.current_step = 0 | ||
|
||
def emit(self, record): | ||
msg = record.getMessage() | ||
if msg.startswith("Running upgrade"): | ||
self.progress_func(100 * self.current_step / self.total_steps) | ||
self.current_step += 1 | ||
|
||
|
||
def get_upgrade_steps_count( | ||
config: Config, current_revision: int, target_revision: str = "head" | ||
) -> int: | ||
""" | ||
Count number of upgrade steps for a schematisation upgrade. | ||
Args: | ||
config: Config parameter containing the configuration information | ||
current_revision: current revision as integer | ||
target_revision: target revision as zero-padded 4 digit string or "head" | ||
""" | ||
if target_revision != "head": | ||
try: | ||
int(target_revision) | ||
except TypeError: | ||
# this should lead to issues in the upgrade pipeline, lets not take over that error handling here | ||
return 0 | ||
# walk_revisions also includes the revision from current_revision to previous | ||
# reduce the number of steps with 1 | ||
offset = -1 | ||
# The first defined revision is 200; revision numbers < 200 will cause walk_revisions to fail | ||
if current_revision < 200: | ||
current_revision = 200 | ||
# set offset to 0 because previous to current is not included in walk_revisions | ||
offset = 0 | ||
if target_revision != "head" and int(target_revision) < current_revision: | ||
# assume that this will be correctly handled by alembic | ||
return 0 | ||
current_revision_str = f"{current_revision:04d}" | ||
script = ScriptDirectory.from_config(config) | ||
# Determine upgrade steps | ||
revisions = script.walk_revisions(current_revision_str, target_revision) | ||
return len(list(revisions)) + offset | ||
|
||
|
||
def setup_logging( | ||
schema: ModelSchema, | ||
target_revision: str, | ||
config: Config, | ||
progress_func: Callable[[float], None], | ||
): | ||
""" | ||
Set up logging for schematisation upgrade | ||
Args: | ||
schema: ModelSchema object representing the current schema of the application | ||
target_revision: A str specifying the target revision for migration | ||
config: Config object containing configuration settings | ||
progress_func: A Callable with a single argument of type float, used to track progress during migration | ||
""" | ||
n_steps = get_upgrade_steps_count(config, schema.get_version(), target_revision) | ||
logger = logging.getLogger("alembic.runtime.migration") | ||
logger.setLevel(logging.INFO) | ||
handler = ProgressHandler(progress_func, total_steps=n_steps) | ||
logger.addHandler(handler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import logging | ||
from pathlib import Path | ||
from unittest.mock import call, MagicMock | ||
|
||
import pytest | ||
|
||
from threedi_schema.application import upgrade_utils | ||
from threedi_schema.application.schema import get_alembic_config | ||
from threedi_schema.application.threedi_database import ThreediDatabase | ||
|
||
data_dir = Path(__file__).parent / "data" | ||
|
||
|
||
def test_progress_handler(): | ||
progress_func = MagicMock() | ||
mock_record = MagicMock(levelno=logging.INFO, percent=40) | ||
expected_calls = [call(100 * i / 5) for i in range(5)] | ||
handler = upgrade_utils.ProgressHandler(progress_func, total_steps=5) | ||
for _ in range(5): | ||
handler.handle(mock_record) | ||
assert progress_func.call_args_list == expected_calls | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_revision, nsteps_expected", [("0226", 5), ("0200", 0), (None, 0)] | ||
) | ||
def test_get_upgrade_steps_count(target_revision, nsteps_expected): | ||
schema = ThreediDatabase(data_dir.joinpath("v2_bergermeer_221.sqlite")).schema | ||
nsteps = upgrade_utils.get_upgrade_steps_count( | ||
config=get_alembic_config(), | ||
current_revision=schema.get_version(), | ||
target_revision=target_revision, | ||
) | ||
assert nsteps == nsteps_expected | ||
|
||
|
||
def test_get_upgrade_steps_count_pre_200(oldest_sqlite): | ||
schema = oldest_sqlite.schema | ||
nsteps = upgrade_utils.get_upgrade_steps_count( | ||
config=get_alembic_config(), | ||
current_revision=schema.get_version(), | ||
target_revision="0226", | ||
) | ||
assert nsteps == 27 | ||
|
||
|
||
def test_upgrade_with_progress_func(oldest_sqlite): | ||
schema = oldest_sqlite.schema | ||
progress_func = MagicMock() | ||
schema.upgrade( | ||
backup=False, | ||
upgrade_spatialite_version=False, | ||
progress_func=progress_func, | ||
revision="0201", | ||
) | ||
assert progress_func.call_args_list == [call(0.0), call(50.0)] |