Skip to content

Commit

Permalink
Fix link_type migration (#2025)
Browse files Browse the repository at this point in the history
I couldn't write migrated models to disk because the validation failed
because the Link table had `link_type` twice. One version was filled in
with defaults on load since it was missing, and the second one came from
the migration code.

This removes link_type before renaming, just like was done for
min_crest_level, and uses a new method to capture this pattern and test
it.
  • Loading branch information
visr authored Feb 3, 2025
1 parent f85ef8f commit 61ed745
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
16 changes: 12 additions & 4 deletions python/ribasim/ribasim/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
# Do the same for write_schema_version in ribasim_qgis/core/geopackage.py


def _rename_column(df, from_colname, to_colname):
"""Rename a column, ensuring we don't end up with two of the same name."""
# If a column has a default value, or is nullable, they are always added.
# Remove that column first, then rename the old column.
# Only call this if from_colname is in the DataFrame.
df.drop(columns="link_type", inplace=True, errors="ignore")
df.rename(columns={from_colname: to_colname}, inplace=True, errors="raise")
return df


def nodeschema_migration(gdf: GeoDataFrame, schema_version: int) -> GeoDataFrame:
if schema_version == 0 and "node_id" in gdf.columns:
warnings.warn("Migrating outdated Node table.", UserWarning)
Expand Down Expand Up @@ -35,7 +45,7 @@ def linkschema_migration(gdf: GeoDataFrame, schema_version: int) -> GeoDataFrame
gdf.index.rename("link_id", inplace=True)
if schema_version < 4 and "edge_type" in gdf.columns:
warnings.warn("Migrating outdated Link table.", UserWarning)
gdf.rename(columns={"edge_type": "link_type"}, inplace=True)
_rename_column(gdf, "edge_type", "link_type")

return gdf

Expand Down Expand Up @@ -91,8 +101,6 @@ def pidcontrolstaticschema_migration(df: DataFrame, schema_version: int) -> Data
def outletstaticschema_migration(df: DataFrame, schema_version: int) -> DataFrame:
if schema_version < 2:
warnings.warn("Migrating outdated Outlet / static table.", UserWarning)
# First remove automatically added empty column.
df.drop(columns="min_upstream_level", inplace=True, errors="ignore")
df.rename(columns={"min_crest_level": "min_upstream_level"}, inplace=True)
_rename_column(df, "min_crest_level", "min_upstream_level")

return df
3 changes: 2 additions & 1 deletion python/ribasim/tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@pytest.mark.regression
def test_hws_migration():
def test_hws_migration(tmp_path):
toml_path = root_folder / "models/hws_migration_test/hws.toml"
db_path = root_folder / "models/hws_migration_test/database.gpkg"

Expand All @@ -22,3 +22,4 @@ def test_hws_migration():

assert model.link.df.index.name == "link_id"
assert len(model.link.df) == 454
model.write(tmp_path / "hws_migrated.toml")
15 changes: 15 additions & 0 deletions python/ribasim/tests/test_schemas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from unittest.mock import patch

import pandas as pd
import pytest
import ribasim
from pandas.testing import assert_frame_equal
from pydantic import ValidationError
from ribasim import Model
from ribasim.db_utils import _get_db_schema_version, _set_db_schema_version
from ribasim.migrations import _rename_column
from ribasim.nodes import basin
from ribasim.schemas import BasinProfileSchema
from shapely.geometry import Point
Expand Down Expand Up @@ -47,3 +50,15 @@ def test_geometry_validation():
match="Column 'geometry' failed element-wise validator number 0: <Check is_correct_geometry_type> failure cases",
):
basin.Area(geometry=[Point([1.0, 2.0])])


def test_column_rename():
df = pd.DataFrame({"edge_type": [1], "link_type": [2]})
_rename_column(df, "edge_type", "link_type")
assert_frame_equal(df, pd.DataFrame({"link_type": [1]}))
df = pd.DataFrame({"edge_type": [1]})
_rename_column(df, "edge_type", "link_type")
assert_frame_equal(df, pd.DataFrame({"link_type": [1]}))
df = pd.DataFrame({"link_type": [2]})
with pytest.raises(KeyError, match="\['edge_type'\] not found in axis"):
_rename_column(df, "edge_type", "link_type")

0 comments on commit 61ed745

Please sign in to comment.