Skip to content

Commit

Permalink
Draft import structure and adapter model class
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Jan 6, 2025
1 parent d6054cb commit d338105
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/adapters/models/mllama/_init_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["MllamaAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import MllamaAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
153 changes: 153 additions & 0 deletions src/adapters/models/mllama/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import logging
from typing import List, Optional, Tuple, Union

import torch

from hf_transformers.build.lib.transformers.cache_utils import Cache
from hf_transformers.build.lib.transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mllama import MLLAMA_START_DOCSTRING, MllamaPreTrainedModel, MllamaTextModel
from transformers.utils import add_start_docstrings

from ...composition import adjust_tensors_for_parallel
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init


logger = logging.getLogger(__name__)


@add_start_docstrings(
"""
TODO
""",
MLLAMA_START_DOCSTRING,
)
class MllamaTextAdapterModel(
EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MllamaPreTrainedModel
):
head_types = [
"causal_lm",
] # TODO: "conditional_generation"

def __init__(self, config):
super().__init__(config)
self.model = MllamaTextModel
init(self.model)

self._init_head_modules()

self.init_weights()

# Model parallel
self.model_parallel = False
self.device_map = None
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.FloatTensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
kwargs["context"] = context
batch_size = outputs[0].shape[0]

if self.config.pad_token_id is None:
# TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this?
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
(sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

cls_logits = outputs[0][range(batch_size), sequence_lengths]

outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_logits,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)

return outputs

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}
)
return model_inputs

0 comments on commit d338105

Please sign in to comment.