Skip to content

Commit

Permalink
Re-implement MllamaAdapterModel
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Jan 13, 2025
1 parent 7507e1e commit d2a28d8
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/adapters/models/mllama/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from transformers.utils import add_start_docstrings

from ...context import AdapterSetup
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -192,4 +193,37 @@ def forward(
output_adapter_fusion_attentions=False,
**kwargs,
):
pass

outputs, context = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
aspect_ratio_mask=aspect_ratio_mask,
aspect_ratio_ids=aspect_ratio_ids,
attention_mask=attention_mask,
cross_attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
kwargs["context"] = context

if head or AdapterSetup.get_context_head_setup() or self.active_head:
head_outputs = self.forward_head(
outputs,
head_name=head,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)
return head_outputs
return outputs

0 comments on commit d2a28d8

Please sign in to comment.