From b9592c50aafe21ef905cb8041988bbf2485edb6e Mon Sep 17 00:00:00 2001 From: jankinf Date: Fri, 2 Aug 2024 11:35:03 +0800 Subject: [PATCH] support deepseek-vl --- mmte/configs/models/deepseek/deepseek-7b.yaml | 0 mmte/models/__init__.py | 1 + mmte/models/deepseek_chat.py | 105 +++ .../models/deepseek_vl/models/clip_encoder.py | 242 +++++++ .../models/image_processing_vlm.py | 208 ++++++ .../models/deepseek_vl/models/modeling_vlm.py | 170 +++++ .../deepseek_vl/models/processing_vlm.py | 390 ++++++++++ mmte/models/deepseek_vl/models/projector.py | 100 +++ mmte/models/deepseek_vl/models/sam.py | 593 +++++++++++++++ mmte/models/deepseek_vl/models/siglip_vit.py | 681 ++++++++++++++++++ mmte/models/deepseek_vl/utils/conversation.py | 348 +++++++++ mmte/models/deepseek_vl/utils/io.py | 89 +++ 12 files changed, 2927 insertions(+) create mode 100644 mmte/configs/models/deepseek/deepseek-7b.yaml create mode 100644 mmte/models/deepseek_chat.py create mode 100644 mmte/models/deepseek_vl/models/clip_encoder.py create mode 100644 mmte/models/deepseek_vl/models/image_processing_vlm.py create mode 100644 mmte/models/deepseek_vl/models/modeling_vlm.py create mode 100644 mmte/models/deepseek_vl/models/processing_vlm.py create mode 100644 mmte/models/deepseek_vl/models/projector.py create mode 100644 mmte/models/deepseek_vl/models/sam.py create mode 100644 mmte/models/deepseek_vl/models/siglip_vit.py create mode 100644 mmte/models/deepseek_vl/utils/conversation.py create mode 100644 mmte/models/deepseek_vl/utils/io.py diff --git a/mmte/configs/models/deepseek/deepseek-7b.yaml b/mmte/configs/models/deepseek/deepseek-7b.yaml new file mode 100644 index 0000000..e69de29 diff --git a/mmte/models/__init__.py b/mmte/models/__init__.py index cedb218..98494c9 100644 --- a/mmte/models/__init__.py +++ b/mmte/models/__init__.py @@ -18,6 +18,7 @@ from mmte.models.sharegpt4v_chat import ShareGPT4VChat from mmte.models.cogvlm_chat import CogVLMChat from mmte.models.phi3_chat import Phi3Chat +from mmte.models.deepseek_chat import DeepSeekVL from typing import List diff --git a/mmte/models/deepseek_chat.py b/mmte/models/deepseek_chat.py new file mode 100644 index 0000000..c99056d --- /dev/null +++ b/mmte/models/deepseek_chat.py @@ -0,0 +1,105 @@ +import torch +from typing import List +from transformers import AutoModelForCausalLM +from mmte.utils.registry import registry +from mmte.models.base import BaseChat, Response +from .deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM +from .deepseek_vl.utils.io import load_pil_images + + +@registry.register_chatmodel() +class DeepSeekVL(BaseChat): + """ + Chat class for deepseek-7b model, + """ + + # TODO: update model config + MODEL_CONFIG = { + "deepseek-7b": 'configs/models/deepseek/deepseek-7b.yaml', + } + model_family = list(MODEL_CONFIG.keys()) + + def __init__(self, model_id: str, device: str="cuda:0"): + super().__init__(model_id) + model_path = "deepseek-ai/deepseek-vl-7b-chat" + vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) + tokenizer = vl_chat_processor.tokenizer + + vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) + vl_gpt = vl_gpt.to(torch.bfloat16).to(device).eval() + + self.device = device + self.model = vl_gpt + self.tokenizer = tokenizer + self.vl_chat_processor = vl_chat_processor + + @torch.no_grad() + def chat(self, messages: List, **generation_kwargs): + # TODO: if system message provided. + for message in messages: + if message["role"] in ["system", "user", "assistant"]: + if message["role"] == "user": + if isinstance(message["content"], dict): + # multimodal + image_path = message["content"]["image_path"] + user_message = message["content"]["text"] + else: + image_path = None + user_message = message["content"] + elif message["role"] == "assistant": + # TODO: add assistant answer into the conversation + pass + else: + raise ValueError("Unsupported role. Only system, user and assistant are supported.") + + if image_path is not None: + conversation = [ + { + "role": "User", + "content": "" + user_message, + "images": [image_path] + }, + { + "role": "Assistant", + "content": "" + } + ] + else: + conversation = [ + { + "role": "User", + "content": user_message, + }, + { + "role": "Assistant", + "content": "" + } + ] + + pil_images = load_pil_images(conversation) + prepare_inputs = self.vl_chat_processor( + conversations=conversation, + images=pil_images, + force_batchify=True + ).to(self.model.device) + + # run image encoder to get the image embeddings + inputs_embeds = self.model.prepare_inputs_embeds(**prepare_inputs) + + # run the model to get the response + outputs = self.model.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=self.tokenizer.eos_token_id, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + max_new_tokens=generation_kwargs.get("max_new_tokens", 512), + do_sample=generation_kwargs.get("do_sample"), + use_cache=True + ) + + output_text = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) + + scores = None + return Response(self.model_id, output_text, scores, None) + diff --git a/mmte/models/deepseek_vl/models/clip_encoder.py b/mmte/models/deepseek_vl/models/clip_encoder.py new file mode 100644 index 0000000..8695701 --- /dev/null +++ b/mmte/models/deepseek_vl/models/clip_encoder.py @@ -0,0 +1,242 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms +from einops import rearrange + +from .sam import create_sam_vit +from .siglip_vit import create_siglip_vit + + +class CLIPVisionTower(nn.Module): + def __init__( + self, + model_name: str = "siglip_large_patch16_384", + image_size: Union[Tuple[int, int], int] = 336, + select_feature: str = "patch", + select_layer: int = -2, + select_layers: list = None, + ckpt_path: str = "", + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + **kwargs, + ): + super().__init__() + + self.model_name = model_name + self.select_feature = select_feature + self.select_layer = select_layer + self.select_layers = select_layers + + vision_tower_params = { + "model_name": model_name, + "image_size": image_size, + "ckpt_path": ckpt_path, + "select_layer": select_layer, + } + vision_tower_params.update(kwargs) + self.vision_tower, self.forward_kwargs = self.build_vision_tower( + vision_tower_params + ) + + if pixel_mean is not None and pixel_std is not None: + image_norm = torchvision.transforms.Normalize( + mean=pixel_mean, std=pixel_std + ) + else: + image_norm = None + + self.image_norm = image_norm + + def build_vision_tower(self, vision_tower_params): + if self.model_name.startswith("siglip"): + self.select_feature = "same" + vision_tower = create_siglip_vit(**vision_tower_params) + forward_kwargs = dict() + + elif self.model_name.startswith("sam"): + vision_tower = create_sam_vit(**vision_tower_params) + forward_kwargs = dict() + + else: # huggingface + from transformers import CLIPVisionModel + + vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) + forward_kwargs = dict(output_hidden_states=True) + + return vision_tower, forward_kwargs + + def feature_select(self, image_forward_outs): + if isinstance(image_forward_outs, torch.Tensor): + # the output has been the self.select_layer"s features + image_features = image_forward_outs + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if self.select_feature == "patch": + # if the output has cls_token + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "same": + image_features = image_features + + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward(self, images): + """ + + Args: + images (torch.Tensor): [b, 3, H, W] + + Returns: + image_features (torch.Tensor): [b, n_patch, d] + """ + + if self.image_norm is not None: + images = self.image_norm(images) + + image_forward_outs = self.vision_tower(images, **self.forward_kwargs) + image_features = self.feature_select(image_forward_outs) + return image_features + + +class HybridVisionTower(nn.Module): + def __init__( + self, + high_res_cfg: Dict, + low_res_cfg: Dict, + freeze_high: bool = False, + freeze_low: bool = False, + concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple", + **ignore_kwargs, + ): + super().__init__() + + self.vision_tower_high = CLIPVisionTower(**high_res_cfg) + self.vision_tower_low = CLIPVisionTower(**low_res_cfg) + self.low_res_size = low_res_cfg["image_size"] + self.concat_type = concat_type + + self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024)) + self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024)) + + if freeze_high: + for p_name, p in self.vision_tower_high.named_parameters(): + p.requires_grad = False + self.vision_tower_high = self.vision_tower_high.eval() + else: + # train donwsamples and neck + for p_name, p in self.vision_tower_high.named_parameters(): + if "downsamples" in p_name or "neck" in p_name: + p.requires_grad = True + else: + p.requires_grad = False + + if freeze_low: + for p in self.vision_tower_low.parameters(): + p.requires_grad = False + self.vision_tower_low = self.vision_tower_low.eval() + + self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True) + + def forward(self, images: torch.Tensor): + """ + + Args: + images (torch.Tensor): [bs, 3, H, W] + + Returns: + res (torch.Tensor): [bs, t, c] + """ + + # [bs, c, h, w] + high_images = images + + # [bs, c, h_low, w_low] + low_images = self.resize(images) + + # separately run two vision towers + # run high_res vision tower + high_res = self.vision_tower_high(high_images) + # [bs, c, h, w] -> [bs, h*w, c] + high_res = rearrange(high_res, "b c h w -> b (h w) c") + # run low_res vision tower + low_res = self.vision_tower_low(low_images) + + if self.concat_type == "feature": + images_features = torch.cat([high_res, low_res], dim=-1) + elif self.concat_type == "sequence": + images_features = torch.cat([high_res, low_res], dim=1) + elif self.concat_type == "add": + images_features = high_res + low_res + elif self.concat_type == "tuple": + images_features = (high_res, low_res) + + else: + raise ValueError( + "Currently only support `feature`, `sequence`, `add` and `tuple` concat type." + ) + + return images_features + + +if __name__ == "__main__": + image_size = 1024 + x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda() + + high_res_cfg = dict( + model_name="sam_b_downsample", + select_feature="same", + image_size=image_size, + pixel_mean=(0.48145466, 0.4578275, 0.40821073), + pixel_std=(0.26862954, 0.26130258, 0.27577711), + select_layer=-1, + ckpt_path="", + ) + + low_res_cfg = dict( + model_name="siglip_large_patch16_384", + select_feature="same", + image_size=384, + pixel_mean=(0.5, 0.5, 0.5), + pixel_std=(0.5, 0.5, 0.5), + select_layer=-1, + ckpt_path="", + ) + + net = ( + HybridVisionTower( + high_res_cfg=high_res_cfg, + low_res_cfg=low_res_cfg, + freeze_high=True, + freeze_low=True, + concat_type="tuple", + ) + .bfloat16() + .cuda() + ) + high_x, low_x = net(x) + print(x.shape, high_x.shape, low_x.shape) diff --git a/mmte/models/deepseek_vl/models/image_processing_vlm.py b/mmte/models/deepseek_vl/models/image_processing_vlm.py new file mode 100644 index 0000000..367dee1 --- /dev/null +++ b/mmte/models/deepseek_vl/models/image_processing_vlm.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import List, Tuple, Union + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional +from PIL import Image +from transformers import AutoImageProcessor, PretrainedConfig +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import to_numpy_array +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] +IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) +IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +class VLMImageProcessorConfig(PretrainedConfig): + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size), + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + pil_img = torchvision.transforms.functional.resize( + pil_img, + size, + interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, + antialias=True, + ) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + images: List[np.ndarray] = [self.resize(image) for image in images] + + # resacle from [0, 255] -> [0, 1] + images = [ + self.rescale( + image=image, + scale=self.rescale_factor, + input_data_format="channels_first", + ) + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize( + image=image, + mean=self.image_mean, + std=self.image_std, + input_data_format="channels_first", + ) + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) + + +if __name__ == "__main__": + image_processor = VLMImageProcessor( + image_size=1024, + image_mean=IMAGENET_INCEPTION_MEAN, + image_std=IMAGENET_INCEPTION_STD, + do_normalize=True, + ) diff --git a/mmte/models/deepseek_vl/models/modeling_vlm.py b/mmte/models/deepseek_vl/models/modeling_vlm.py new file mode 100644 index 0000000..75389dd --- /dev/null +++ b/mmte/models/deepseek_vl/models/modeling_vlm.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import torch +from attrdict import AttrDict +from einops import rearrange +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaForCausalLM, + PreTrainedModel, +) +from transformers.configuration_utils import PretrainedConfig + +from .clip_encoder import CLIPVisionTower, HybridVisionTower +from .projector import MlpProjector + + +def model_name_to_cls(cls_name): + if "MlpProjector" in cls_name: + cls = MlpProjector + + elif "CLIPVisionTower" in cls_name: + cls = CLIPVisionTower + + elif "HybridVisionTower" in cls_name: + cls = HybridVisionTower + + else: + raise ValueError(f"class_name {cls_name} is invalid.") + + return cls + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class MultiModalityConfig(PretrainedConfig): + model_type = "multi_modality" + vision_config: VisionConfig + aligner_config: AlignerConfig + language_config: LlamaConfig + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionConfig(**vision_config) + + aligner_config = kwargs.get("aligner_config", {}) + self.aligner_config = AlignerConfig(**aligner_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, LlamaConfig): + self.language_config = language_config + else: + self.language_config = LlamaConfig(**language_config) + + +class MultiModalityPreTrainedModel(PreTrainedModel): + config_class = MultiModalityConfig + base_model_prefix = "multi_modality" + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + + +class MultiModalityCausalLM(MultiModalityPreTrainedModel): + def __init__(self, config: MultiModalityConfig): + super().__init__(config) + + vision_config = config.vision_config + vision_cls = model_name_to_cls(vision_config.cls) + self.vision_model = vision_cls(**vision_config.params) + + aligner_config = config.aligner_config + aligner_cls = model_name_to_cls(aligner_config.cls) + self.aligner = aligner_cls(aligner_config.params) + + language_config = config.language_config + self.language_model = LlamaForCausalLM(language_config) + + def prepare_inputs_embeds( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + images_seq_mask: torch.LongTensor, + images_emb_mask: torch.LongTensor, + **kwargs, + ): + """ + + Args: + input_ids (torch.LongTensor): [b, T] + pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] + images_seq_mask (torch.BoolTensor): [b, T] + images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] + + assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) + + Returns: + input_embeds (torch.Tensor): [b, T, D] + """ + + bs, n = pixel_values.shape[0:2] + images = rearrange(pixel_values, "b n c h w -> (b n) c h w") + # [b x n, T2, D] + images_embeds = self.aligner(self.vision_model(images)) + + # [b x n, T2, D] -> [b, n x T2, D] + images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) + # [b, n, T2] -> [b, n x T2] + images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") + + # [b, T, D] + input_ids[input_ids < 0] = 0 # ignore the image embeddings + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + # replace with the image embeddings + inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] + + return inputs_embeds + + +AutoConfig.register("vision", VisionConfig) +AutoConfig.register("aligner", AlignerConfig) +AutoConfig.register("multi_modality", MultiModalityConfig) +AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) diff --git a/mmte/models/deepseek_vl/models/processing_vlm.py b/mmte/models/deepseek_vl/models/processing_vlm.py new file mode 100644 index 0000000..aef5095 --- /dev/null +++ b/mmte/models/deepseek_vl/models/processing_vlm.py @@ -0,0 +1,390 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from dataclasses import dataclass +from typing import Dict, List + +import torch +from PIL.Image import Image +from transformers import LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin + +from ..models.image_processing_vlm import VLMImageProcessor +from ..utils.conversation import get_conv_template + + +class DictOutput(object): + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + sft_format: str + input_ids: torch.Tensor + pixel_values: torch.Tensor + num_image_tokens: torch.IntTensor + + def __len__(self): + return len(self.input_ids) + + +@dataclass +class BatchedVLChatProcessorOutput(DictOutput): + sft_format: List[str] + input_ids: torch.Tensor + pixel_values: torch.Tensor + attention_mask: torch.Tensor + images_seq_mask: torch.BoolTensor + images_emb_mask: torch.BoolTensor + + def to(self, device, dtype=torch.bfloat16): + self.input_ids = self.input_ids.to(device) + self.attention_mask = self.attention_mask.to(device) + self.images_seq_mask = self.images_seq_mask.to(device) + self.images_emb_mask = self.images_emb_mask.to(device) + self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) + return self + + +class VLChatProcessor(ProcessorMixin): + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + attributes = ["image_processor", "tokenizer"] + + system_prompt = ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ) + + def __init__( + self, + image_processor: VLMImageProcessor, + tokenizer: LlamaTokenizerFast, + image_tag: str = "", + num_image_tokens: int = 576, + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_processor = image_processor + self.tokenizer = tokenizer + + image_id = self.tokenizer.vocab.get(image_tag) + if image_id is None: + special_tokens = [image_tag] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + print(f"Add image tag = {image_tag} to the tokenizer") + + self.image_tag = image_tag + self.num_image_tokens = num_image_tokens + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + image_processor, + tokenizer, + image_tag, + num_image_tokens, + add_special_token, + sft_format, + mask_prompt, + ignore_id, + **kwargs, + ) + + def new_chat_template(self): + conv = get_conv_template(self.sft_format) + conv.set_system_message(self.system_prompt) + return conv + + def apply_sft_template_for_multi_turn_prompts( + self, + conversations: List[Dict[str, str]], + sft_format: str = "deepseek", + system_prompt: str = "", + ): + """ + Applies the SFT template to conversation. + + An example of conversation: + conversation = [ + { + "role": "User", + "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", + "images": [ + "./multi-images/attribute_comparison_1.png", + "./multi-images/attribute_comparison_2.png" + ] + }, + { + "role": "Assistant", + "content": "" + } + ] + + Args: + conversations (List[Dict]): A conversation with a List of Dict[str, str] text. + sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". + + Returns: + sft_prompt (str): The formatted text. + """ + + conv = get_conv_template(sft_format) + conv.set_system_message(system_prompt) + for message in conversations: + conv.append_message(message["role"], message["content"].strip()) + sft_prompt = conv.get_prompt().strip() + + return sft_prompt + + @property + def image_token(self): + return self.image_tag + + @property + def image_id(self): + image_id = self.tokenizer.vocab.get(self.image_tag) + return image_id + + @property + def pad_id(self): + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = self.tokenizer.eos_token_id + + return pad_id + + def add_image_token( + self, + image_indices: List[int], + input_ids: torch.LongTensor, + ): + """ + + Args: + image_indices (List[int]): [index_0, index_1, ..., index_j] + input_ids (torch.LongTensor): [N] + + Returns: + input_ids (torch.LongTensor): [N + image tokens] + num_image_tokens (torch.IntTensor): [n_images] + """ + + input_slices = [] + + start = 0 + for index in image_indices: + if self.add_special_token: + end = index + 1 + else: + end = index + + # original text tokens + input_slices.append(input_ids[start:end]) + + # add image tokens, and set the mask as False + input_slices.append( + self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) + ) + start = index + 1 + + # the left part + input_slices.append(input_ids[start:]) + + # concat all slices + input_ids = torch.cat(input_slices, dim=0) + num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) + + return input_ids, num_image_tokens + + def process_one( + self, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert ( + prompt is None or conversations is None + ), "prompt and conversations cannot be used at the same time." + + if prompt is None: + # apply sft format + sft_format = self.apply_sft_template_for_multi_turn_prompts( + conversations=conversations, + sft_format=self.sft_format, + system_prompt=self.system_prompt, + ) + else: + sft_format = prompt + + # tokenize + input_ids = self.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + + # add image tokens to the input_ids + image_token_mask: torch.BoolTensor = input_ids == self.image_id + image_indices = image_token_mask.nonzero() + input_ids, num_image_tokens = self.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # load images + images_outputs = self.image_processor(images, return_tensors="pt") + + prepare = VLChatProcessorOutput( + sft_format=sft_format, + input_ids=input_ids, + pixel_values=images_outputs.pixel_values, + num_image_tokens=num_image_tokens, + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + force_batchify: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, conversations=conversations, images=images + ) + + if force_batchify: + prepare = self.batchify([prepare]) + + return prepare + + def batchify( + self, prepare_list: List[VLChatProcessorOutput] + ) -> BatchedVLChatProcessorOutput: + """ + Preprocesses the inputs for multimodal inference. + + Args: + prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. + + Returns: + BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. + """ + + batch_size = len(prepare_list) + sft_format = [] + n_images = [] + seq_lens = [] + for prepare in prepare_list: + n_images.append(len(prepare.num_image_tokens)) + seq_lens.append(len(prepare)) + + input_token_max_len = max(seq_lens) + max_n_images = max(1, max(n_images)) + + batched_input_ids = torch.full( + (batch_size, input_token_max_len), self.pad_id + ).long() # FIXME + batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() + batched_pixel_values = torch.zeros( + (batch_size, max_n_images, *self.image_processor.default_shape) + ).float() + batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() + batched_images_emb_mask = torch.zeros( + (batch_size, max_n_images, self.num_image_tokens) + ).bool() + + for i, prepare in enumerate(prepare_list): + input_ids = prepare.input_ids + seq_len = len(prepare) + n_image = len(prepare.num_image_tokens) + # left-padding + batched_attention_mask[i, -seq_len:] = 1 + batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) + batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id + + if n_image > 0: + batched_pixel_values[i, :n_image] = prepare.pixel_values + for j, n_image_tokens in enumerate(prepare.num_image_tokens): + batched_images_emb_mask[i, j, :n_image_tokens] = True + + sft_format.append(prepare.sft_format) + + batched_prepares = BatchedVLChatProcessorOutput( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + pixel_values=batched_pixel_values, + images_seq_mask=batched_images_seq_mask, + images_emb_mask=batched_images_emb_mask, + sft_format=sft_format, + ) + + return batched_prepares diff --git a/mmte/models/deepseek_vl/models/projector.py b/mmte/models/deepseek_vl/models/projector.py new file mode 100644 index 0000000..15f4ca3 --- /dev/null +++ b/mmte/models/deepseek_vl/models/projector.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Tuple, Union + +import torch +import torch.nn as nn +from attrdict import AttrDict + + +class MlpProjector(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + if cfg.projector_type == "identity": + modules = nn.Identity() + + elif cfg.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) + + elif cfg.projector_type == "mlp_gelu": + mlp_depth = cfg.get("depth", 1) + modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + mlp_depth = cfg.get("depth", 1) + self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {cfg.projector_type}") + + self.layers = modules + + def forward( + self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + ): + """ + + Args: + x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, + then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); + otherwise it is the feature from the single vision encoder. + + Returns: + x (torch.Tensor): [b, s, c] + """ + + if isinstance(x_or_tuple, tuple): + # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x_or_tuple + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + else: + x = x_or_tuple + + return self.layers(x) + + +if __name__ == "__main__": + cfg = AttrDict( + input_dim=1024, + n_embed=2048, + depth=2, + projector_type="low_high_hybrid_split_mlp_gelu", + ) + inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) + + m = MlpProjector(cfg) + out = m(inputs) + print(out.shape) diff --git a/mmte/models/deepseek_vl/models/sam.py b/mmte/models/deepseek_vl/models/sam.py new file mode 100644 index 0000000..31159a9 --- /dev/null +++ b/mmte/models/deepseek_vl/models/sam.py @@ -0,0 +1,593 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + downsample_channels: Tuple[int, ...] = (512, 1024), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + downsample_channels (list): Channels for downsampling layers. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + in_channels = out_chans + downsamples = [] + for i in range(len(downsample_channels)): + out_channels = downsample_channels[i] + downsamples.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ) + ) + in_channels = out_channels + self.downsamples = nn.Sequential(*downsamples) + + self.sam_hd = True + if self.sam_hd: + self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1)) + # self.neck_hd = nn.Linear(embed_dim, embed_dim) + self.neck_hd = copy.deepcopy(self.neck) + # self.downsamples_hd = copy.deepcopy(self.downsamples) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + global_features = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if self.sam_hd and blk.window_size == 0: + global_features.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + x_dtype = x.dtype + x = F.interpolate( + x.float(), size=(96, 96), mode="bilinear", align_corners=False + ).to(x_dtype) + x = self.downsamples(x) + + if self.sam_hd: + first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2)) + x_dtype = first_global_feature.dtype + first_global_feature = F.interpolate( + first_global_feature.float(), + size=(96, 96), + mode="bilinear", + align_corners=False, + ) + first_global_feature = self.downsamples(first_global_feature.to(x_dtype)) + x = x + first_global_feature * self.hd_alpha_downsamples + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + def do_attention(q, k, v): + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.use_rel_pos: + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + attn = attn.softmax(dim=-1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + + return x + + # from haiscale.utils import on_demand_checkpoint + # x = on_demand_checkpoint(do_attention, q, k, v) + x = do_attention(q, k, v) + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +@dataclass +class SAMViTCfg: + image_size: Union[Tuple[int, int], int] = 1024 + width: int = 1024 + layers: int = 23 + heads: int = 16 + patch_size: int = 16 + window_size: int = 14 + prompt_embed_dim: int = 256 + global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23) + downsample_channels: Union[List[int], Tuple[int]] = (512, 1024) + + +SAM_MODEL_CONFIG = { + "sam_vit_b": { + "width": 768, + "layers": 12, + "heads": 12, + "global_attn_indexes": [2, 5, 8, 11], + "downsample_channels": (), + }, + "sam_b_downsample": { + "width": 768, + "layers": 12, + "heads": 12, + "global_attn_indexes": [2, 5, 8, 11], + "downsample_channels": (512, 1024), + }, + "sam_vit_l": { + "width": 1024, + "layers": 24, + "heads": 16, + "global_attn_indexes": [5, 11, 17, 23], + "downsample_channels": (), + }, + "sam_vit_h": { + "width": 1280, + "layers": 32, + "heads": 16, + "global_attn_indexes": [7, 15, 23, 31], + "downsample_channels": (), + }, +} + + +def create_sam_vit( + model_name: str = "sam_b_downsample", + image_size: int = 1024, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SAM_MODEL_CONFIG.keys() + ), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}" + + sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name]) + image_encoder = ImageEncoderViT( + depth=sam_cfg.layers, + embed_dim=sam_cfg.width, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=sam_cfg.heads, + patch_size=sam_cfg.patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=sam_cfg.global_attn_indexes, + window_size=14, + out_chans=sam_cfg.prompt_embed_dim, + downsample_channels=sam_cfg.downsample_channels, + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path) + image_encoder.load_state_dict(state_dict, strict=False) + print(f"SAM-ViT restores from {ckpt_path}") + + return image_encoder + + +if __name__ == "__main__": + x = torch.zeros(2, 3, 1024, 1024).bfloat16() + # x.permute(0, 3, 1, 2) + net = create_sam_vit().bfloat16() + out = net(x) + print(x.shape, out.shape) diff --git a/mmte/models/deepseek_vl/models/siglip_vit.py b/mmte/models/deepseek_vl/models/siglip_vit.py new file mode 100644 index 0000000..a93707b --- /dev/null +++ b/mmte/models/deepseek_vl/models/siglip_vit.py @@ -0,0 +1,681 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, +) +from timm.models._manipulate import checkpoint_seq, named_apply + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == "map": + AttentionPoolLatent.init_weights = init_weights + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + if weight_init != "skip": + self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if not self.ignore_head: + x = self.forward_head(x) + return x + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 336, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + model = VisionTransformer( + img_size=image_size, + patch_size=vision_cfg.patch_size, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + ignore_head=kwargs.get("ignore_head", True), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path, map_location="cpu") + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print( + f"SigLIP-ViT restores from {ckpt_path},\n" + f"\tincompatible_keys:', {incompatible_keys}." + ) + + return model diff --git a/mmte/models/deepseek_vl/utils/conversation.py b/mmte/models/deepseek_vl/utils/conversation.py new file mode 100644 index 0000000..711512f --- /dev/null +++ b/mmte/models/deepseek_vl/utils/conversation.py @@ -0,0 +1,348 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Dict, List + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + DeepSeek = auto() + PLAIN = auto() + ALIGNMENT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: List[str] = (("USER", "ASSISTANT"),) + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: str = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + + if self.sep_style == SeparatorStyle.DeepSeek: + seps = [self.sep, self.sep2] + if system_prompt == "" or system_prompt is None: + ret = "" + else: + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if type(message) is tuple: # multimodal message + message, _ = message + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += message + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + elif self.sep_style == SeparatorStyle.ALIGNMENT: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += "\n" + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_prompt_for_current_round(self, content=None): + """Get current round formatted question prompt during sft training""" + if self.sep_style == SeparatorStyle.PLAIN: + formatted_question = "\n" + elif self.sep_style == SeparatorStyle.DeepSeek: + formatted_question = ( + f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:" + ) + else: + raise ValueError(f"Unsupported sep_style: {self.sep_style}") + return formatted_question + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def reset_message(self): + """Reset a new message.""" + self.messages = [] + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + system_prompt = self.system_template.format(system_message=self.system_message) + ret = [{"role": "system", "content": system_prompt}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# llava_llama2 template +register_conv_template( + Conversation( + name="llava_llama2", + system_message="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + +# llama2 template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + + +# deepseek template +register_conv_template( + Conversation( + name="deepseek", + system_template="{system_message}", + # system_message="You are a helpful assistant. Please answer truthfully and write out your " + # "thinking step by step to be sure you get the right answer.", + system_message="", + roles=("User", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DeepSeek, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_token_ids=[100001], + stop_str=["User:", "<|end▁of▁sentence|>"], + ) +) + +register_conv_template( + Conversation( + name="plain", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[""], + ) +) + + +register_conv_template( + Conversation( + name="alignment", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.ALIGNMENT, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[""], + ) +) + + +if __name__ == "__main__": + # print("Llama-2 template:") + # conv = get_conv_template("llama-2") + # conv.set_system_message("You are a helpful, respectful and honest assistant.") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.get_prompt()) + + # print("\n") + + print("deepseek template:") + conv = get_conv_template("deepseek") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi! This is Tony.") + conv.append_message(conv.roles[0], "Who are you?") + conv.append_message(conv.roles[1], "I am a helpful assistant.") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/mmte/models/deepseek_vl/utils/io.py b/mmte/models/deepseek_vl/utils/io.py new file mode 100644 index 0000000..497271f --- /dev/null +++ b/mmte/models/deepseek_vl/utils/io.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import json +from typing import Dict, List + +import PIL.Image +import torch +import base64 +import io +from transformers import AutoModelForCausalLM + +from ..models import MultiModalityCausalLM, VLChatProcessor + + +def load_pretrained_model(model_path: str): + vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) + tokenizer = vl_chat_processor.tokenizer + + vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True + ) + vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + + return tokenizer, vl_chat_processor, vl_gpt + + +def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: + """ + + Support file path or base64 images. + + Args: + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : + [ + { + "role": "User", + "content": "\nExtract all information from this image and convert them into markdown format.", + "images": ["./examples/table_datasets.png"] + }, + {"role": "Assistant", "content": ""}, + ] + + Returns: + pil_images (List[PIL.Image.Image]): the list of PIL images. + + """ + + pil_images = [] + + for message in conversations: + if "images" not in message: + continue + + for image_data in message["images"]: + if image_data.startswith("data:image"): + # Image data is in base64 format + _, image_data = image_data.split(",", 1) + image_bytes = base64.b64decode(image_data) + pil_img = PIL.Image.open(io.BytesIO(image_bytes)) + else: + # Image data is a file path + pil_img = PIL.Image.open(image_data) + pil_img = pil_img.convert("RGB") + pil_images.append(pil_img) + + return pil_images + + +def load_json(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data