From d0f6bc358df6399c5e3c7b535145c07df31e3544 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Mon, 9 Oct 2023 14:02:24 -0400 Subject: [PATCH] test: check weights match (#4649) * test: check weights match * sort weights for sparse * lint --- test/run_tests_model_gen_and_load.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/run_tests_model_gen_and_load.py b/test/run_tests_model_gen_and_load.py index 2315d976d71..aff79e99163 100644 --- a/test/run_tests_model_gen_and_load.py +++ b/test/run_tests_model_gen_and_load.py @@ -179,7 +179,15 @@ def load_model( weights_dir = working_dir / "test_weights" weights_dir.mkdir(parents=True, exist_ok=True) weight_file = str(weights_dir / f"weights_{test_id}.json") + + # Load old weights old_weights = json.load(open(weight_file)) + + # Sort the weights inside both dictionaries based on 'index' + new_weights["weights"].sort(key=lambda x: x["index"]) + old_weights["weights"].sort(key=lambda x: x["index"]) + + # Assert if both sorted weights are equal assert new_weights == old_weights vw.finish() except Exception as e: