From 61ed745f0077a912711cda1a29e65ae3d7ce63b6 Mon Sep 17 00:00:00 2001 From: Martijn Visser Date: Mon, 3 Feb 2025 10:22:19 +0100 Subject: [PATCH] Fix link_type migration (#2025) I couldn't write migrated models to disk because the validation failed because the Link table had `link_type` twice. One version was filled in with defaults on load since it was missing, and the second one came from the migration code. This removes link_type before renaming, just like was done for min_crest_level, and uses a new method to capture this pattern and test it. --- python/ribasim/ribasim/migrations.py | 16 ++++++++++++---- python/ribasim/tests/test_migrations.py | 3 ++- python/ribasim/tests/test_schemas.py | 15 +++++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/python/ribasim/ribasim/migrations.py b/python/ribasim/ribasim/migrations.py index f818f10d6..81f0e4952 100644 --- a/python/ribasim/ribasim/migrations.py +++ b/python/ribasim/ribasim/migrations.py @@ -7,6 +7,16 @@ # Do the same for write_schema_version in ribasim_qgis/core/geopackage.py +def _rename_column(df, from_colname, to_colname): + """Rename a column, ensuring we don't end up with two of the same name.""" + # If a column has a default value, or is nullable, they are always added. + # Remove that column first, then rename the old column. + # Only call this if from_colname is in the DataFrame. + df.drop(columns="link_type", inplace=True, errors="ignore") + df.rename(columns={from_colname: to_colname}, inplace=True, errors="raise") + return df + + def nodeschema_migration(gdf: GeoDataFrame, schema_version: int) -> GeoDataFrame: if schema_version == 0 and "node_id" in gdf.columns: warnings.warn("Migrating outdated Node table.", UserWarning) @@ -35,7 +45,7 @@ def linkschema_migration(gdf: GeoDataFrame, schema_version: int) -> GeoDataFrame gdf.index.rename("link_id", inplace=True) if schema_version < 4 and "edge_type" in gdf.columns: warnings.warn("Migrating outdated Link table.", UserWarning) - gdf.rename(columns={"edge_type": "link_type"}, inplace=True) + _rename_column(gdf, "edge_type", "link_type") return gdf @@ -91,8 +101,6 @@ def pidcontrolstaticschema_migration(df: DataFrame, schema_version: int) -> Data def outletstaticschema_migration(df: DataFrame, schema_version: int) -> DataFrame: if schema_version < 2: warnings.warn("Migrating outdated Outlet / static table.", UserWarning) - # First remove automatically added empty column. - df.drop(columns="min_upstream_level", inplace=True, errors="ignore") - df.rename(columns={"min_crest_level": "min_upstream_level"}, inplace=True) + _rename_column(df, "min_crest_level", "min_upstream_level") return df diff --git a/python/ribasim/tests/test_migrations.py b/python/ribasim/tests/test_migrations.py index 6ea0c4698..17980f47b 100644 --- a/python/ribasim/tests/test_migrations.py +++ b/python/ribasim/tests/test_migrations.py @@ -9,7 +9,7 @@ @pytest.mark.regression -def test_hws_migration(): +def test_hws_migration(tmp_path): toml_path = root_folder / "models/hws_migration_test/hws.toml" db_path = root_folder / "models/hws_migration_test/database.gpkg" @@ -22,3 +22,4 @@ def test_hws_migration(): assert model.link.df.index.name == "link_id" assert len(model.link.df) == 454 + model.write(tmp_path / "hws_migrated.toml") diff --git a/python/ribasim/tests/test_schemas.py b/python/ribasim/tests/test_schemas.py index 7693d52cc..e1ab09f88 100644 --- a/python/ribasim/tests/test_schemas.py +++ b/python/ribasim/tests/test_schemas.py @@ -1,10 +1,13 @@ from unittest.mock import patch +import pandas as pd import pytest import ribasim +from pandas.testing import assert_frame_equal from pydantic import ValidationError from ribasim import Model from ribasim.db_utils import _get_db_schema_version, _set_db_schema_version +from ribasim.migrations import _rename_column from ribasim.nodes import basin from ribasim.schemas import BasinProfileSchema from shapely.geometry import Point @@ -47,3 +50,15 @@ def test_geometry_validation(): match="Column 'geometry' failed element-wise validator number 0: failure cases", ): basin.Area(geometry=[Point([1.0, 2.0])]) + + +def test_column_rename(): + df = pd.DataFrame({"edge_type": [1], "link_type": [2]}) + _rename_column(df, "edge_type", "link_type") + assert_frame_equal(df, pd.DataFrame({"link_type": [1]})) + df = pd.DataFrame({"edge_type": [1]}) + _rename_column(df, "edge_type", "link_type") + assert_frame_equal(df, pd.DataFrame({"link_type": [1]})) + df = pd.DataFrame({"link_type": [2]}) + with pytest.raises(KeyError, match="\['edge_type'\] not found in axis"): + _rename_column(df, "edge_type", "link_type")