Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pluggable Model Integration Interface #738

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6366685
Init pluggable interface
calpt Aug 24, 2024
f004a10
Test fixes
calpt Aug 25, 2024
3c4c791
doc
calpt Aug 25, 2024
9cef6c6
style
calpt Aug 25, 2024
2171409
Fix iter_layers & add tests
lenglaender Aug 29, 2024
89e7aa7
fix
calpt Sep 15, 2024
31c8b2a
Merge branch 'main' into dev/interface
calpt Dec 23, 2024
5305c36
wip: bottleneck
calpt Dec 23, 2024
0d623ef
Minimal working bottleneck plugin version
calpt Dec 24, 2024
5334521
style
calpt Dec 24, 2024
62c1f83
attr fix
calpt Dec 24, 2024
1f0d3ef
add emb training support
calpt Dec 24, 2024
6338a79
style
calpt Dec 24, 2024
3e90a3a
simple prompt tuning implementation
calpt Dec 25, 2024
ea85e43
Extended interface for more bottleneck support
calpt Dec 25, 2024
f4d7967
load_model() concurrency fix
calpt Dec 25, 2024
652b562
fix init for replaced classes
calpt Dec 26, 2024
e26c425
Merge branch 'main' into dev/interface
calpt Dec 26, 2024
05e8b2a
clean up adapters init
calpt Dec 26, 2024
b01cf6d
fixes
calpt Dec 26, 2024
535dd9c
Add `supports_adapter()` method
calpt Jan 5, 2025
7d346db
Rename AdapterType -> AdapterMethod. Test fixes.
calpt Jan 6, 2025
3a9e702
WIP: invertible adapters support
calpt Jan 6, 2025
074ca66
Add invertible output layer. Test fixes.
calpt Jan 6, 2025
5596bf1
Merge branch 'main' into dev/interface
calpt Jan 9, 2025
ec4ae1a
Merge branch 'main' into dev/interface
calpt Jan 29, 2025
c9098bd
Fix test after refactoring
calpt Feb 3, 2025
f7d59c0
remove code
calpt Feb 3, 2025
cbf74a9
style
calpt Feb 3, 2025
788bc8d
Save & load adapter interface with full model
calpt Feb 8, 2025
3d085fd
rename adapter_types -> adapter_methods
calpt Feb 9, 2025
f3c43a5
Add documentation
calpt Feb 9, 2025
9ae0afe
Update docs/model_overview.md
calpt Feb 10, 2025
788a76f
Update docs/plugin_interface.md
calpt Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add emb training support
calpt committed Dec 24, 2024
commit 1f0d3efd1955992d1a0b57f19ec2ebe00938fb53
3 changes: 2 additions & 1 deletion src/adapters/wrappers/model.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from ..configuration import ModelAdaptersConfig
from ..interface import AdapterModelInterface
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelBaseAdaptersMixin,
@@ -63,7 +64,7 @@ def init(
model_class_name = base_model.__class__.__name__
model_class = type(
model_class_name,
(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin, base_model.__class__),
(EmbeddingAdaptersMixin, ModelBaseAdaptersMixin, base_model.__class__),
{},
)
base_model.__class__ = model_class
2 changes: 1 addition & 1 deletion tests/methods/base.py
Original file line number Diff line number Diff line change
@@ -250,7 +250,7 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config):
def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None):
if self.config_class not in ADAPTER_MODEL_MAPPING:
self.skipTest("Does not support flex heads.")
model = AutoAdapterModel.from_config(self.config())
15 changes: 8 additions & 7 deletions tests/test_adapter_embeddings.py
Original file line number Diff line number Diff line change
@@ -83,11 +83,10 @@ def test_training_embedding(self):
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoAdapterModel.from_config(self.config())
model = self._init_model_for_train_run("test", "dummy")

model.add_embeddings("test", tokenizer)
self.assertEqual(model.active_embeddings, "test")
model.add_adapter("test")
self.add_head(model, "test")
model.train_adapter("test", train_embeddings=True)

for k, v in filter_parameters(model, "adapters.test.").items():
@@ -103,7 +102,7 @@ def test_training_embedding(self):
training_args = TrainingArguments(
output_dir="./examples",
do_train=True,
learning_rate=0.4,
learning_rate=1.0,
max_steps=15,
no_cuda=True,
per_device_train_batch_size=2,
@@ -138,17 +137,19 @@ def test_training_embedding(self):
and "embed_tokens" not in k1
and "shared" not in k1
and "wte" not in k1
and "score" not in k1
)
)

def test_reference_embedding(self):
model = AutoAdapterModel.from_config(self.config()) # self.get_model()
model = self.get_model()
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
new_tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT")

model.add_embeddings("test", new_tokenizer, "default", tokenizer)
model.to(torch_device)

default_embedding = model.base_model.loaded_embeddings["default"]
test_embedding = model.base_model.loaded_embeddings["test"]
@@ -163,8 +164,8 @@ def test_reference_embedding(self):
if len(input_test) >= 5:
break

input_default = torch.tensor([input_default])
input_test = torch.tensor([input_test])
input_default = torch.tensor([input_default]).to(torch_device)
input_test = torch.tensor([input_test]).to(torch_device)

default = default_embedding(input_default)
test = test_embedding(input_test)
11 changes: 6 additions & 5 deletions tests/test_custom_interface.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,14 @@
import torch

import adapters
from adapters import AdapterModelInterface, AdapterSetup, load_model
from adapters import AdapterModelInterface, AdapterSetup, LoRAConfig, load_model
from transformers import Gemma2ForCausalLM, Gemma2ForSequenceClassification
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
from transformers.testing_utils import require_torch, torch_device

from .methods import IA3TestMixin, LoRATestMixin, ReftTestMixin, create_twin_models
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_embeddings import EmbeddingTestMixin


class CustomInterfaceModelTestBase(AdapterTestBase):
@@ -57,7 +58,7 @@ class CustomInterfaceModelTest(
# PromptTuningTestMixin,
ReftTestMixin,
# UniPELTTestMixin,
# EmbeddingTestMixin,
EmbeddingTestMixin,
# AdapterFusionModelTestMixin,
# CompabilityTestMixin,
# ParallelAdapterInferenceTestMixin,
@@ -99,12 +100,12 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config):
def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None):
model = Gemma2ForSequenceClassification(self.config())
adapters.init(model, interface=self.adapter_interface)

model.add_adapter(trained_adapter_name, config=adapter_config)
model.add_adapter(frozen_adapter_name, config=adapter_config)
model.add_adapter(trained_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert"))
model.add_adapter(frozen_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert"))

return model