-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge-weights.sh
executable file
·75 lines (67 loc) · 2.12 KB
/
merge-weights.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env bash
TOTAL_UPDATES=20000
PEAK_LR=0.00005
EXPERT_DROPOUT=0.25
SAVE_INTERVAL=1
RAW_INPUT=$1
CKPT_PATH=$(echo "$RAW_INPUT" | sed 's:/*$::')
FOLDER_NAME=$(basename $(dirname $CKPT_PATH))
COMPS=(${FOLDER_NAME//-/ })
moe_type=${COMPS[1]}
size=${COMPS[2]}
size=${COMPS[2]}
if [ "$moe_type" = "gs" ]; then
v_moe_type="aux_gshard"
top_k=2
elif [ "$moe_type" = "st" ]; then
v_moe_type="aux_switch2"
top_k=1
else
echo "Run the code with moe-train.sh {gs/sh} {base/large}"; exit 128;
fi
if [ "$size" = "base" ]; then
MAX_TOKENS=21845
UPDATE_FREQ=96
elif [ "$size" = "large" ]; then
MAX_TOKENS=8192
UPDATE_FREQ=256
else
echo "Run the code with moe-train.sh {gs/sh} {base/large}"; exit 128;
fi
SEED=77
SAVE_PATH="models/merged-${moe_type}-${size}"
mkdir -p $SAVE_PATH
cp $CKPT_PATH $SAVE_PATH/checkpoint_last.pt
ARCH="t5-v1.1-${size}"
MOE_ARCH="t5-moe-v1.1-${size}"
mkdir -p $SAVE_PATH
DATA_WORKERS=0
DATA_DIR=data/fairseq-aux-bin
timestamp=`date "+%Y%0m%0d_%T"`
mkdir -p $SAVE_PATH/logs
me=$(basename "$0")
cp $me $SAVE_PATH/logs/
python train.py $DATA_DIR \
--seed $SEED \
--num-workers $DATA_WORKERS \
--max-epoch 3 --freeze-non-MoE --share-expert-gelu --merge-backbone \
--moe-freq 1 --moe-type $v_moe_type --aux-weight 0.1 \
--gate-logits --gate-hidden-dims "[384]" --gate-class-dim 28 \
--moe-location decoder --expert-dropout $EXPERT_DROPOUT \
--num-experts 7 --share-gate --gate-top-n $top_k --gate-capacity "[4,6]" \
--skip-invalid-size-inputs-valid-test \
--task aux_translation -s ori -t cor -x edit \
--criterion label_smoothed_cross_entropy_for_moe --moe-weight 1 \
--arch $MOE_ARCH \
--reset-optimizer --reset-lr-scheduler --reset-dataloader --reset-meters \
--max-source-positions 128 \
--max-target-positions 128 \
--optimizer adafactor \
--lr $PEAK_LR \
--update-freq $UPDATE_FREQ \
--max-tokens $MAX_TOKENS \
--save-dir $SAVE_PATH \
--disable-validation \
--max-tokens-valid $MAX_TOKENS \
--save-interval-updates $SAVE_INTERVAL \
--max-update $TOTAL_UPDATES --log-format simple --log-interval 1 2>&1 | tee $SAVE_PATH/logs/train-${timestamp}.log