-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_parallel.sh
151 lines (130 loc) · 8.33 KB
/
run_parallel.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/bin/bash
run_command() {
# parameters of simulated teachers
eps_mistake=0.3
eps_skip=0.0
eps_equal=0.0
teacher_gamma=1.0
# run which algorithm, in [sac, pebble, surf, rune, mrn, rime]
algorithm='rime'
# parameters that change with env for algorithms
# RIME
unsup_steps=9000 # 2000 for cheetah_run otherwise 9000
# MRN
meta_steps=1000
# RUNE
rho=0.001
# SURF
tau=0.99
envname="walker_walk"
sac_lr=0.0005
num_interact=20000
feedback=1000
reward_batch=100
# envname="cheetah_run"
# sac_lr=0.0005
# num_interact=20000
# feedback=1000
# reward_batch=100
# envname="quadruped_walk"
# sac_lr=0.0001
# num_interact=30000
# feedback=4000
# reward_batch=400
# envname="metaworld_button-press-v2"
# sac_lr=0.0003
# num_interact=5000
# feedback=20000
# reward_batch=100
# envname="metaworld_sweep-into-v2"
# sac_lr=0.0003
# num_interact=5000
# feedback=20000
# reward_batch=100
# envname="metaworld_hammer-v2"
# sac_lr=0.0003
# num_interact=5000
# feedback=80000
# reward_batch=400
seed=$1
device="cuda:${gpu_ids[$seed]}"
# SAC
if [ "${algorithm,,}" == "sac" ]; then
case "$envname" in
*metaworld*)
python train_SAC.py --device=$device --env="$envname" --seed=$seed --actor_lr=$sac_lr --critic_lr=$sac_lr --steps=1000000 --batch_size=512 --critic_hidden_dim=256 --critic_hidden_depth=3 --actor_hidden_dim=256 --actor_hidden_depth=3 > "./SAC_env_"$envname"_seed_"$seed".log" 2>&1
;;
*)
python train_SAC.py --device=$device --env="$envname" --seed=$seed --actor_lr=$sac_lr --critic_lr=$sac_lr --steps=1000000 > "./SAC_env_"$envname"_seed_"$seed".log" 2>&1
;;
esac
# MRN
elif [ "${algorithm,,}" == "mrn" ]; then
case "$envname" in
*metaworld*)
python train_MRN.py --env="$envname" --seed=$seed --actor_lr=$sac_lr --critic_lr=$sac_lr --batch_size=512 --critic_hidden_dim=256 --critic_hidden_depth=3 --actor_hidden_dim=256 --actor_hidden_depth=3 --unsup_steps=9000 --steps=1000000 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch=$reward_batch --reward_update=10 --feed_type=1 --meta_steps=$meta_steps --device=$device --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" > "./MRN_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
*)
python train_MRN.py --env="$envname" --seed=$seed --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch=$reward_batch --reward_update=50 --feed_type=1 --meta_steps=$meta_steps --device=$device --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" > "./MRN_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
esac
# PEBBLE
elif [ "${algorithm,,}" == "pebble" ]; then
case "$envname" in
*metaworld*)
python train_PEBBLE.py --device=$device --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --batch_size=512 --critic_hidden_dim=256 --critic_hidden_depth=3 --actor_hidden_dim=256 --actor_hidden_depth=3 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch="$reward_batch" --reward_update=10 --feed_type=1 --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" > "./PEBBLE_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
*)
python train_PEBBLE.py --device=$device --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch="$reward_batch" --reward_update=50 --feed_type=1 --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" > "./PEBBLE_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
esac
# RIME
elif [ "${algorithm,,}" == "rime" ]; then
case "$envname" in
*metaworld*)
python train_RIME.py --device=$device --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --batch_size=512 --critic_hidden_dim=256 --critic_hidden_depth=3 --actor_hidden_dim=256 --actor_hidden_depth=3 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch="$reward_batch" --reward_update=10 --feed_type=1 --eps_mistake="$eps_mistake" --least_reward_update=5 --threshold_variance='kl' --threshold_alpha=0.5 --threshold_beta_init=3.0 --threshold_beta_min=1.0 --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" > "./RIME_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
*)
python train_RIME.py --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=$unsup_steps --steps=1000000 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch=$reward_batch --reward_update=50 --feed_type=1 --device="$device" --eps_mistake="$eps_mistake" --least_reward_update=15 --threshold_variance='kl' --threshold_alpha=0.5 --threshold_beta_init=3.0 --threshold_beta_min=1.0 --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" > "./RIME_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
esac
# SURF
elif [ "${algorithm,,}" == "surf" ]; then
case "$envname" in
*metaworld*)
python train_SURF.py --device=$device --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --batch_size=512 --critic_hidden_dim=256 --critic_hidden_depth=3 --actor_hidden_dim=256 --actor_hidden_depth=3 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch="$reward_batch" --reward_update=20 --feed_type=1 --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" --inv_label_ratio=10 --threshold_u=$tau --mu=4 > "./SURF_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
*)
python train_SURF.py --device=$device --env="$envname" --seed=$seed --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch=$reward_batch --inv_label_ratio=100 --reward_update=1000 --feed_type=1 --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" --threshold_u=$tau --mu=4 > "./SURF_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
esac
# RUNE
else
case "$envname" in
*metaworld*)
python train_RUNE.py --device=$device --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --batch_size=512 --critic_hidden_dim=256 --critic_hidden_depth=3 --actor_hidden_dim=256 --actor_hidden_depth=3 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch="$reward_batch" --reward_update=10 --feed_type=1 --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" --rho=$rho > "./RUNE_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
*)
python train_RUNE.py --device=$device --env="$envname" --seed="$seed" --actor_lr=$sac_lr --critic_lr=$sac_lr --unsup_steps=9000 --steps=1000000 --num_interact=$num_interact --max_feedback="$feedback" --reward_batch="$reward_batch" --reward_update=50 --feed_type=1 --eps_mistake="$eps_mistake" --eps_skip="$eps_skip" --eps_equal="$eps_equal" --teacher_gamma="$teacher_gamma" --rho=$rho > "./RUNE_env_"$envname"_mistake_"$eps_mistake"_seed_"$seed".log" 2>&1
;;
esac
fi
}
# Determine which GPU each seed is assigned to
declare -A gpu_ids=(
[12345]=0
[23451]=1
[34512]=2
[45123]=3
[51234]=4
[67890]=5
[68790]=6
[78906]=7
[89067]=0
[90678]=1
)
seeds=(12345 23451 34512 45123 51234 67890 68790 78906 89067 90678)
for seed in "${seeds[@]}"; do
run_command "$seed" &
done
wait