Skip to content

Commit

Permalink
feat: llama
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Feb 5, 2025
1 parent 993c85c commit 8cd9bb3
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 10 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"LVIS",
"MDETR",
"multihead",
"multimodal",
"Numbersbatch",
"ODKD",
"preds",
Expand Down
8 changes: 6 additions & 2 deletions docs/source/pretrained/gemma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GemmaForCausalLM,
GemmaTokenizerFast,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma import GemmaForCausalLM, GemmaTokenizerFast

import todd

Expand Down
102 changes: 102 additions & 0 deletions docs/source/pretrained/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import random

import torch
from PIL import Image
from transformers import (
AutoTokenizer,
BatchEncoding,
MllamaForConditionalGeneration,
MllamaProcessor,
PreTrainedTokenizerFast,
)

import todd


class Chatbot:
PRETRAINED = 'pretrained/llama/Llama-3.2-11B-Vision-Instruct'

def __init__(self) -> None:
tokenizer = AutoTokenizer.from_pretrained(self.PRETRAINED)
self._tokenizer: PreTrainedTokenizerFast = tokenizer

processor = MllamaProcessor.from_pretrained(self.PRETRAINED)
self._processor = processor

model = MllamaForConditionalGeneration.from_pretrained(
self.PRETRAINED,
device_map='auto',
torch_dtype='auto',
)
self._model = model

def __call__(self, inputs: BatchEncoding) -> str:
if todd.Store.cuda: # pylint: disable=using-constant-test
inputs = inputs.to('cuda')

input_ids: torch.Tensor = inputs['input_ids']
_, input_length = input_ids.shape

output_ids = self._model.generate(**inputs, max_new_tokens=1024)
generated_ids = output_ids[0, input_length:]
generated_text = self._processor.decode(
generated_ids,
skip_special_tokens=True,
)

return generated_text

def chat(self, text: str) -> str:
conversation = [dict(role='user', content=text)]
inputs = self._tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors='pt',
return_dict=True,
)
return self(inputs)

def chat_multimodal(self, images: list[Image.Image], text: str) -> str:
content = [dict(type='image') for _ in images]
content.append(dict(type='text', text=text))
conversation = [dict(role='user', content=content)]
input_text = self._processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)

inputs = self._processor(
images,
input_text,
add_special_tokens=False,
return_tensors='pt',
)
return self(inputs)


def main() -> None:
filenames = os.listdir('cat')
sampled_filenames = random.sample(filenames, 5)
images = [
Image.open(os.path.join('cat', filename))
for filename in sampled_filenames
]

chatbot = Chatbot()

response = chatbot.chat("What is AI?")
todd.logger.info(response)

caption = chatbot.chat_multimodal(
images,
"The images are exemplars of a category. "
"Can you guess what category it is? "
"Answer with a template of the form: A photo of <object>. "
"Example: A photo of cat.",
)
todd.logger.info(caption)


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions docs/source/pretrained/llama.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ LLaMA
root=pretrained/llama
mkdir -p ${root} && cd ${root}
git clone git@hf.co:meta-llama/Llama-3.2-11B-Vision
git clone git@hf.co:meta-llama/Llama-3.2-11B-Vision-Instruct
cd ../..
10 changes: 5 additions & 5 deletions docs/source/pretrained/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from bs4 import BeautifulSoup, NavigableString
from ebooklib import ITEM_DOCUMENT, epub
from tqdm import tqdm
from transformers import BatchEncoding
from transformers.cache_utils import DynamicCache
from transformers.models.qwen2 import Qwen2ForCausalLM, Qwen2TokenizerFast
from transformers import DynamicCache, Qwen2ForCausalLM, Qwen2TokenizerFast

import todd

Expand Down Expand Up @@ -51,7 +49,7 @@ def __call__(self, text: str) -> str:
message = Message(role='user', content=text)
self._conversation.append(message)

inputs: BatchEncoding = self._tokenizer.apply_chat_template(
inputs = self._tokenizer.apply_chat_template(
self._conversation,
add_generation_prompt=True,
return_tensors='pt',
Expand Down Expand Up @@ -98,7 +96,9 @@ def _translate_text(self, text: str) -> str | None:

def _translate_item(self, item: epub.EpubItem) -> None:
soup = BeautifulSoup(item.content, 'html.parser')
texts: list[NavigableString] = soup.body.find_all(string=True)
body = soup.body
assert body is not None
texts: list[NavigableString] = body.find_all(string=True)
for text in tqdm(texts, leave=False):
translation = self._translate_text(text)
if translation is not None and translation.strip():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ optional = [
'regex',
'sentencepiece',
'tqdm',
'transformers',
'transformers==4.48.0',
'WeTextProcessing',
]
dev = [
Expand Down
5 changes: 3 additions & 2 deletions todd/tasks/image_classification/models/ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torch import nn
from torchvision.models import SwinTransformer
from torchvision.models.swin_transformer import ShiftedWindowAttention
from transformers.models.bert import BertConfig, BertLayer, BertModel
from transformers import BertConfig, BertModel
from transformers.models.bert import BertLayer
from transformers.models.bert.modeling_bert import BertSelfAttention

from todd.models.modules import PretrainedMixin
Expand Down Expand Up @@ -256,7 +257,7 @@ def forward( # pylint: disable=arguments-differ
category_embedding = self._in_linear(category_embedding)
category_embedding = category_embedding.relu()

embedding, *_ = super().forward(
embedding, *_ = super().forward( # pylint: disable=no-member
inputs_embeds=category_embedding,
encoder_hidden_states=x,
)
Expand Down

0 comments on commit 8cd9bb3

Please sign in to comment.