From d2a28d8114c6a4cd1157dde86fe0df5c817cb341 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Mon, 13 Jan 2025 16:43:08 +0100 Subject: [PATCH] Re-implement MllamaAdapterModel --- src/adapters/models/mllama/adapter_model.py | 36 ++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/adapters/models/mllama/adapter_model.py b/src/adapters/models/mllama/adapter_model.py index 73f845105..eb475554c 100644 --- a/src/adapters/models/mllama/adapter_model.py +++ b/src/adapters/models/mllama/adapter_model.py @@ -16,6 +16,7 @@ ) from transformers.utils import add_start_docstrings +from ...context import AdapterSetup from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -192,4 +193,37 @@ def forward( output_adapter_fusion_attentions=False, **kwargs, ): - pass + + outputs, context = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + aspect_ratio_mask=aspect_ratio_mask, + aspect_ratio_ids=aspect_ratio_ids, + attention_mask=attention_mask, + cross_attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + kwargs["context"] = context + + if head or AdapterSetup.get_context_head_setup() or self.active_head: + head_outputs = self.forward_head( + outputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + return head_outputs + return outputs