Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add emb training support
Browse files Browse the repository at this point in the history
calpt committed Dec 24, 2024
1 parent 62c1f83 commit 1f0d3ef
Showing 4 changed files with 17 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/adapters/wrappers/model.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from ..configuration import ModelAdaptersConfig
from ..interface import AdapterModelInterface
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelBaseAdaptersMixin,
@@ -63,7 +64,7 @@ def init(
model_class_name = base_model.__class__.__name__
model_class = type(
model_class_name,
(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin, base_model.__class__),
(EmbeddingAdaptersMixin, ModelBaseAdaptersMixin, base_model.__class__),
{},
)
base_model.__class__ = model_class
2 changes: 1 addition & 1 deletion tests/methods/base.py
Original file line number Diff line number Diff line change
@@ -250,7 +250,7 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config):
def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None):
if self.config_class not in ADAPTER_MODEL_MAPPING:
self.skipTest("Does not support flex heads.")
model = AutoAdapterModel.from_config(self.config())
15 changes: 8 additions & 7 deletions tests/test_adapter_embeddings.py
Original file line number Diff line number Diff line change
@@ -83,11 +83,10 @@ def test_training_embedding(self):
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoAdapterModel.from_config(self.config())
model = self._init_model_for_train_run("test", "dummy")

model.add_embeddings("test", tokenizer)
self.assertEqual(model.active_embeddings, "test")
model.add_adapter("test")
self.add_head(model, "test")
model.train_adapter("test", train_embeddings=True)

for k, v in filter_parameters(model, "adapters.test.").items():
@@ -103,7 +102,7 @@ def test_training_embedding(self):
training_args = TrainingArguments(
output_dir="./examples",
do_train=True,
learning_rate=0.4,
learning_rate=1.0,
max_steps=15,
no_cuda=True,
per_device_train_batch_size=2,
@@ -138,17 +137,19 @@ def test_training_embedding(self):
and "embed_tokens" not in k1
and "shared" not in k1
and "wte" not in k1
and "score" not in k1
)
)

def test_reference_embedding(self):
model = AutoAdapterModel.from_config(self.config()) # self.get_model()
model = self.get_model()
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
new_tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT")

model.add_embeddings("test", new_tokenizer, "default", tokenizer)
model.to(torch_device)

default_embedding = model.base_model.loaded_embeddings["default"]
test_embedding = model.base_model.loaded_embeddings["test"]
@@ -163,8 +164,8 @@ def test_reference_embedding(self):
if len(input_test) >= 5:
break

input_default = torch.tensor([input_default])
input_test = torch.tensor([input_test])
input_default = torch.tensor([input_default]).to(torch_device)
input_test = torch.tensor([input_test]).to(torch_device)

default = default_embedding(input_default)
test = test_embedding(input_test)
11 changes: 6 additions & 5 deletions tests/test_custom_interface.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,14 @@
import torch

import adapters
from adapters import AdapterModelInterface, AdapterSetup, load_model
from adapters import AdapterModelInterface, AdapterSetup, LoRAConfig, load_model
from transformers import Gemma2ForCausalLM, Gemma2ForSequenceClassification
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
from transformers.testing_utils import require_torch, torch_device

from .methods import IA3TestMixin, LoRATestMixin, ReftTestMixin, create_twin_models
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_embeddings import EmbeddingTestMixin


class CustomInterfaceModelTestBase(AdapterTestBase):
@@ -57,7 +58,7 @@ class CustomInterfaceModelTest(
# PromptTuningTestMixin,
ReftTestMixin,
# UniPELTTestMixin,
# EmbeddingTestMixin,
EmbeddingTestMixin,
# AdapterFusionModelTestMixin,
# CompabilityTestMixin,
# ParallelAdapterInferenceTestMixin,
@@ -99,12 +100,12 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config):
def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None):
model = Gemma2ForSequenceClassification(self.config())
adapters.init(model, interface=self.adapter_interface)

model.add_adapter(trained_adapter_name, config=adapter_config)
model.add_adapter(frozen_adapter_name, config=adapter_config)
model.add_adapter(trained_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert"))
model.add_adapter(frozen_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert"))

return model

0 comments on commit 1f0d3ef

Please sign in to comment.