From cd95c0691da00478892e87c4a99ce975d5f98cbd Mon Sep 17 00:00:00 2001 From: julian fong <cyfongj@gmail.com> Date: Sat, 11 Jan 2025 14:03:05 -0500 Subject: [PATCH] configure vera tests to #740 --- tests/test_methods/generator.py | 13 ++++++ .../method_test_impl/peft/test_vera.py | 46 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 tests/test_methods/method_test_impl/peft/test_vera.py diff --git a/tests/test_methods/generator.py b/tests/test_methods/generator.py index 3af8f82b7..cf3947aad 100644 --- a/tests/test_methods/generator.py +++ b/tests/test_methods/generator.py @@ -21,6 +21,7 @@ from tests.test_methods.method_test_impl.peft.test_config_union import ConfigUnionAdapterTest from tests.test_methods.method_test_impl.peft.test_ia3 import IA3TestMixin from tests.test_methods.method_test_impl.peft.test_lora import LoRATestMixin +from tests.test_methods.method_test_impl.peft.test_vera import VeraTestMixin from tests.test_methods.method_test_impl.peft.test_prefix_tuning import PrefixTuningTestMixin from tests.test_methods.method_test_impl.peft.test_prompt_tuning import PromptTuningTestMixin from tests.test_methods.method_test_impl.peft.test_reft import ReftTestMixin @@ -191,6 +192,18 @@ class IA3( if "IA3" not in excluded_tests: test_classes["IA3"] = IA3 + @require_torch + @pytest.mark.vera + class Vera( + model_test_base, + VeraTestMixin, + unittest.TestCase, + ): + pass + + if "Vera" not in excluded_tests: + test_classes["Vera"] = Vera + @require_torch @pytest.mark.lora class LoRA( diff --git a/tests/test_methods/method_test_impl/peft/test_vera.py b/tests/test_methods/method_test_impl/peft/test_vera.py new file mode 100644 index 000000000..1c70b6bfa --- /dev/null +++ b/tests/test_methods/method_test_impl/peft/test_vera.py @@ -0,0 +1,46 @@ +from adapters import VeraConfig +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin +from transformers.testing_utils import require_torch + + +@require_torch +class VeraTestMixin(AdapterMethodBaseTestMixin): + def test_add_vera(self): + model = self.get_model() + self.run_add_test(model, VeraConfig(), ["loras.{name}."]) + + def test_leave_out_vara(self): + model = self.get_model() + self.run_leave_out_test(model, VeraConfig(), self.leave_out_layers) + + def test_linear_average_vera(self): + model = self.get_model() + self.run_linear_average_test(model, VeraConfig(), ["loras.{name}."]) + + def test_delete_vera(self): + model = self.get_model() + self.run_delete_test(model, VeraConfig(), ["loras.{name}."]) + + def test_get_vera(self): + model = self.get_model() + n_layers = len(list(model.iter_layers())) + self.run_get_test(model, VeraConfig(intermediate_lora=True, output_lora=True), n_layers * 3) + + def test_forward_vera(self): + model = self.get_model() + self.run_forward_test(model, VeraConfig(init_weights="vera", intermediate_lora=True, output_lora=True)) + + def test_load_vera(self): + self.run_load_test(VeraConfig()) + + def test_load_full_model_vera(self): + self.run_full_model_load_test(VeraConfig(init_weights="vera")) + + def test_train_vera(self): + self.run_train_test(VeraConfig(init_weights="vera"), ["loras.{name}."]) + + def test_merge_vera(self): + self.run_merge_test(VeraConfig(init_weights="vera")) + + def test_reset_vera(self): + self.run_reset_test(VeraConfig(init_weights="vera"))