Skip to content

Commit

Permalink
Fix PP issue
Browse files Browse the repository at this point in the history
  • Loading branch information
billishyahao committed Aug 7, 2024
1 parent d29eb87 commit 738e3e1
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 170 deletions.
2 changes: 1 addition & 1 deletion examples_deepspeed/finetune_hf_llama/finetune_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ DS_CONFIG=./examples_deepspeed/finetune_hf_llama/ds_config.json
DATASET_PATH=./examples_deepspeed/finetune_hf_llama/alpaca_data.json
# dataset link: https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json

HF_LLAMA_PATH=/yahao/llama-2-7b-hf/
HF_LLAMA_PATH=/data/llama-2-7b-hf/
# weights link: https://huggingface.co/huggyllama/llama-7b

MICRO_BATCH_SIZE=16
Expand Down
277 changes: 108 additions & 169 deletions tools/hf2megads_weight_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import print_rank_0, get_tokenizer, get_args
from megatron.core import mpu
Expand All @@ -18,6 +19,9 @@
from megatron.training import get_optimizer_param_scheduler
from deepspeed.runtime.utils import see_memory_usage
import deepspeed
import copy
from pathlib import Path



def add_extra_args(parser):
Expand Down Expand Up @@ -126,6 +130,7 @@ def __init__(self, ds_model, hf_model, args, config):
# align layer number
self.ds_model = ds_model
self.hf_model = hf_model
self.hf_dict = {} # for handling pp case when converting mds => hf
self.config = config

self.offset_num = 2
Expand Down Expand Up @@ -165,101 +170,7 @@ def _embedding_refactor(self, pname, p):
)
return new_w

def _embedding_refactor_to_hf(self, pname, p):

if pname == f"{self.mega_lm_head_wnum}.lm_head.weight":
hf_w = self.hf_model.lm_head.weight
elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight":
hf_w = self.hf_model.model.embed_tokens.weight

ds_w = p
with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w)

hf_w.data.copy_(ds_w_all_rank[:hf_w.shape[0], :])

def _direct_refactor_to_hf(self, pname, p, hf_layer=None, subname=None):
ds_w = p
if pname in [f"{self.mega_norm_wnum}.weight"]:
hf_w = self.hf_model.model.norm.weight
elif subname in ["input_layernorm.weight"]:
hf_w = self.hf_model.model.layers[hf_layer].input_layernorm.weight
elif subname in ["post_attention_layernorm.weight"]:
hf_w = self.hf_model.model.layers[hf_layer].post_attention_layernorm.weight

hf_w.data.copy_(ds_w)

def _attn_dense_refactor_to_hf(self, pname, p, hf_layer, subname):
ds_w = p

if subname == "self_attention.dense.weight":
hf_w = self.hf_model.model.layers[hf_layer].self_attn.o_proj.weight
elif subname == "mlp.dense_4h_to_h.weight":
hf_w = self.hf_model.model.layers[hf_layer].mlp.down_proj.weight

with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_last_dim(ds_w)

hf_w.data.copy_(ds_w_all_rank)

def _mlphto4h_dense_refactor_to_hf(self, pname, p, hf_layer):
ds_w = p
hf_g = self.hf_model.model.layers[hf_layer].mlp.gate_proj.weight
hf_u = self.hf_model.model.layers[hf_layer].mlp.up_proj.weight

with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w)

if torch.distributed.get_rank() == 0:
print(f"yahao-dbg: hf_g: {hf_g.shape}, ds_w_all_rank: {ds_w_all_rank.shape}")

ds_w_shape = ds_w_all_rank.shape

ds_w_all_rank = ds_w_all_rank.reshape(self.tp_size, 2, -1, ds_w_shape[-1])

hf_g.data.copy_(ds_w_all_rank[:, 0, :, :].reshape(-1, ds_w_shape[-1]))
hf_u.data.copy_(ds_w_all_rank[:, 1, :, :].reshape(-1, ds_w_shape[-1]))


def _qkv_refactor_to_hf(self, pname, p, hf_layer):

ds_w = p

#TODO(yahao): Need to check function if we are using only 1 GPU.

with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w)

# [np/8, 3, 128, 6144]

# [np, 3, 128, 6144]

# hf_w = self.hf_model.model.layers[hf_layer].attention.query_key_value.weight
hf_q = self.hf_model.model.layers[hf_layer].self_attn.q_proj.weight
hf_k = self.hf_model.model.layers[hf_layer].self_attn.k_proj.weight
hf_v = self.hf_model.model.layers[hf_layer].self_attn.v_proj.weight

oldshape = hf_q.shape

hidden_size = oldshape[-1]

hidden_size_per_attention_head = divide(hidden_size,
self.config.num_attention_heads)

num_attention_heads_per_partition = divide(self.config.num_attention_heads,
self.tp_size)

newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size)


if torch.distributed.get_rank() == 0:
print(f"yahao-dbg: qkv refactor : oldshape: {oldshape}, newshape: {newshape}")

ds_w_out = ds_w_all_rank.reshape(*newshape)

hf_q.data.copy_(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1]))
hf_k.data.copy_(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1]))
hf_v.data.copy_(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1]))


def _direct_refactor(self, pname, p, hf_layer=None, subname=None):
Expand All @@ -273,6 +184,7 @@ def _direct_refactor(self, pname, p, hf_layer=None, subname=None):
f"mega-ds:{pname,p.data.shape}<--hf{hf_name,} {hf_w.shape}")
return new_w


def _qkv_refactor(self, pname, p, hf_layer):
hf_wq_name = f"model.layers.{hf_layer}.self_attn.q_proj.weight"
hf_wk_name = f"model.layers.{hf_layer}.self_attn.k_proj.weight"
Expand Down Expand Up @@ -362,12 +274,10 @@ def _mlphto4h1_refactor(self, pname, p, hf_layer, subname):
)
return new_w

def refactor(self):
def transform_from_hf_to_megds(self):
assert self.is_refactored == False
new_w = None
for pname, p in self.ds_model.named_parameters():
if torch.distributed.get_rank() == 0:
print(f"yahao-dbg: pname: {pname}, p.shape: {p.shape}")

if pname in [
f"{self.mega_emb_wnum}.word_embeddings.weight",
Expand Down Expand Up @@ -406,14 +316,86 @@ def refactor(self):
new_w = None
self.is_refactored = True


def _embedding_refactor_to_hf(self, pname, ds_w):
if pname == f"{self.mega_lm_head_wnum}.lm_head.weight":
hf_w = self.hf_model.lm_head.weight
hf_w_name = "lm_head.weight"
elif pname == f"{self.mega_emb_wnum}.word_embeddings.weight":
hf_w = self.hf_model.model.embed_tokens.weight
hf_w_name = "model.embed_tokens.weight"

with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w)

self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank[:hf_w.shape[0], :])

def _direct_refactor_to_hf(self, pname, ds_w, hf_layer=None, subname=None):
if pname in [f"{self.mega_norm_wnum}.weight"]:
hf_w = self.hf_model.model.norm.weight
hf_w_name = "model.norm.weight"
elif subname in ["input_layernorm.weight"]:
hf_w = self.hf_model.model.layers[hf_layer].input_layernorm.weight
hf_w_name = f"model.layers.{hf_layer}.input_layernorm.weight"
elif subname in ["post_attention_layernorm.weight"]:
hf_w = self.hf_model.model.layers[hf_layer].post_attention_layernorm.weight
hf_w_name = f"model.layers.{hf_layer}.post_attention_layernorm.weight"

self.hf_dict[hf_w_name] = copy.deepcopy(ds_w)

def _attn_dense_refactor_to_hf(self, pname, ds_w, hf_layer, subname):
if subname == "self_attention.dense.weight":
hf_w = self.hf_model.model.layers[hf_layer].self_attn.o_proj.weight
hf_w_name = f"model.layers.{hf_layer}.self_attn.o_proj.weight"
elif subname == "mlp.dense_4h_to_h.weight":
hf_w = self.hf_model.model.layers[hf_layer].mlp.down_proj.weight
hf_w_name = f"model.layers.{hf_layer}.mlp.down_proj.weight"

with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_last_dim(ds_w)

self.hf_dict[hf_w_name] = copy.deepcopy(ds_w_all_rank)

def _mlphto4h_dense_refactor_to_hf(self, pname, ds_w, hf_layer):
hf_g_name = f"model.layers.{hf_layer}.mlp.gate_proj.weight"
hf_u_name = f"model.layers.{hf_layer}.mlp.up_proj.weight"

with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w)

ds_w_shape = ds_w_all_rank.shape
ds_w_all_rank = ds_w_all_rank.reshape(self.tp_size, 2, -1, ds_w_shape[-1])
self.hf_dict[hf_g_name] = copy.deepcopy(ds_w_all_rank[:, 0, :, :].reshape(-1, ds_w_shape[-1]))
self.hf_dict[hf_u_name] = copy.deepcopy(ds_w_all_rank[:, 1, :, :].reshape(-1, ds_w_shape[-1]))


def _qkv_refactor_to_hf(self, pname, ds_w, hf_layer):
with torch.no_grad():
ds_w_all_rank = tensor_parallel.mappings._gather_along_first_dim(ds_w)

hf_q = self.hf_model.model.layers[hf_layer].self_attn.q_proj.weight
hf_k = self.hf_model.model.layers[hf_layer].self_attn.k_proj.weight
hf_v = self.hf_model.model.layers[hf_layer].self_attn.v_proj.weight
hf_q_name = f"model.layers.{hf_layer}.self_attn.q_proj.weight"
hf_k_name = f"model.layers.{hf_layer}.self_attn.k_proj.weight"
hf_v_name = f"model.layers.{hf_layer}.self_attn.v_proj.weight"
oldshape = hf_q.shape
hidden_size = oldshape[-1]
hidden_size_per_attention_head = divide(hidden_size,
self.config.num_attention_heads)
num_attention_heads_per_partition = divide(self.config.num_attention_heads,
self.tp_size)
newshape = (self.tp_size, num_attention_heads_per_partition, 3, hidden_size_per_attention_head, hidden_size)
ds_w_out = ds_w_all_rank.reshape(*newshape)
self.hf_dict[hf_q_name] = copy.deepcopy(ds_w_out[:, :, 0, :, :].reshape(-1, oldshape[-1]))
self.hf_dict[hf_k_name] = copy.deepcopy(ds_w_out[:, :, 1, :, :].reshape(-1, oldshape[-1]))
self.hf_dict[hf_v_name] = copy.deepcopy(ds_w_out[:, :, 2, :, :].reshape(-1, oldshape[-1]))


def transform_from_megads_to_hf(self):
use_gqa = True if self.num_attention_heads != self.num_key_value_heads else False

for pname, p in self.ds_model.named_parameters():

if torch.distributed.get_rank() == 0:
print(f"yahao-dbg: ds_model pname: {pname} , p.shape: {p.shape}")

if pname in [
f"{self.mega_emb_wnum}.word_embeddings.weight",
f"{self.mega_lm_head_wnum}.lm_head.weight",
Expand All @@ -432,7 +414,7 @@ def transform_from_megads_to_hf(self):
if not use_gqa:
self._qkv_refactor_to_hf(pname, p, hf_layer)
else:
#TODO(yahao): Not impl yet ...
#TODO(billishyahao): Not impl yet ...
assert False
elif subname in ["mlp.dense_h_to_4h.weight"]:
self._mlphto4h_dense_refactor_to_hf(pname, p, hf_layer)
Expand All @@ -449,9 +431,6 @@ def transform_from_megads_to_hf(self):
else:
print(f"Unrecognized weight type: {pname}")
raise ValueError(f"Unrecognized weight type: {pname}")

# if torch.distributed.get_rank() == 0:
# print("yahao debug save last #2 stage: " + str(self.hf_model.state_dict()['lm.embed_in.weight']))
self.is_refactored = True

def record_mapping_info(self, record_msg):
Expand All @@ -473,61 +452,6 @@ def inorder_show_record(self):
torch.distributed.barrier()


def convert_hf_to_mega_ds():
"""Build the model."""
args = get_args()
print_rank_0(f'building model ...')
see_memory_usage(f"Before Building Model", force=True)

config = core_transformer_config_from_args(args)
with deepspeed.zero.Init(
data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed and not args.no_pipeline_parallel:
model = GPTModelPipe(config, num_tokentypes=0, parallel_output=True)
else:
raise NotImplementedError("Not implemented")

see_memory_usage(f"After Building Model", force=True)
if torch.distributed.get_rank() < 2:
print(f"{torch.distributed.get_rank()} {model}")

# load and initialize HF weight dict
# print hf weights list & mega-ds weights list
hf_ckpt_dir = args.origin_hf_ckpt_dir
hf_ckpt_num_of_shards = args.hf_ckpt_num_shards
loaded = load_and_print_hf_weight(hf_ckpt_dir, hf_ckpt_num_of_shards)
print_distinct_weights(model)

# refactor weight from hf to mega-ds

cur_refactor = refactor(model, loaded, args, config)
cur_refactor.refactor()
cur_refactor.inorder_show_record()

del loaded

unwrapped_model = unwrap_model([model], (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)

#init model and save
print_rank_0(f"before deepspeed init")
ds_engine, _, _, _ = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=opt_param_scheduler,
mpu=mpu if args.no_pipeline_parallel else None)
print_rank_0(f"after deepspeed init")

print_rank_0(f"mega-ds checkpoint will be saved in {args.save}")
save_checkpoint(0, [ds_engine], optimizer, opt_param_scheduler)
print_rank_0(f"save checkpoint completed")

def load_hf_weights(args, no_init):
if args.load_mode == 'torchbin':
assert no_init == False, "only work with init"
Expand All @@ -538,6 +462,7 @@ def load_hf_weights(args, no_init):
elif args.load_mode == 'auto':
return load_and_print_hf_weight_auto(args.hf_ckpt_dir, no_init)


def convert_ckpt():
"""Build the model."""
args = get_args()
Expand All @@ -560,12 +485,7 @@ def convert_ckpt():
if torch.distributed.get_rank() < 2:
print(f"{torch.distributed.get_rank()} {ds_model}")

# load and initialize HF weight dict
# print hf weights list & mega-ds weights list


# 'torchbin', 'safetensor', 'auto'

hf_model = load_hf_weights(args, no_init=args.to_hf_ckpt)

# print_distinct_weights(hf_model)
Expand All @@ -582,24 +502,43 @@ def convert_ckpt():

if args.to_hf_ckpt:
load_checkpoint([ds_engine], None, None, load_only_weights=True)
print(f"Load deepspeed actual checkpoint completed")
print_rank_0(f"completed to load deepspeed actual checkpoint")

# refactor weight from hf to mega-ds and vice versa

cur_refactor = refactor(ds_model, hf_model, args, config)
if args.to_hf_ckpt:
cur_refactor.transform_from_megads_to_hf()
else:
cur_refactor.refactor()
cur_refactor.transform_from_hf_to_megds()
# cur_refactor.inorder_show_record()

if args.to_hf_ckpt:

save_path = "/yahao/Megatron-DeepSpeed/test-llama-7b-hf/"
if not os.path.exists(save_path):
Path(save_path).mkdir(parents=True, exist_ok=True)
ckpt_per_pp_path = os.path.join(save_path, f"model_pp{mpu.get_pipeline_model_parallel_rank()}.pt")
torch.save(cur_refactor.hf_dict, ckpt_per_pp_path)

if torch.distributed.is_initialized():
torch.distributed.barrier()

print_rank_0(f"hf checkpoint will be saved in {save_path}/release ")
if mpu.is_pipeline_last_stage():
## doing checkpoint merging and saving...
# hf_model.tie_weights()

all_wei = {}
for pprank in range(mpu.get_pipeline_model_parallel_world_size()):
ckpt_per_pp_path = os.path.join(save_path, f"model_pp{pprank}.pt")
partial_wei = torch.load(ckpt_per_pp_path)
all_wei = all_wei | partial_wei

hf_model.load_state_dict(all_wei)

# mega-ds checkpoint will be saved in args.save
print_rank_0(f"mega-ds checkpoint will be saved in /yahao/Megatron-DeepSpeed/test-llama-7b-hf ")
hf_model.save_pretrained("/yahao/Megatron-DeepSpeed/test-llama-7b-hf", safe_serialization=True)
hf_model.save_pretrained(os.path.join(save_path, "release"), safe_serialization=True)
else:
print_rank_0(f"mega-ds checkpoint will be saved in {args.save}")
save_checkpoint(0, [ds_engine], None, None)
Expand Down

0 comments on commit 738e3e1

Please sign in to comment.