Skip to content

Commit

Permalink
Fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Jan 13, 2025
1 parent d2a28d8 commit 6d39941
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
24 changes: 13 additions & 11 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin
from .mistral.mixin_mistral import MistralModelAdapterMixin
from .mllama.mixin_mllama import (
MllamaAdaptersMixin,
MllamaCrossAttentionDecoderLayerAdaptersMixin,
MllamaSelfAttentionDecoderLayerAdaptersMixin,
MllamaTextCrossAttentionAdaptersMixin,
MllamaTextModelAdaptersMixin,
MllamaTextSelfAttentionAdaptersMixin,
MllamaVisionAttentionAdaptersMixin,
MllamaVisionEncoderAdaptersMixin,
MllamaVisionEncoderLayerAdaptersMixin,
MllamaVisionModelAdaptersMixin,
)
from .plbart.mixin_plbart import (
PLBartDecoderAdaptersMixin,
PLBartDecoderWrapperAdaptersMixin,
Expand All @@ -42,17 +54,6 @@
)
from .xmod.mixin_xmod import XmodModelAdaptersMixin

from .mllama.mixin_mllama import (
MllamaCrossAttentionDecoderLayerAdaptersMixin,
MllamaSelfAttentionDecoderLayerAdaptersMixin,
MllamaTextCrossAttentionAdaptersMixin,
MllamaTextModelAdaptersMixin,
MllamaTextSelfAttentionAdaptersMixin,
MllamaVisionAttentionAdaptersMixin,
MllamaVisionEncoderAdaptersMixin,
MllamaVisionEncoderLayerAdaptersMixin,
MllamaVisionModelAdaptersMixin,
)

# IMPORTANT: Only add classes to this mapping that are not copied into the adapters package
MODEL_MIXIN_MAPPING = {
Expand Down Expand Up @@ -121,6 +122,7 @@
"LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin,
"MistralModel": MistralModelAdapterMixin,
# Mulitmodal Llama
"MllamaModel": MllamaAdaptersMixin,
"MllamaVisionModel": MllamaVisionModelAdaptersMixin,
"MllamaTextModel": MllamaTextModelAdaptersMixin,
"MllamaVisionEncoder": MllamaVisionEncoderAdaptersMixin,
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
("llama", "LlamaAdapterModel"),
("mbart", "MBartAdapterModel"),
("mistral", "MistralAdapterModel"),
("mllama", "MllamaAdapterModel")
("mllama", "MllamaAdapterModel"),
("mt5", "MT5AdapterModel"),
("plbart", "PLBartAdapterModel"),
("roberta", "RobertaAdapterModel"),
Expand Down
11 changes: 6 additions & 5 deletions src/adapters/models/mllama/mixin_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ...methods.reft import ReftLayer, hook_fn
from ...model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
InvertibleAdaptersWrapperMixin,
ModelBaseAdaptersMixin,
Expand Down Expand Up @@ -65,7 +66,7 @@ def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)

# Register hook for post embedding forward
self.embeddings.register_forward_hook(self.post_embedding_forward)
self.embed_tokens.register_forward_hook(self.post_embedding_forward)

def post_embedding_forward(self, module, args, embedding_output):
embedding_output = self.invertible_adapters_forward(embedding_output)
Expand All @@ -77,23 +78,23 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
yield i, layer


class MllamaAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin):
class MllamaAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin):
"""
Adds adapters to the MLLaMA model, handling both vision and text components.
"""

invertible_adapters_base_name = "language_model" # Changed from text_model to match MLLaMA's naming
invertible_adapters_base_name = "language_model"
support_prompt_tuning = False

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
layer_idx = 0

# First iterate through vision model's local transformer layers
for _, layer in enumerate(self.vision_model.iter_layers()):
for _, layer in self.vision_model.iter_layers():
yield layer_idx, layer
layer_idx += 1

for _, layer in enumerate(self.language_model.layers):
for _, layer in self.language_model.iter_layers():
yield layer_idx, layer
layer_idx += 1

Expand Down

0 comments on commit 6d39941

Please sign in to comment.