Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support Mllama #777

Draft
wants to merge 67 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
93bee1a
Make generation tests generic
TimoImhof Sep 16, 2024
f51cfdb
Merge remote-tracking branch 'origin/main' into dev/test-refactoring
TimoImhof Oct 16, 2024
7e65e82
Draft Refactoring AdapterTestBase
TimoImhof Oct 28, 2024
793cbe5
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof Oct 30, 2024
65c3fb7
Replace import class names
TimoImhof Oct 30, 2024
afdcfdd
Merge branch 'dev/test-refactoring' of https://github.com/TimoImhof/a…
TimoImhof Oct 30, 2024
ee6166c
Base refactoring:
TimoImhof Nov 1, 2024
630b722
remove redundant imports
TimoImhof Nov 1, 2024
0d3577f
Add pytest markers and respective pytest commands
TimoImhof Nov 1, 2024
1300856
Add draft of README
TimoImhof Nov 1, 2024
78387db
Refactoring:
TimoImhof Nov 5, 2024
83d3b32
Fix make quality
TimoImhof Nov 5, 2024
5e8e1b8
Add gpt2 tests
TimoImhof Nov 5, 2024
53eb0b9
Fix config union and head tests
TimoImhof Nov 7, 2024
1dbd412
Fix paths and imports
TimoImhof Nov 7, 2024
cf4f6a7
remove accidently added prompt tuning from gpt2 and make style
TimoImhof Nov 7, 2024
b390d61
Revert PromptTuning changes
TimoImhof Nov 7, 2024
2193aee
Revert "Revert PromptTuning changes"
TimoImhof Nov 7, 2024
f555484
Re-add missing adapter model tests
TimoImhof Nov 7, 2024
8dccda2
Refactoring:
TimoImhof Nov 7, 2024
c665948
Introduce generic test creator function
TimoImhof Nov 8, 2024
fb425b6
Re-add beit adapter method tests
TimoImhof Nov 8, 2024
225439c
Refactor & Re-add bertgeneration and bert
TimoImhof Nov 9, 2024
09f9cdc
Re-add clip tests
TimoImhof Nov 11, 2024
7934350
Re-add:
TimoImhof Nov 11, 2024
5f55935
Add more models
TimoImhof Nov 21, 2024
147c8af
Re-add whisper
TimoImhof Nov 27, 2024
b2979ce
Changes:
TimoImhof Dec 16, 2024
ffd21a9
Add debug statements and only execute failing test
TimoImhof Dec 18, 2024
0dba87c
Add verbose information
TimoImhof Dec 18, 2024
c333467
check package versions
TimoImhof Dec 18, 2024
aac4038
More debugging statements
TimoImhof Dec 18, 2024
0f4c9b6
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof Dec 22, 2024
9ac515c
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof Dec 23, 2024
4af10df
Fix failing test:
TimoImhof Dec 23, 2024
dbd4965
Update README
TimoImhof Dec 24, 2024
d1a4a09
Merge branch 'main' of https://github.com/TimoImhof/adapters into dev…
TimoImhof Dec 27, 2024
c516464
Fix hf version and clip tests
TimoImhof Dec 27, 2024
d338105
Draft import structure and adapter model class
TimoImhof Jan 6, 2025
dc5dc6d
Update gitignore for development
TimoImhof Jan 7, 2025
7b46d58
More thorough draft of adapter model
TimoImhof Jan 8, 2025
32609c3
Draft mllama adapter mixins
TimoImhof Jan 8, 2025
87c0998
Merge branch 'adapter-hub:main' into dev/test-refactoring
TimoImhof Jan 8, 2025
77135cf
Merge branch 'adapter-hub:main' into dev/mllama
TimoImhof Jan 8, 2025
2c80a5c
Polish:
TimoImhof Jan 8, 2025
be69f0a
Merge branch 'main' into dev/test-refactoring
TimoImhof Jan 8, 2025
f1b1136
Merge branch 'main' into dev/test-refactoring
TimoImhof Jan 8, 2025
919d72a
Merge branch 'adapter-hub:main' into dev/mllama
TimoImhof Jan 9, 2025
b680012
Draft import structure and adapter model class
TimoImhof Jan 6, 2025
c32b08a
Update gitignore for development
TimoImhof Jan 7, 2025
4b38180
More thorough draft of adapter model
TimoImhof Jan 8, 2025
8f9298a
Draft mllama adapter mixins
TimoImhof Jan 8, 2025
a122a75
Merge branch 'dev/mllama' of https://github.com/TimoImhof/adapters in…
TimoImhof Jan 9, 2025
d67692d
Fix import structure
TimoImhof Jan 10, 2025
958b2c6
Reuse mixin implementations
TimoImhof Jan 10, 2025
7507e1e
Create MllamaModel class and adjust mixins accordingly
TimoImhof Jan 13, 2025
d2a28d8
Re-implement MllamaAdapterModel
TimoImhof Jan 13, 2025
6d39941
Fix typos
TimoImhof Jan 13, 2025
4c153b4
Draft adapter attention classes
TimoImhof Jan 14, 2025
a52154f
Progress:
TimoImhof Jan 16, 2025
2ec0b35
save links for useful resources
TimoImhof Jan 16, 2025
88f6230
Integrate CLIP into refactored test structure
TimoImhof Jan 16, 2025
39878e6
Merge branch 'dev/test-refactoring' into dev/mllama
TimoImhof Jan 17, 2025
a75846a
Progress:
TimoImhof Jan 17, 2025
7b970c9
Add mllama model tests
TimoImhof Jan 17, 2025
40871e5
Adapt VisionEncoder forward pre hook
TimoImhof Jan 19, 2025
e177ecb
Merge branch 'main' into dev/mllama
TimoImhof Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,7 @@ scripts/git-strip-merge
tests/backwards_compatibility/Ref_Out

# backwards compatibility
model_outputs
model_outputs

# TODO: remove after mllama dev
explore_mllama
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ use_parentheses = True
[flake8]
ignore = E203, E501, E731, E741, W503, W605
max-line-length = 119
per-file-ignores =
tests/test_methods/generator.py: F401, F403, F405
tests/test_methods/test_*.py:F403,F405

[tool:pytest]
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
"models.llama": ["LlamaAdapterModel"],
"models.mbart": ["MBartAdapterModel"],
"models.mistral": ["MistralAdapterModel"],
"models.mllama": ["MllamaAdapterModel"],
"models.mt5": ["MT5AdapterModel"],
"models.plbart": ["PLBartAdapterModel"],
"models.roberta": ["RobertaAdapterModel"],
Expand Down Expand Up @@ -222,6 +223,7 @@
from .models.llama import LlamaAdapterModel
from .models.mbart import MBartAdapterModel
from .models.mistral import MistralAdapterModel
from .models.mllama import MllamaAdapterModel
from .models.mt5 import MT5AdapterModel
from .models.plbart import PLBartAdapterModel
from .models.roberta import RobertaAdapterModel
Expand Down
10 changes: 10 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,16 @@
},
"layers": ["proj_out"],
},
"MllamaForConditionalGeneration": {
"config": {
"head_type": "causal_lm",
"layers": 1,
"activation_function": None,
"layer_norm": False,
"bias": False,
},
"layers": ["language_model.lm_head"],
},
}


Expand Down
1 change: 1 addition & 0 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(self, model_config: PretrainedConfig, adapters_config: ModelAdapter
self.prefix_tunings = nn.ModuleDict()

def indicate_prefix(self, prefix_name: str, location_key: str, **kwargs):
"""Indicate that a Prefix Tuning module should be added to the indicated layer."""
if prefix_name not in self.prefix_counts:
self.prefix_counts[prefix_name] = {location_key: {"count": 1, **kwargs}}
elif location_key not in self.prefix_counts[prefix_name]:
Expand Down
25 changes: 25 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin
from .mistral.mixin_mistral import MistralModelAdapterMixin
from .mllama.mixin_mllama import (
MllamaAdaptersMixin,
MllamaCrossAttentionDecoderLayerAdaptersMixin,
MllamaForConditionalGenerationWithHeadsAdaptersMixin,
MllamaSelfAttentionDecoderLayerAdaptersMixin,
MllamaTextCrossAttentionAdaptersMixin,
MllamaTextModelAdaptersMixin,
MllamaTextSelfAttentionAdaptersMixin,
MllamaVisionAttentionAdaptersMixin,
MllamaVisionEncoderAdaptersMixin,
MllamaVisionEncoderLayerAdaptersMixin,
MllamaVisionModelAdaptersMixin,
)
from .plbart.mixin_plbart import (
PLBartDecoderAdaptersMixin,
PLBartDecoderWrapperAdaptersMixin,
Expand Down Expand Up @@ -109,4 +122,16 @@
"WhisperForAudioClassification": WhisperForAudioClassificationWithHeadsMixin,
"LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin,
"MistralModel": MistralModelAdapterMixin,
# Mulitmodal Llama
"MllamaForConditionalGeneration": MllamaForConditionalGenerationWithHeadsAdaptersMixin,
"MllamaModel": MllamaAdaptersMixin,
"MllamaVisionModel": MllamaVisionModelAdaptersMixin,
"MllamaTextModel": MllamaTextModelAdaptersMixin,
"MllamaVisionEncoder": MllamaVisionEncoderAdaptersMixin,
"MllamaVisionAttention": MllamaVisionAttentionAdaptersMixin,
"MllamaTextSelfAttention": MllamaTextSelfAttentionAdaptersMixin,
"MllamaTextCrossAttention": MllamaTextCrossAttentionAdaptersMixin,
"MllamaVisionEncoderLayer": MllamaVisionEncoderLayerAdaptersMixin,
"MllamaSelfAttentionDecoderLayer": MllamaSelfAttentionDecoderLayerAdaptersMixin,
"MllamaCrossAttentionDecoderLayer": MllamaCrossAttentionDecoderLayerAdaptersMixin,
}
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
("llama", "LlamaAdapterModel"),
("mbart", "MBartAdapterModel"),
("mistral", "MistralAdapterModel"),
("mllama", "MllamaAdapterModel"),
("mt5", "MT5AdapterModel"),
("plbart", "PLBartAdapterModel"),
("roberta", "RobertaAdapterModel"),
Expand Down
232 changes: 232 additions & 0 deletions src/adapters/models/mllama/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import logging
from typing import List, Optional, Tuple, Union

import torch
from torch import nn

from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mllama.modeling_mllama import (
MLLAMA_START_DOCSTRING,
MllamaPreTrainedModel,
MllamaTextModel,
MllamaVisionModel,
_prepare_cross_attention_mask,
)
from transformers.utils import add_start_docstrings

from ...context import AdapterSetup
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init


logger = logging.getLogger(__name__)


class MllamaModel(MllamaPreTrainedModel):
"""
Base MLLaMA model that provides the fundamental architecture combining vision and text.
This serves as the foundation for the specialized adapter model version.
"""

def __init__(self, config):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1

self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaTextModel._from_config(config.text_config)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
)
self.post_init()

def get_input_embeddings(self):
return self.language_model.get_input_embeddings()

def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def get_output_embeddings(self):
return self.language_model.get_output_embeddings()

def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)

def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)

def get_decoder(self):
return self.language_model.get_decoder()

def tie_weights(self):
return self.language_model.tie_weights()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:

# Establish parameter values
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# Check invalid argument combinations
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")

# If image is provided compute cross_attention_states
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size
)

# Compute cross_attention_mask
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]

outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
)

return outputs


@add_start_docstrings(MLLAMA_START_DOCSTRING)
class MllamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MllamaPreTrainedModel):

head_types = [
"causal_lm",
]

def __init__(self, config):
super().__init__(config)

self.model = MllamaModel(config)
init(self.model)

self._init_head_modules()
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):

outputs, context = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
aspect_ratio_mask=aspect_ratio_mask,
aspect_ratio_ids=aspect_ratio_ids,
attention_mask=attention_mask,
cross_attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
kwargs["context"] = context

hidden_states = outputs[0]
head_input_states = hidden_states[:, -num_logits_to_keep:, :]

if head or AdapterSetup.get_context_head_setup() or self.active_head:
head_outputs = self.forward_head(
head_input_states,
head_name=head,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)
return head_outputs
return outputs
Loading
Loading