Skip to content

Commit

Permalink
add support converting checkpoint from hf to mds
Browse files Browse the repository at this point in the history
  • Loading branch information
billishyahao committed Jul 28, 2024
1 parent fc989b8 commit d29eb87
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 31 deletions.
8 changes: 1 addition & 7 deletions examples_deepspeed/finetune_hf_llama/ds_config.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
{
"train_batch_size" : 256,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 100,
"zero_optimization": {
"stage": 0
},
"bf16": {
"enabled": true
}
"steps_per_print": 1
}
23 changes: 17 additions & 6 deletions examples_deepspeed/finetune_hf_llama/finetune_llama.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
DS_CONFIG=./examples_deepspeed/finetune_hf_llama/ds_config.json
DATASET_PATH=./alpaca_data.json
DATASET_PATH=./examples_deepspeed/finetune_hf_llama/alpaca_data.json
# dataset link: https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json

HF_LLAMA_PATH=/data/llama-7b/
HF_LLAMA_PATH=/yahao/llama-2-7b-hf/
# weights link: https://huggingface.co/huggyllama/llama-7b

MICRO_BATCH_SIZE=16
Expand Down Expand Up @@ -44,11 +44,20 @@ cat <<EOT > $DS_CONFIG
EOT


covert_args="deepspeed tools/hf2megads_weight_converter.py \
covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \
--hf-ckpt-num-shards 2 \
--origin-hf-ckpt-dir $HF_LLAMA_PATH \
--hf-ckpt-dir $HF_LLAMA_PATH \
--load-mode auto \
--save $MEGA_DS_LLAMA_PATH"

covert_mds2hf_args="deepspeed tools/hf2megads_weight_converter.py \
--hf-ckpt-num-shards 2 \
--hf-ckpt-dir $HF_LLAMA_PATH \
--load-mode auto \
--to-hf-ckpt \
--load $MEGA_DS_LLAMA_PATH \
--save $HF_LLAMA_PATH'-hf-out' "

finetune_args="deepspeed finetune_llama.py \
--load $MEGA_DS_LLAMA_PATH"

Expand Down Expand Up @@ -98,8 +107,10 @@ comm_args="--tensor-model-parallel-size $TP \
--no-gradient-accumulation-fusion \
--repeated-dataloader"

if [ "$1" = "convert" ]; then
task_args="$covert_args"
if [ "$1" = "convert_hf2mds" ]; then
task_args="$covert_hf2mds_args"
elif [ "$1" = "convert_mds2hf" ]; then
task_args="$covert_mds2hf_args"
else
task_args="$finetune_args"
fi
Expand Down
1 change: 1 addition & 0 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _set_wandb_writer(args):
'project or experiment name provided, '
'therefore WANDB logs will be written '
'according to random generated project or experiment name.', flush=True)
return

try:
import wandb
Expand Down
Loading

0 comments on commit d29eb87

Please sign in to comment.