Skip to content

Commit

Permalink
Introduce new parameter to differentiate between redundant and not su…
Browse files Browse the repository at this point in the history
…pported tests
  • Loading branch information
TimoImhof committed Jan 19, 2025
1 parent cdd81f9 commit 3738ada
Show file tree
Hide file tree
Showing 21 changed files with 40 additions and 36 deletions.
34 changes: 18 additions & 16 deletions tests/test_methods/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,23 @@

def generate_method_tests(
model_test_base,
excluded_tests=[],
redundant=[],
not_supported=[],
) -> dict:
"""
Generates a set of method test classes for a given model test base.
Args:
model_test_base (type): The base class for the model tests.
excluded_tests (list, optional): A list of test classes to exclude.
redundant (list, optional): A list of redundant tests to exclude. Defaults to [].
not_supported (list, optional): A list of tests that are not supported for the model. Defaults to [].
Returns:
dict: A dictionary mapping test class names to the generated test classes.
"""
test_classes = {}

if "Core" not in excluded_tests:
if "Core" not in redundant and "Core" not in not_supported:

@require_torch
@pytest.mark.core
Expand All @@ -60,7 +62,7 @@ class Core(

test_classes["Core"] = Core

if "Heads" not in excluded_tests:
if "Heads" not in redundant and "Heads" not in not_supported:

@require_torch
@pytest.mark.heads
Expand All @@ -73,7 +75,7 @@ class Heads(

test_classes["Heads"] = Heads

if "Embeddings" not in excluded_tests:
if "Embeddings" not in redundant and "Embeddings" not in not_supported:

@require_torch
@pytest.mark.embeddings
Expand All @@ -86,7 +88,7 @@ class Embeddings(

test_classes["Embeddings"] = Embeddings

if "Composition" not in excluded_tests:
if "Composition" not in redundant and "Composition" not in not_supported:

@require_torch
@pytest.mark.composition
Expand All @@ -100,7 +102,7 @@ class Composition(

test_classes["Composition"] = Composition

if "ClassConversion" not in excluded_tests:
if "ClassConversion" not in redundant and "ClassConversion" not in not_supported:

@require_torch
class ClassConversion(
Expand All @@ -112,7 +114,7 @@ class ClassConversion(

test_classes["ClassConversion"] = ClassConversion

if "PrefixTuning" not in excluded_tests:
if "PrefixTuning" not in redundant and "PrefixTuning" not in not_supported:

@require_torch
@pytest.mark.prefix_tuning
Expand All @@ -125,7 +127,7 @@ class PrefixTuning(

test_classes["PrefixTuning"] = PrefixTuning

if "PromptTuning" not in excluded_tests:
if "PromptTuning" not in redundant and "PromptTuning" not in not_supported:

@require_torch
@pytest.mark.prompt_tuning
Expand All @@ -138,7 +140,7 @@ class PromptTuning(

test_classes["PromptTuning"] = PromptTuning

if "ReFT" not in excluded_tests:
if "ReFT" not in redundant and "ReFT" not in not_supported:

@require_torch
@pytest.mark.reft
Expand All @@ -151,7 +153,7 @@ class ReFT(

test_classes["ReFT"] = ReFT

if "UniPELT" not in excluded_tests:
if "UniPELT" not in redundant and "UniPELT" not in not_supported:

@require_torch
@pytest.mark.unipelt
Expand All @@ -164,7 +166,7 @@ class UniPELT(

test_classes["UniPELT"] = UniPELT

if "Compacter" not in excluded_tests:
if "Compacter" not in redundant and "Compacter" not in not_supported:

@require_torch
@pytest.mark.compacter
Expand All @@ -177,7 +179,7 @@ class Compacter(

test_classes["Compacter"] = Compacter

if "Bottleneck" not in excluded_tests:
if "Bottleneck" not in redundant and "Bottleneck" not in not_supported:

@require_torch
@pytest.mark.bottleneck
Expand All @@ -190,7 +192,7 @@ class Bottleneck(

test_classes["Bottleneck"] = Bottleneck

if "IA3" not in excluded_tests:
if "IA3" not in redundant and "IA3" not in not_supported:

@require_torch
@pytest.mark.ia3
Expand All @@ -203,7 +205,7 @@ class IA3(

test_classes["IA3"] = IA3

if "LoRA" not in excluded_tests:
if "LoRA" not in redundant and "LoRA" not in not_supported:

@require_torch
@pytest.mark.lora
Expand All @@ -216,7 +218,7 @@ class LoRA(

test_classes["LoRA"] = LoRA

if "ConfigUnion" not in excluded_tests:
if "ConfigUnion" not in redundant and "ConfigUnion" not in not_supported:

@require_torch
@pytest.mark.config_union
Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class AlbertAdapterTestBase(TextAdapterTestBase):
leave_out_layers = [0]


method_tests = generate_method_tests(AlbertAdapterTestBase, excluded_tests=["Heads"])
method_tests = generate_method_tests(AlbertAdapterTestBase, not_supported=["Heads"])

for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
Expand Down
4 changes: 3 additions & 1 deletion tests/test_methods/test_on_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class BartAdapterTestBase(TextAdapterTestBase):
tokenizer_name = "facebook/bart-base"


method_tests = generate_method_tests(BartAdapterTestBase, excluded_tests=["PromptTuning"])
method_tests = generate_method_tests(
BartAdapterTestBase, not_supported=["PromptTuning"], redundant=["ConfigUnion", "Embeddings"]
)

for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BeitAdapterTestBase(VisionAdapterTestBase):
feature_extractor_name = "microsoft/beit-base-patch16-224-pt22k"


method_tests = generate_method_tests(BeitAdapterTestBase, excluded_tests=["Composition", "Embeddings"])
method_tests = generate_method_tests(BeitAdapterTestBase, not_supported=["Composition", "Embeddings"])

for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_clip/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_load_adapter_setup(self):

method_tests = generate_method_tests(
model_test_base=CLIPAdapterTestBase,
excluded_tests=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_clip/test_textmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CLIPTextAdapterTestBase(TextAdapterTestBase):

method_tests = generate_method_tests(
model_test_base=CLIPTextAdapterTestBase,
excluded_tests=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CLIPTextWithProjectionAdapterTestBase(TextAdapterTestBase):

method_tests = generate_method_tests(
model_test_base=CLIPTextWithProjectionAdapterTestBase,
excluded_tests=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_clip/test_visionmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class CLIPVisionAdapterTestBase(VisionAdapterTestBase):

method_tests = generate_method_tests(
model_test_base=CLIPVisionAdapterTestBase,
excluded_tests=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class CLIPVisionWithProjectionAdapterTestBase(VisionAdapterTestBase):

method_tests = generate_method_tests(
model_test_base=CLIPVisionWithProjectionAdapterTestBase,
excluded_tests=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"],
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_output_adapter_fusion_attentions(self):

test_methods = generate_method_tests(
EncoderDecoderAdapterTestBase,
excluded_tests=["Heads", "ConfigUnion", "Embeddings", "Composition", "PromptTuning", "ClassConversion"],
not_supported=["Heads", "ConfigUnion", "Embeddings", "Composition", "PromptTuning", "ClassConversion"],
)

for test_class_name, test_class in test_methods.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_parallel_training_lora(self):
self.skipTest("Not supported for GPT2")


method_tests = generate_method_tests(GPT2AdapterTestBase, excluded_tests=["PromptTuning"])
method_tests = generate_method_tests(GPT2AdapterTestBase, not_supported=["PromptTuning"])

for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LlamaAdapterTestBase(TextAdapterTestBase):
tokenizer_name = "openlm-research/open_llama_13b"


method_tests = generate_method_tests(LlamaAdapterTestBase, excluded_tests=["PromptTuning"])
method_tests = generate_method_tests(LlamaAdapterTestBase, not_supported=["PromptTuning"])

for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
Expand Down
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_parallel_training_lora(self):


method_tests = generate_method_tests(
MBartAdapterTestBase, excluded_tests=["ConfigUnion", "Embeddings", "PromptTuning"]
MBartAdapterTestBase, redundant=["ConfigUnion", "Embeddings"], not_supported=["PromptTuning"]
)
for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MistralAdapterTestBase(TextAdapterTestBase):
tokenizer_name = "HuggingFaceH4/zephyr-7b-beta"


test_methods = generate_method_tests(MistralAdapterTestBase, excluded_tests=["PromptTuning", "ConfigUnion"])
test_methods = generate_method_tests(MistralAdapterTestBase, not_supported=["PromptTuning", "ConfigUnion"])

for test_class_name, test_class in test_methods.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MT5AdapterTestBase(TextAdapterTestBase):
tokenizer_name = "google/mt5-base"


method_tests = generate_method_tests(MT5AdapterTestBase, excluded_tests=["PromptTuning", "ConfigUnion"])
method_tests = generate_method_tests(MT5AdapterTestBase, not_supported=["PromptTuning", "ConfigUnion"])

for test_name, test_class in method_tests.items():
globals()[test_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PLBartAdapterTestBase(TextAdapterTestBase):
tokenizer_name = "uclanlp/plbart-base"


method_tests = generate_method_tests(PLBartAdapterTestBase, excluded_tests=["PromptTuning"])
method_tests = generate_method_tests(PLBartAdapterTestBase, not_supported=["PromptTuning"])

for test_name, test_class in method_tests.items():
globals()[test_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class T5AdapterTestBase(TextAdapterTestBase):
tokenizer_name = "t5-base"


method_tests = generate_method_tests(T5AdapterTestBase, excluded_tests=["ConfigUnion", "PromptTuning"])
method_tests = generate_method_tests(T5AdapterTestBase, not_supported=["ConfigUnion", "PromptTuning"])
for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class ViTAdapterTestBase(VisionAdapterTestBase):
feature_extractor_name = "google/vit-base-patch16-224-in21k"


method_tests = generate_method_tests(ViTAdapterTestBase, excluded_tests=["ConfigUnion", "Embeddings", "Composition"])
method_tests = generate_method_tests(ViTAdapterTestBase, not_supported=["ConfigUnion", "Embeddings", "Composition"])
for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def test_parallel_training_lora(self):
self.skipTest("Not supported for Whisper")


method_tests = generate_method_tests(WhisperAdapterTestBase, excluded_tests=["PromptTuning"])
method_tests = generate_method_tests(WhisperAdapterTestBase, not_supported=["PromptTuning"])
for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ class XLMRobertaAdapterTestBase(TextAdapterTestBase):
tokenizer_name = "xlm-roberta-base"


method_tests = generate_method_tests(XLMRobertaAdapterTestBase, excluded_tests=["ConfigUnion", "Embeddings"])
method_tests = generate_method_tests(XLMRobertaAdapterTestBase, redundant=["ConfigUnion", "Embeddings"])
for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class
2 changes: 1 addition & 1 deletion tests/test_methods/test_on_xmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ class XmodAdapterTestBase(TextAdapterTestBase):
tokenizer_name = "xlm-roberta-base"


method_tests = generate_method_tests(XmodAdapterTestBase, excluded_tests=["ConfigUnion", "Embeddings"])
method_tests = generate_method_tests(XmodAdapterTestBase, not_supported=["ConfigUnion", "Embeddings"])
for test_class_name, test_class in method_tests.items():
globals()[test_class_name] = test_class

0 comments on commit 3738ada

Please sign in to comment.