Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel Transformers #61

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions models/transformers/modify_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from transformers.models.mixtral.modeling_mixtral import (
MixtralRotaryEmbedding,
MixtralAttention,
MixtralBLockSparseTop2MLP,
ACT2FN,
Optional,
)
from transformers.utils import logging
from axonn.intra_layer import Linear

logger = logging.get_logger(__name__)


def modified_attention_init(self, config, layer_idx: Optional[int] = None):
super(MixtralAttention, self).__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` "
"is not recommended and will lead to errors during the forward call, "
"if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)

self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.attention_dropout = config.attention_dropout

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} & `num_heads`: {self.num_heads})."
)
self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.rotary_emb = MixtralRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)


def modified_mlp_init(self, config):
super(MixtralBLockSparseTop2MLP, self).__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size

self.w1 = Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = Linear(self.hidden_dim, self.ffn_dim, bias=False)

self.act_fn = ACT2FN[config.hidden_act]


def monkey_patch_mixtral_with_axonn():
MixtralAttention.__init__ = modified_attention_init
MixtralBLockSparseTop2MLP.__init__ = modified_mlp_init
113 changes: 113 additions & 0 deletions models/transformers/modify_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# import sys
# sys.path.append("/nfshomes/jwendlan/Qwen-7B")
# need to have local version of Qwen for next import statement, along with above code
from modeling_qwen import QWenAttention, FlashSelfAttention, QWenMLP
import torch.nn as nn
from axonn.intra_layer import Linear
import torch
import math
import warnings
import pathlib
from flash_attn import flash_attn_unpadded_func


def modified_attention_init(self, config):
super(QWenAttention, self).__init__()

self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.seq_length = config.seq_length

self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads

self.use_flash_attn = config.use_flash_attn
self.scale_attn_weights = True

self.projection_size = config.kv_channels * config.num_attention_heads

assert self.projection_size % config.num_attention_heads == 0
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)

self.c_attn = Linear(config.hidden_size, 3 * self.projection_size)

self.c_proj = Linear(
config.hidden_size, self.projection_size, bias=not config.no_bias
)

self.is_fp32 = not (config.bf16 or config.fp16)
if (
self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
):
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=config.attn_dropout_prob
)
self.bf16 = config.bf16

self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn

logn_list = [
math.log(i, self.seq_length) if i > self.seq_length else 1
for i in range(1, 32768)
]
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.register_buffer("logn_tensor", logn_tensor, persistent=False)

self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.softmax_in_fp32 = (
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
)
self.use_cache_quantization = (
config.use_cache_quantization
if hasattr(config, "use_cache_quantization")
else False
)
self.use_cache_kernel = (
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
)
cache_dtype = torch.float
if self.bf16:
cache_dtype = torch.bfloat16
elif config.fp16:
cache_dtype = torch.float16
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)

if config.use_cache_quantization and config.use_cache_kernel:
# pre check if the support files existing
module_root = pathlib.Path(__file__).parent
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
if any(not (module_root / src).is_file() for src in src_files):
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
self.cache_kernels = None
else:
try:
from .cpp_kernels import cache_autogptq_cuda_256

self.cache_kernels = cache_autogptq_cuda_256
except ImportError:
warnings.warn("Failed to import KV cache kernels.")
self.cache_kernels = None


def modified_mlp_init(self, config):
super(QWenMLP, self).__init__()
self.w1 = Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w2 = Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
ff_dim_in = config.intermediate_size // 2
self.c_proj = Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)


def monkey_patch_qwen_with_axonn():
QWenAttention.__init__ = modified_attention_init
QWenMLP.__init__ = modified_mlp_init
Loading