From cd95c0691da00478892e87c4a99ce975d5f98cbd Mon Sep 17 00:00:00 2001
From: julian fong <cyfongj@gmail.com>
Date: Sat, 11 Jan 2025 14:03:05 -0500
Subject: [PATCH] configure vera tests to #740

---
 tests/test_methods/generator.py               | 13 ++++++
 .../method_test_impl/peft/test_vera.py        | 46 +++++++++++++++++++
 2 files changed, 59 insertions(+)
 create mode 100644 tests/test_methods/method_test_impl/peft/test_vera.py

diff --git a/tests/test_methods/generator.py b/tests/test_methods/generator.py
index 3af8f82b7..cf3947aad 100644
--- a/tests/test_methods/generator.py
+++ b/tests/test_methods/generator.py
@@ -21,6 +21,7 @@
 from tests.test_methods.method_test_impl.peft.test_config_union import ConfigUnionAdapterTest
 from tests.test_methods.method_test_impl.peft.test_ia3 import IA3TestMixin
 from tests.test_methods.method_test_impl.peft.test_lora import LoRATestMixin
+from tests.test_methods.method_test_impl.peft.test_vera import VeraTestMixin
 from tests.test_methods.method_test_impl.peft.test_prefix_tuning import PrefixTuningTestMixin
 from tests.test_methods.method_test_impl.peft.test_prompt_tuning import PromptTuningTestMixin
 from tests.test_methods.method_test_impl.peft.test_reft import ReftTestMixin
@@ -191,6 +192,18 @@ class IA3(
     if "IA3" not in excluded_tests:
         test_classes["IA3"] = IA3
 
+    @require_torch
+    @pytest.mark.vera
+    class Vera(
+        model_test_base,
+        VeraTestMixin,
+        unittest.TestCase,
+    ):
+        pass
+
+    if "Vera" not in excluded_tests:
+        test_classes["Vera"] = Vera
+
     @require_torch
     @pytest.mark.lora
     class LoRA(
diff --git a/tests/test_methods/method_test_impl/peft/test_vera.py b/tests/test_methods/method_test_impl/peft/test_vera.py
new file mode 100644
index 000000000..1c70b6bfa
--- /dev/null
+++ b/tests/test_methods/method_test_impl/peft/test_vera.py
@@ -0,0 +1,46 @@
+from adapters import VeraConfig
+from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin
+from transformers.testing_utils import require_torch
+
+
+@require_torch
+class VeraTestMixin(AdapterMethodBaseTestMixin):
+    def test_add_vera(self):
+        model = self.get_model()
+        self.run_add_test(model, VeraConfig(), ["loras.{name}."])
+
+    def test_leave_out_vara(self):
+        model = self.get_model()
+        self.run_leave_out_test(model, VeraConfig(), self.leave_out_layers)
+
+    def test_linear_average_vera(self):
+        model = self.get_model()
+        self.run_linear_average_test(model, VeraConfig(), ["loras.{name}."])
+
+    def test_delete_vera(self):
+        model = self.get_model()
+        self.run_delete_test(model, VeraConfig(), ["loras.{name}."])
+
+    def test_get_vera(self):
+        model = self.get_model()
+        n_layers = len(list(model.iter_layers()))
+        self.run_get_test(model, VeraConfig(intermediate_lora=True, output_lora=True), n_layers * 3)
+
+    def test_forward_vera(self):
+        model = self.get_model()
+        self.run_forward_test(model, VeraConfig(init_weights="vera", intermediate_lora=True, output_lora=True))
+
+    def test_load_vera(self):
+        self.run_load_test(VeraConfig())
+
+    def test_load_full_model_vera(self):
+        self.run_full_model_load_test(VeraConfig(init_weights="vera"))
+
+    def test_train_vera(self):
+        self.run_train_test(VeraConfig(init_weights="vera"), ["loras.{name}."])
+
+    def test_merge_vera(self):
+        self.run_merge_test(VeraConfig(init_weights="vera"))
+
+    def test_reset_vera(self):
+        self.run_reset_test(VeraConfig(init_weights="vera"))