Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin MoE integration #2

Closed
wants to merge 111 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
5a2ab25
Moving branch to a different repo
ElizaWszola Aug 2, 2024
b39dba4
clean up the CPU code
ElizaWszola Aug 2, 2024
b0c4671
Fix build issues
ElizaWszola Aug 2, 2024
e5c1a81
Refactoring for maintainability
DhruvaBansal00 Aug 7, 2024
7da678e
Fixing tests
DhruvaBansal00 Aug 7, 2024
641696b
Addressing repacking comment
DhruvaBansal00 Aug 8, 2024
3cef667
gptq -> marlin renaming
DhruvaBansal00 Aug 8, 2024
a6710af
Undo formatting changes
DhruvaBansal00 Aug 8, 2024
e29107f
Final formatting change
DhruvaBansal00 Aug 8, 2024
099d61e
Switching to mixtral file for quantized mixtral
DhruvaBansal00 Aug 12, 2024
bdf6bdc
Bug fixes
DhruvaBansal00 Aug 12, 2024
19c5c59
is quantized change
DhruvaBansal00 Aug 12, 2024
3b7cc60
debug stat
DhruvaBansal00 Aug 12, 2024
d2c4754
replace wiehgt name with param name
DhruvaBansal00 Aug 12, 2024
f579cb2
typo
DhruvaBansal00 Aug 12, 2024
79394eb
debug
DhruvaBansal00 Aug 12, 2024
ec75f4e
more debug
DhruvaBansal00 Aug 12, 2024
91ca970
only relevant logging
DhruvaBansal00 Aug 12, 2024
1b9d5bb
log
DhruvaBansal00 Aug 12, 2024
ec06719
log
DhruvaBansal00 Aug 12, 2024
71d82e1
removing qzero weights
DhruvaBansal00 Aug 12, 2024
d3465d0
Qzeors in expert mapping
DhruvaBansal00 Aug 12, 2024
226ee26
Debug
DhruvaBansal00 Aug 12, 2024
21d7d27
Load qzero
DhruvaBansal00 Aug 12, 2024
2dabb4b
rm 2x
DhruvaBansal00 Aug 12, 2024
6366976
Mapping for scales
DhruvaBansal00 Aug 12, 2024
d63c096
rm logging
DhruvaBansal00 Aug 12, 2024
360fef4
Adding lyaer wise logging
DhruvaBansal00 Aug 12, 2024
c23d616
shard ids
DhruvaBansal00 Aug 12, 2024
8d81d14
Loading qzero correctly
DhruvaBansal00 Aug 12, 2024
22e1aa7
List operand
DhruvaBansal00 Aug 12, 2024
81e01f3
If clause
DhruvaBansal00 Aug 12, 2024
dcfd32d
Able to load layers
DhruvaBansal00 Aug 12, 2024
f04cbea
Setting load quant to false
DhruvaBansal00 Aug 12, 2024
a56821d
Disabling logging
DhruvaBansal00 Aug 12, 2024
7f961c6
Removing *2 in marlin moe repack
DhruvaBansal00 Aug 13, 2024
4a6c7ff
*4 in marlin moe repack
DhruvaBansal00 Aug 13, 2024
e6cd286
bits
DhruvaBansal00 Aug 13, 2024
90241c4
*4
DhruvaBansal00 Aug 13, 2024
67409e9
intermediate size
DhruvaBansal00 Aug 13, 2024
539032e
repeat keyword
DhruvaBansal00 Aug 13, 2024
57b1cbe
hidden size
DhruvaBansal00 Aug 13, 2024
87f1dd4
intermediate size back
DhruvaBansal00 Aug 13, 2024
4c073c2
permute scales w3
DhruvaBansal00 Aug 13, 2024
d732493
*2
DhruvaBansal00 Aug 13, 2024
fdc22c4
log
DhruvaBansal00 Aug 13, 2024
272822e
shape as 2
DhruvaBansal00 Aug 13, 2024
3ce045e
test
DhruvaBansal00 Aug 13, 2024
c4ba477
Increasing to 4 and changing assert
DhruvaBansal00 Aug 13, 2024
2ea8370
logging
DhruvaBansal00 Aug 13, 2024
8287025
marlin moe repack change
DhruvaBansal00 Aug 13, 2024
53b23b9
mult qweight shape by pack factor
DhruvaBansal00 Aug 13, 2024
bc40786
Potential support for 8 bit
DhruvaBansal00 Aug 13, 2024
bea13de
undo change
DhruvaBansal00 Aug 13, 2024
a3a9114
qzeros
DhruvaBansal00 Aug 13, 2024
eb916f9
switching traffic to mixtral quant
DhruvaBansal00 Aug 13, 2024
017d6f8
compat
DhruvaBansal00 Aug 13, 2024
eb9c087
Passing intermediate tensor into mixtral in quant file
DhruvaBansal00 Aug 13, 2024
ea3cf18
Removing intemediate tensors from forward
DhruvaBansal00 Aug 13, 2024
4f6b4ca
load weights from quant
DhruvaBansal00 Aug 13, 2024
7ec27d9
Mixtral load weights change:
DhruvaBansal00 Aug 13, 2024
aa1fe77
none shard id change
DhruvaBansal00 Aug 13, 2024
ae8fb15
Use class from mixtral_quant
DhruvaBansal00 Aug 15, 2024
b863981
Removing lora from mixtral model init
DhruvaBansal00 Aug 15, 2024
5556d28
Adding empty intermediate tensors
DhruvaBansal00 Aug 15, 2024
c484a37
Building quantMixtralModel
DhruvaBansal00 Aug 15, 2024
0344e72
fused moe test
DhruvaBansal00 Aug 15, 2024
8c8b3fa
Lora enabled mixtral
DhruvaBansal00 Aug 15, 2024
dff59cd
LoRAMixtralModel compat
DhruvaBansal00 Aug 15, 2024
33f7e51
remove prefix
DhruvaBansal00 Aug 15, 2024
fdba917
use fused moe
DhruvaBansal00 Aug 15, 2024
780471e
remove org num embeddings
DhruvaBansal00 Aug 15, 2024
c0970f1
pass use fused moe into decoder
DhruvaBansal00 Aug 15, 2024
6a1a838
Mixtral for causal lm load func
DhruvaBansal00 Aug 15, 2024
5c3e857
Copying over quant mixtral
DhruvaBansal00 Aug 15, 2024
8d327de
Passing prefix
DhruvaBansal00 Aug 15, 2024
d337aea
Weight load
DhruvaBansal00 Aug 15, 2024
379f3e8
Weight load back
DhruvaBansal00 Aug 15, 2024
a5d356e
Load with name not weight name
DhruvaBansal00 Aug 15, 2024
62c0135
params dict should load from old name
DhruvaBansal00 Aug 15, 2024
d23c00c
logging name and parmas
DhruvaBansal00 Aug 15, 2024
6dda447
log expert parmas map
DhruvaBansal00 Aug 15, 2024
67ce7b6
parity with prev commits
DhruvaBansal00 Aug 15, 2024
bd933c9
Adding qzeros to mapping
DhruvaBansal00 Aug 15, 2024
77cd095
Remove log
DhruvaBansal00 Aug 15, 2024
529191e
Remove is quantized
DhruvaBansal00 Aug 15, 2024
2450543
Assume fused true
DhruvaBansal00 Aug 15, 2024
8cba45e
rm fused true
DhruvaBansal00 Aug 15, 2024
10940a5
Switching to mixtral moe
DhruvaBansal00 Aug 15, 2024
895ffbe
Precision changes
DhruvaBansal00 Aug 15, 2024
e54b2e4
Cleanup
DhruvaBansal00 Aug 15, 2024
b4f23dc
Mixtral quant parity:
DhruvaBansal00 Aug 15, 2024
d59fe3b
fixing tests
DhruvaBansal00 Aug 15, 2024
0d9cbdc
Tests working and correctness verified
DhruvaBansal00 Aug 15, 2024
112aa40
Formating
DhruvaBansal00 Aug 15, 2024
1ca9098
Moving single marlin alongside fused marlin
DhruvaBansal00 Aug 19, 2024
4d41425
Removing unused imports
DhruvaBansal00 Aug 19, 2024
4907f43
single marlin moe import
DhruvaBansal00 Aug 19, 2024
8f4648c
Merge branch 'main' into marlin-moe-integration
ElizaWszola Aug 20, 2024
8225037
Merge branch 'marlin-moe-integration' into gptq-marlin-refactor
ElizaWszola Aug 20, 2024
315e3b6
Unify shard_id to be of str w[1-3] format
ElizaWszola Aug 21, 2024
34bb5b0
Merge pull request #4 from DhruvaBansal00/gptq-marlin-refactor
ElizaWszola Aug 22, 2024
fd4bb21
Merge branch 'main' into marlin-moe-integration
ElizaWszola Aug 22, 2024
7956a69
Unfused codepath for non-supported quant_types
ElizaWszola Aug 26, 2024
2511f78
uint8b128 support
ElizaWszola Aug 28, 2024
f875842
Merge branch 'main' into marlin-moe-integration
ElizaWszola Aug 29, 2024
d8feb8d
Cleanup, compressed tensors compatibility
ElizaWszola Aug 29, 2024
3676621
update todo
ElizaWszola Aug 29, 2024
75e3dd5
Fix merge
ElizaWszola Aug 30, 2024
a5f5a74
bad paste
ElizaWszola Aug 30, 2024
e305306
GPTQFusedMoE layer
ElizaWszola Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Mixtral quant parity:
  • Loading branch information
DhruvaBansal00 committed Aug 15, 2024
commit b4f23dc6b0a4677fc3bf137576afe93e25b2184b
241 changes: 0 additions & 241 deletions vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
@@ -361,75 +361,6 @@ def forward(
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states



class LoRAMixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, use_fused_moe=True, cache_config=cache_config, quant_config=quant_config
),
prefix=f"{prefix}.layers",
)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class MixtralForCausalLM(nn.Module):
fall_back_to_pt_during_load = False

@@ -488,178 +419,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "v_proj", "v"),
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if self.use_fused_moe:
if ("block_sparse_moe.experts." in name
and ".w1." not in name and ".w2." not in name
and ".w3." not in name
and name not in params_dict):
continue

if (".qzeros" in name):
continue

shard_id = None
expert_id = 0

has_any_numbered = (".qweight" in name or ".scales" in name
or ".g_idx" in name)
if (has_any_numbered and (".w1." in name)):
name = name.replace(".w1.", ".w13_")
shard_id = 0
if (has_any_numbered and (".w2." in name)):
name = name.replace(".w2.", ".w2_")
shard_id = 0
if (has_any_numbered and (".w3." in name)):
name = name.replace(".w3.", ".w13_")
shard_id = 1

exp_string = re.search(r"\.experts\.\d+.", name)
if exp_string:
exp_string = exp_string.group(0)
expert_id = int(exp_string.split(".")[2])
name = name.replace(exp_string, ".experts.")

else:
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue

param = params_dict[name]

if self.use_fused_moe and shard_id is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight, name, shard_id,
expert_id, True)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class LoRAEnabledMixtralForCausalLM(nn.Module, SupportsLoRA):
fall_back_to_pt_during_load = False

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]

def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()

self.config = config
self.lora_config = lora_config
self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn)
self.model = LoRAMixtralModel(
config=config, cache_config=cache_config, quant_config=quant_config, lora_config=lora_config, prefix="model"
)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.sampler = Sampler()

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata, intermediate_tensors
)
return hidden_states

def compute_logits(
self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
return logits

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
return IntermediateTensors(
{
"hidden_states": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
"residual": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
}
)

def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: