From d338105ad72d3f9e0a809123f843ab8c67b25fe2 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Mon, 6 Jan 2025 22:00:26 +0100 Subject: [PATCH] Draft import structure and adapter model class --- src/adapters/models/mllama/_init_.py | 39 +++++ src/adapters/models/mllama/adapter_model.py | 153 ++++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 src/adapters/models/mllama/_init_.py create mode 100644 src/adapters/models/mllama/adapter_model.py diff --git a/src/adapters/models/mllama/_init_.py b/src/adapters/models/mllama/_init_.py new file mode 100644 index 000000000..12ff0ddd9 --- /dev/null +++ b/src/adapters/models/mllama/_init_.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["MllamaAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import MllamaAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/mllama/adapter_model.py b/src/adapters/models/mllama/adapter_model.py new file mode 100644 index 000000000..8812a2ecd --- /dev/null +++ b/src/adapters/models/mllama/adapter_model.py @@ -0,0 +1,153 @@ +import logging +from typing import List, Optional, Tuple, Union + +import torch + +from hf_transformers.build.lib.transformers.cache_utils import Cache +from hf_transformers.build.lib.transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.mllama import MLLAMA_START_DOCSTRING, MllamaPreTrainedModel, MllamaTextModel +from transformers.utils import add_start_docstrings + +from ...composition import adjust_tensors_for_parallel +from ...heads import ModelWithFlexibleHeadsAdaptersMixin +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +logger = logging.getLogger(__name__) + + +@add_start_docstrings( + """ + TODO + """, + MLLAMA_START_DOCSTRING, +) +class MllamaTextAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MllamaPreTrainedModel +): + head_types = [ + "causal_lm", + ] # TODO: "conditional_generation" + + def __init__(self, config): + super().__init__(config) + self.model = MllamaTextModel + init(self.model) + + self._init_head_modules() + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs, context = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + kwargs["context"] = context + batch_size = outputs[0].shape[0] + + if self.config.pad_token_id is None: + # TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this? + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + (sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + cls_logits = outputs[0][range(batch_size), sequence_lengths] + + outputs = self.forward_head( + outputs, + head_name=head, + cls_output=cls_logits, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + + return outputs + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), + } + ) + return model_inputs