From 6d39941f826f1cad061cb414fd3fea9d2f5be572 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Mon, 13 Jan 2025 17:47:42 +0100 Subject: [PATCH] Fix typos --- src/adapters/models/__init__.py | 24 ++++++++++++---------- src/adapters/models/auto/adapter_model.py | 2 +- src/adapters/models/mllama/mixin_mllama.py | 11 +++++----- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 52adf5030..2b7c55fc3 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -20,6 +20,18 @@ from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin from .mistral.mixin_mistral import MistralModelAdapterMixin +from .mllama.mixin_mllama import ( + MllamaAdaptersMixin, + MllamaCrossAttentionDecoderLayerAdaptersMixin, + MllamaSelfAttentionDecoderLayerAdaptersMixin, + MllamaTextCrossAttentionAdaptersMixin, + MllamaTextModelAdaptersMixin, + MllamaTextSelfAttentionAdaptersMixin, + MllamaVisionAttentionAdaptersMixin, + MllamaVisionEncoderAdaptersMixin, + MllamaVisionEncoderLayerAdaptersMixin, + MllamaVisionModelAdaptersMixin, +) from .plbart.mixin_plbart import ( PLBartDecoderAdaptersMixin, PLBartDecoderWrapperAdaptersMixin, @@ -42,17 +54,6 @@ ) from .xmod.mixin_xmod import XmodModelAdaptersMixin -from .mllama.mixin_mllama import ( - MllamaCrossAttentionDecoderLayerAdaptersMixin, - MllamaSelfAttentionDecoderLayerAdaptersMixin, - MllamaTextCrossAttentionAdaptersMixin, - MllamaTextModelAdaptersMixin, - MllamaTextSelfAttentionAdaptersMixin, - MllamaVisionAttentionAdaptersMixin, - MllamaVisionEncoderAdaptersMixin, - MllamaVisionEncoderLayerAdaptersMixin, - MllamaVisionModelAdaptersMixin, -) # IMPORTANT: Only add classes to this mapping that are not copied into the adapters package MODEL_MIXIN_MAPPING = { @@ -121,6 +122,7 @@ "LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin, "MistralModel": MistralModelAdapterMixin, # Mulitmodal Llama + "MllamaModel": MllamaAdaptersMixin, "MllamaVisionModel": MllamaVisionModelAdaptersMixin, "MllamaTextModel": MllamaTextModelAdaptersMixin, "MllamaVisionEncoder": MllamaVisionEncoderAdaptersMixin, diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 5f2497ff8..9921b1f87 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -24,7 +24,7 @@ ("llama", "LlamaAdapterModel"), ("mbart", "MBartAdapterModel"), ("mistral", "MistralAdapterModel"), - ("mllama", "MllamaAdapterModel") + ("mllama", "MllamaAdapterModel"), ("mt5", "MT5AdapterModel"), ("plbart", "PLBartAdapterModel"), ("roberta", "RobertaAdapterModel"), diff --git a/src/adapters/models/mllama/mixin_mllama.py b/src/adapters/models/mllama/mixin_mllama.py index 0f629ddc3..27b4aa7cb 100644 --- a/src/adapters/models/mllama/mixin_mllama.py +++ b/src/adapters/models/mllama/mixin_mllama.py @@ -5,6 +5,7 @@ from ...methods.reft import ReftLayer, hook_fn from ...model_mixin import ( EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, InvertibleAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin, @@ -65,7 +66,7 @@ def init_adapters(self, model_config, adapters_config): super().init_adapters(model_config, adapters_config) # Register hook for post embedding forward - self.embeddings.register_forward_hook(self.post_embedding_forward) + self.embed_tokens.register_forward_hook(self.post_embedding_forward) def post_embedding_forward(self, module, args, embedding_output): embedding_output = self.invertible_adapters_forward(embedding_output) @@ -77,23 +78,23 @@ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: yield i, layer -class MllamaAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): +class MllamaAdaptersMixin(EmbeddingAdaptersWrapperMixin, InvertibleAdaptersWrapperMixin, ModelBaseAdaptersMixin): """ Adds adapters to the MLLaMA model, handling both vision and text components. """ - invertible_adapters_base_name = "language_model" # Changed from text_model to match MLLaMA's naming + invertible_adapters_base_name = "language_model" support_prompt_tuning = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: layer_idx = 0 # First iterate through vision model's local transformer layers - for _, layer in enumerate(self.vision_model.iter_layers()): + for _, layer in self.vision_model.iter_layers(): yield layer_idx, layer layer_idx += 1 - for _, layer in enumerate(self.language_model.layers): + for _, layer in self.language_model.iter_layers(): yield layer_idx, layer layer_idx += 1