-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_expert.py
executable file
·80 lines (69 loc) · 2.5 KB
/
run_expert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
"""
Code to load an expert policy and generate roll-out data for behavioral cloning.
Example usage:
python run_expert.py experts/Humanoid-v1.pkl Humanoid-v1 --render \
--num_rollouts 20
Author of this script and included expert policies: Jonathan Ho (hoj@openai.com)
"""
import json
import pickle
import tensorflow as tf
import numpy as np
import tf_util
import gym
import load_policy
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('expert_policy_file', type=str)
parser.add_argument('envname', type=str)
parser.add_argument('--render', action='store_true')
parser.add_argument("--max_timesteps", type=int)
parser.add_argument('--num_rollouts', type=int, default=20,
help='Number of expert roll outs')
parser.add_argument('--save', action='store_true',
help='Save the rollouts to file')
args = parser.parse_args()
print('loading and building expert policy')
policy_fn = load_policy.load_policy(args.expert_policy_file)
print('loaded and built')
with tf.Session():
tf_util.initialize()
import gym
env = gym.make(args.envname)
max_steps = args.max_timesteps or env.spec.timestep_limit
expert_data = []
for i in range(args.num_rollouts):
returns = []
observations = []
actions = []
print('iter', i)
obs = env.reset()
done = False
totalr = 0.
steps = 0
while not done:
action = policy_fn(obs[None,:])
observations.append(obs)
actions.append(action)
obs, r, done, _ = env.step(action)
totalr += r
steps += 1
if args.render:
env.render()
if steps % 100 == 0: print("%i/%i"%(steps, max_steps))
if steps >= max_steps:
break
returns.append(totalr)
expert_data.append({'observations': np.array(observations),
'actions': np.array(actions),
'returns': np.array(returns)})
print('returns', returns)
print('mean return', np.mean(returns))
print('std of return', np.std(returns))
if (args.save == True):
with open('rollouts/'+args.envname+'.pkl', 'wb') as f:
pickle.dump(expert_data, f)
if __name__ == '__main__':
main()