Skip to content

Commit

Permalink
Add Mistral support (#609)
Browse files Browse the repository at this point in the history

Co-authored-by: calpt <calpt@mail.de>
  • Loading branch information
KorventennFR and calpt authored Jul 20, 2024
1 parent fc929b4 commit 5cc7557
Show file tree
Hide file tree
Showing 15 changed files with 815 additions and 2 deletions.
31 changes: 31 additions & 0 deletions docs/classes/models/mistral.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Mistral
-----------------------------------------------------------------------------------------------------------------------

The Mistral model was proposed in `Mistral 7B <https://arxiv.org/abs/2310.06825>`__ by
Albert Q. Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas,
Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, Lélio Renard Lavaud, Marie-Anne Lachaux,
Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
It is a foundation language model with 7.3B parameters.

The abstract from the paper is the following:

*We introduce Mistral 7B, a 7-billion-parameter language model engineered for
superior performance and efficiency. Mistral 7B outperforms the best open 13B
model (Llama 2) across all evaluated benchmarks, and the best released 34B
model (Llama 1) in reasoning, mathematics, and code generation. Our model
leverages grouped-query attention (GQA) for faster inference, coupled with sliding
window attention (SWA) to effectively handle sequences of arbitrary length with a
reduced inference cost. We also provide a model fine-tuned to follow instructions,
Mistral 7B - Instruct, that surpasses Llama 2 13B - chat model both on human and
automated benchmarks. Our models are released under the Apache 2.0 license.*

Code: https://github.com/mistralai/mistral-src
Webpage: https://mistral.ai/news/announcing-mistral-7b/


MistralAdapterModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.MistralAdapterModel
:members:
:inherited-members: MistralPreTrainedModel
2 changes: 1 addition & 1 deletion docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- Add `<model_type>AdapterModel` to the `ADAPTER_MODEL_MAPPING_NAMES` mapping in `src/adapters/models/auto/adapter_model.py` and to `src/adapters/__init__.py`.
- Define the classes to be added to Python's import structure in `src/adapters/models/<model_type>/__init__.py`. This will likely only be the `<model_type>AdapterModel`.
6. **Adapt the config classes:**
- Adapt the config class to the requirements of adapters in `src/transformers/adapters/wrappers/configuration.py`.
- Adapt the config class to the requirements of adapters in `src/adapters/wrappers/configuration.py`.
- There are some naming differences in the config attributes of different model architectures. The adapter implementation requires some additional attributes with a specific name to be available. These currently are `num_attention_heads`, `hidden_size`, `hidden_dropout_prob` and `attention_probs_dropout_prob` as in the `BertConfig` class.
If your model config does not provide these, add corresponding mappings to `CONFIG_CLASS_KEYS_MAPPING`.

Expand Down
1 change: 1 addition & 0 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The table below further shows which model architectures support which adaptation
| [GPT-J](classes/models/gptj.html) |||||||| ||
| [Llama](classes/models/llama.html) |||||||| ||
| [MBart](classes/models/mbart.html) |||||||| ||
| [Mistral](classes/models/mistral.html) |||||||| ||
| [MT5](classes/models/mt5.html) |||||||| ||
| [PLBart](classes/models/plbart.html) |||||||| ||
| [RoBERTa](classes/models/roberta.html) ||||||||||
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"models.gptj": ["GPTJAdapterModel"],
"models.llama": ["LlamaAdapterModel"],
"models.mbart": ["MBartAdapterModel"],
"models.mistral": ["MistralAdapterModel"],
"models.mt5": ["MT5AdapterModel"],
"models.plbart": ["PLBartAdapterModel"],
"models.roberta": ["RobertaAdapterModel"],
Expand Down Expand Up @@ -217,6 +218,7 @@
from .models.gptj import GPTJAdapterModel
from .models.llama import LlamaAdapterModel
from .models.mbart import MBartAdapterModel
from .models.mistral import MistralAdapterModel
from .models.mt5 import MT5AdapterModel
from .models.plbart import PLBartAdapterModel
from .models.roberta import RobertaAdapterModel
Expand Down
1 change: 1 addition & 0 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
"xlm-roberta",
"bert-generation",
"llama",
"mistral",
"electra",
"xmod",
],
Expand Down
17 changes: 17 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,23 @@
},
"layers": [None, "qa_outputs"],
},
# Mistral
"MistralForSequenceClassification": {
"config": {
"head_type": "classification",
"layers": 1,
"dropout_prob": 0,
"activation_function": None,
"bias": False,
},
"layers": [None, "score"],
},
"MistralForCausalLM": {
"config": {
"head_type": "causal_lm",
},
"layers": ["lm_head"],
},
# Electra
"ElectraForTokenClassification": {
"config": {
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin
from .mistral.mixin_mistral import MistralModelAdapterMixin
from .plbart.mixin_plbart import (
PLBartDecoderAdaptersMixin,
PLBartDecoderWrapperAdaptersMixin,
Expand Down Expand Up @@ -94,4 +95,5 @@
"BertGenerationLayer": BertLayerAdaptersMixin,
"LlamaModel": LlamaModelAdapterMixin,
"LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin,
"MistralModel": MistralModelAdapterMixin,
}
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
("gptj", "GPTJAdapterModel"),
("llama", "LlamaAdapterModel"),
("mbart", "MBartAdapterModel"),
("mistral", "MistralAdapterModel"),
("mt5", "MT5AdapterModel"),
("plbart", "PLBartAdapterModel"),
("roberta", "RobertaAdapterModel"),
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/llama/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@add_start_docstrings(
"""
The Llama Model that allows the loading of different heads dor different tasks. This enables a flexible use of the
The Llama Model that allows the loading of different heads for different tasks. This enables a flexible use of the
models and adpters. Since this class does classification on the last token, it requires to know the position of the
last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding
token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
Expand Down
39 changes: 39 additions & 0 deletions src/adapters/models/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["MistralAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import MistralAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
149 changes: 149 additions & 0 deletions src/adapters/models/mistral/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import logging

import torch

from transformers.models.mistral.modeling_mistral import MISTRAL_START_DOCSTRING, MistralModel, MistralPreTrainedModel
from transformers.utils import add_start_docstrings

from ...composition import adjust_tensors_for_parallel
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init


logger = logging.getLogger(__name__)


@add_start_docstrings(
"""
The Mistal Model that allows the loading of different heads for different tasks. This enables a flexible use of the
models and adpters. Since this class does classification on the last token, it requires to know the position of the
last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding
token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same
(take the last value in each row of the batch).
""",
MISTRAL_START_DOCSTRING,
)
class MistralAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MistralPreTrainedModel):
head_types = [
"classification",
"multilabel_classification",
"tagging",
"question_answering",
"causal_lm",
]

def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
init(self.model)

self._init_head_modules()

self.init_weights()

# Model parallel
self.model_parallel = False
self.device_map = None
self.post_init()

def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=return_dict,
output_hidden_states=output_hidden_states,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

batch_size = outputs[0].shape[0]

if self.config.pad_token_id is None:
# TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this?
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
(sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

cls_logits = outputs[0][range(batch_size), sequence_lengths]

outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_logits,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)

return outputs

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
}
)
return model_inputs
46 changes: 46 additions & 0 deletions src/adapters/models/mistral/mixin_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


class MistralAttentionMixin:
def init_adapters(self, model_config, adapters_config):
self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q")
self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k")
self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningLayer("self_prefix", model_config, adapters_config)


class MistralDecoderLayerMixin:
def init_adapters(self, model_config, adapters_config):
# Wrap layers for LoRA
self.mlp.down_proj = LoRALinear.wrap(self.mlp.down_proj, "intermediate", model_config, adapters_config)
self.mlp.up_proj = LoRALinear.wrap(self.mlp.up_proj, "output", model_config, adapters_config)

self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")


class MistralModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False

def init_adapters(self, model_config, adapters_config):
super().init_adapters(model_config, adapters_config)

# Register hook for post embedding forward
self.embed_tokens.register_forward_hook(self.post_embedding_forward)

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.layers):
yield i, layer

def post_embedding_forward(self, module, args, embedding_output):
embedding_output = self.invertible_adapters_forward(embedding_output)
# Prompt tuning not yet supported
return embedding_output
Loading

0 comments on commit 5cc7557

Please sign in to comment.