-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain-procgen.py
110 lines (93 loc) · 3.39 KB
/
train-procgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import tensorflow as tf
from baselines.ppo2 import ppo2
from baselines.common.models import build_impala_cnn
from baselines.common.mpi_util import setup_mpi_gpus
from procgen import ProcgenEnv
from baselines.common.vec_env import (
VecExtractDictObs,
VecMonitor,
VecFrameStack,
VecNormalize
)
from baselines import logger
from mpi4py import MPI
import argparse
def train_fn(env_name, num_envs, distribution_mode, num_levels, start_level, timesteps_per_proc, is_test_worker=False, log_dir='/tmp/procgen', comm=None):
learning_rate = 5e-4
ent_coef = .01
gamma = .999
lam = .95
nsteps = 256
nminibatches = 8
ppo_epochs = 3
clip_range = .2
use_vf_clipping = True
mpi_rank_weight = 0 if is_test_worker else 1
num_levels = 0 if is_test_worker else num_levels
if log_dir is not None:
log_comm = comm.Split(1 if is_test_worker else 0, 0)
format_strs = ['csv', 'stdout'] if log_comm.Get_rank() == 0 else []
logger.configure(comm=log_comm, dir=log_dir, format_strs=format_strs)
logger.info("creating environment")
venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level, distribution_mode=distribution_mode)
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(
venv=venv, filename=None, keep_buf=100,
)
venv = VecNormalize(venv=venv, ob=False)
logger.info("creating tf session")
setup_mpi_gpus()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True #pylint: disable=E1101
sess = tf.Session(config=config)
sess.__enter__()
conv_fn = lambda x: build_impala_cnn(x, depths=[16,32,32], emb_size=256)
logger.info("training")
ppo2.learn(
env=venv,
network=conv_fn,
total_timesteps=timesteps_per_proc,
save_interval=0,
nsteps=nsteps,
nminibatches=nminibatches,
lam=lam,
gamma=gamma,
noptepochs=ppo_epochs,
log_interval=1,
ent_coef=ent_coef,
mpi_rank_weight=mpi_rank_weight,
clip_vf=use_vf_clipping,
comm=comm,
lr=learning_rate,
cliprange=clip_range,
update_fn=None,
init_fn=None,
vf_coef=0.5,
max_grad_norm=0.5,
)
def main():
parser = argparse.ArgumentParser(description='Process procgen training arguments.')
parser.add_argument('--env_name', type=str, default='coinrun')
parser.add_argument('--num_envs', type=int, default=64)
parser.add_argument('--distribution_mode', type=str, default='hard', choices=["easy", "hard", "exploration", "memory", "extreme"])
parser.add_argument('--num_levels', type=int, default=0)
parser.add_argument('--start_level', type=int, default=0)
parser.add_argument('--test_worker_interval', type=int, default=0)
parser.add_argument('--timesteps_per_proc', type=int, default=50_000_000)
args = parser.parse_args()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
is_test_worker = False
test_worker_interval = args.test_worker_interval
if test_worker_interval > 0:
is_test_worker = rank % test_worker_interval == (test_worker_interval - 1)
train_fn(args.env_name,
args.num_envs,
args.distribution_mode,
args.num_levels,
args.start_level,
args.timesteps_per_proc,
is_test_worker=is_test_worker,
comm=comm)
if __name__ == '__main__':
main()