From bd220f2bfb755500e411b70348c878e4948ebd15 Mon Sep 17 00:00:00 2001 From: Martijn Visser Date: Tue, 30 Jan 2024 12:32:30 +0100 Subject: [PATCH 1/3] Add roundtrip test for model files --- python/ribasim/tests/conftest.py | 5 +++++ python/ribasim/tests/test_io.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/python/ribasim/tests/conftest.py b/python/ribasim/tests/conftest.py index f6dbb730d..2f0f3ae9f 100644 --- a/python/ribasim/tests/conftest.py +++ b/python/ribasim/tests/conftest.py @@ -37,3 +37,8 @@ def discrete_control_of_pid_control() -> ribasim.Model: @pytest.fixture() def level_setpoint_with_minmax() -> ribasim.Model: return ribasim_testmodels.level_setpoint_with_minmax_model() + + +@pytest.fixture() +def trivial() -> ribasim.Model: + return ribasim_testmodels.trivial_model() diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index e86b9456b..eb714eabe 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -1,3 +1,5 @@ +import sqlite3 + import pandas as pd import pytest import ribasim @@ -129,3 +131,33 @@ def test_sort(level_setpoint_with_minmax, tmp_path): table_loaded = model_loaded.discrete_control.condition assert table_loaded.df.iloc[0]["greater_than"] == 5.0 __assert_equal(table.df, table_loaded.df) + + +def test_roundtrip(trivial, tmp_path): + model1 = trivial + model1dir = tmp_path / "model1" + model2dir = tmp_path / "model2" + # read a model and then write it to a different path + model1.write(model1dir / "ribasim.toml") + model2 = ribasim.Model(filepath=model1dir / "ribasim.toml") + model2.write(model2dir / "ribasim.toml") + + assert (model1dir / "database.gpkg").is_file() + assert (model2dir / "database.gpkg").is_file() + + # gpkg_contents contains a last_change column that causes a binary diff + # remove that table so we can check if the rest is the same + for modeldir in [model1dir, model2dir]: + conn = sqlite3.connect(modeldir / "database.gpkg") + cursor = conn.cursor() + cursor.execute("DROP TABLE gpkg_contents") + cursor.execute("VACUUM") + conn.commit() + conn.close() + + assert (model1dir / "ribasim.toml").read_text() == ( + model2dir / "ribasim.toml" + ).read_text() + assert (model1dir / "database.gpkg").read_bytes() == ( + model2dir / "database.gpkg" + ).read_bytes() From 3923cf047325b32009d0d12738fe0719a0e5e2f8 Mon Sep 17 00:00:00 2001 From: Martijn Visser Date: Fri, 9 Feb 2024 17:32:27 +0100 Subject: [PATCH 2/3] Leave the network out of the nodes dict. It is not a node. Just a network. --- python/ribasim/ribasim/model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index b22b9a2df..4f1188409 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -232,6 +232,7 @@ def _write_toml(self, fn: FilePath): return fn def _save(self, directory: DirectoryPath, input_dir: DirectoryPath): + self.network._save(directory, input_dir) for sub in self.nodes().values(): sub._save(directory, input_dir) @@ -239,7 +240,7 @@ def nodes(self): return { k: getattr(self, k) for k in self.model_fields.keys() - if isinstance(getattr(self, k), NodeModel) + if isinstance(getattr(self, k), NodeModel) and k != "network" } def children(self): @@ -279,9 +280,6 @@ def validate_model_node_ids(self): for node in self.nodes().values(): nodetype = node.get_input_type() - if nodetype == "Network": - # skip the reference - continue node_ids_data = set(node.node_ids()) node_ids_network = set( self.network.node.df.loc[self.network.node.df["type"] == nodetype].index From 909deb44870f1fcfeb6b577a3c8afdc28836f6be Mon Sep 17 00:00:00 2001 From: Martijn Visser Date: Fri, 9 Feb 2024 17:52:47 +0100 Subject: [PATCH 3/3] Compare the dataframes instead --- python/ribasim/tests/test_io.py | 50 ++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index bdf50c733..b87d5b0af 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -1,5 +1,3 @@ -import sqlite3 - import pandas as pd import pytest import ribasim @@ -10,14 +8,18 @@ from ribasim import Node, Pump -def __assert_equal(a: DataFrame, b: DataFrame) -> None: +def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None: """Like pandas.testing.assert_frame_equal, but ignoring the index.""" if a is None and b is None: return True - # TODO support assert basic == model, ignoring the index for all but node - a = a.reset_index(drop=True) - b = b.reset_index(drop=True) + if is_network: + # We set this on write, needed for GeoPackage. + a.index.name = "fid" + a.index.name = "fid" + else: + a = a.reset_index(drop=True) + b = b.reset_index(drop=True) # avoid comparing datetime64[ns] with datetime64[ms] if "time" in a: @@ -40,8 +42,12 @@ def test_basic(basic, tmp_path): index_a = model_orig.network.node.df.index.to_numpy(int) index_b = model_loaded.network.node.df.index.to_numpy(int) assert_array_equal(index_a, index_b) - __assert_equal(model_orig.network.node.df, model_loaded.network.node.df) - __assert_equal(model_orig.network.edge.df, model_loaded.network.edge.df) + __assert_equal( + model_orig.network.node.df, model_loaded.network.node.df, is_network=True + ) + __assert_equal( + model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True + ) assert model_loaded.basin.time.df is None @@ -58,8 +64,12 @@ def test_basic_transient(basic_transient, tmp_path): model_orig.write(tmp_path / "basic_transient/ribasim.toml") model_loaded = ribasim.Model(filepath=tmp_path / "basic_transient/ribasim.toml") - __assert_equal(model_orig.network.node.df, model_loaded.network.node.df) - __assert_equal(model_orig.network.edge.df, model_loaded.network.edge.df) + __assert_equal( + model_orig.network.node.df, model_loaded.network.node.df, is_network=True + ) + __assert_equal( + model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True + ) time = model_loaded.basin.time assert model_orig.basin.time.df.time[0] == time.df.time[0] @@ -145,19 +155,13 @@ def test_roundtrip(trivial, tmp_path): assert (model1dir / "database.gpkg").is_file() assert (model2dir / "database.gpkg").is_file() - # gpkg_contents contains a last_change column that causes a binary diff - # remove that table so we can check if the rest is the same - for modeldir in [model1dir, model2dir]: - conn = sqlite3.connect(modeldir / "database.gpkg") - cursor = conn.cursor() - cursor.execute("DROP TABLE gpkg_contents") - cursor.execute("VACUUM") - conn.commit() - conn.close() - assert (model1dir / "ribasim.toml").read_text() == ( model2dir / "ribasim.toml" ).read_text() - assert (model1dir / "database.gpkg").read_bytes() == ( - model2dir / "database.gpkg" - ).read_bytes() + + # check if all tables are the same + __assert_equal(model1.network.node.df, model2.network.node.df, is_network=True) + __assert_equal(model1.network.edge.df, model2.network.edge.df, is_network=True) + for node1, node2 in zip(model1.nodes().values(), model2.nodes().values()): + for table1, table2 in zip(node1.tables(), node2.tables()): + __assert_equal(table1.df, table2.df)