Skip to content

Commit

Permalink
configure vera tests to #740
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-fong committed Jan 11, 2025
1 parent ebdf0a7 commit cd95c06
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/test_methods/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tests.test_methods.method_test_impl.peft.test_config_union import ConfigUnionAdapterTest
from tests.test_methods.method_test_impl.peft.test_ia3 import IA3TestMixin
from tests.test_methods.method_test_impl.peft.test_lora import LoRATestMixin
from tests.test_methods.method_test_impl.peft.test_vera import VeraTestMixin
from tests.test_methods.method_test_impl.peft.test_prefix_tuning import PrefixTuningTestMixin
from tests.test_methods.method_test_impl.peft.test_prompt_tuning import PromptTuningTestMixin
from tests.test_methods.method_test_impl.peft.test_reft import ReftTestMixin
Expand Down Expand Up @@ -191,6 +192,18 @@ class IA3(
if "IA3" not in excluded_tests:
test_classes["IA3"] = IA3

@require_torch
@pytest.mark.vera
class Vera(
model_test_base,
VeraTestMixin,
unittest.TestCase,
):
pass

if "Vera" not in excluded_tests:
test_classes["Vera"] = Vera

@require_torch
@pytest.mark.lora
class LoRA(
Expand Down
46 changes: 46 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_vera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from adapters import VeraConfig
from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin
from transformers.testing_utils import require_torch


@require_torch
class VeraTestMixin(AdapterMethodBaseTestMixin):
def test_add_vera(self):
model = self.get_model()
self.run_add_test(model, VeraConfig(), ["loras.{name}."])

def test_leave_out_vara(self):
model = self.get_model()
self.run_leave_out_test(model, VeraConfig(), self.leave_out_layers)

def test_linear_average_vera(self):
model = self.get_model()
self.run_linear_average_test(model, VeraConfig(), ["loras.{name}."])

def test_delete_vera(self):
model = self.get_model()
self.run_delete_test(model, VeraConfig(), ["loras.{name}."])

def test_get_vera(self):
model = self.get_model()
n_layers = len(list(model.iter_layers()))
self.run_get_test(model, VeraConfig(intermediate_lora=True, output_lora=True), n_layers * 3)

def test_forward_vera(self):
model = self.get_model()
self.run_forward_test(model, VeraConfig(init_weights="vera", intermediate_lora=True, output_lora=True))

def test_load_vera(self):
self.run_load_test(VeraConfig())

def test_load_full_model_vera(self):
self.run_full_model_load_test(VeraConfig(init_weights="vera"))

def test_train_vera(self):
self.run_train_test(VeraConfig(init_weights="vera"), ["loras.{name}."])

def test_merge_vera(self):
self.run_merge_test(VeraConfig(init_weights="vera"))

def test_reset_vera(self):
self.run_reset_test(VeraConfig(init_weights="vera"))

0 comments on commit cd95c06

Please sign in to comment.