-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathsample_trajectory.py
75 lines (58 loc) · 2.13 KB
/
sample_trajectory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import gym
import numpy as np
from network_models.policy_net import Policy_net
import tensorflow as tf
# noinspection PyTypeChecker
def open_file_and_save(file_path, data):
"""
:param file_path: type==string
:param data:
"""
try:
with open(file_path, 'ab') as f_handle:
np.savetxt(f_handle, data, fmt='%s')
except FileNotFoundError:
with open(file_path, 'wb') as f_handle:
np.savetxt(f_handle, data, fmt='%s')
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='filename of model to test', default='trained_models/ppo/model.ckpt')
parser.add_argument('--iteration', default=10, type=int)
return parser.parse_args()
def main(args):
env = gym.make('CartPole-v0')
env.seed(0)
ob_space = env.observation_space
Policy = Policy_net('policy', env)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, args.model)
obs = env.reset()
for iteration in range(args.iteration): # episode
observations = []
actions = []
run_steps = 0
while True:
run_steps += 1
# prepare to feed placeholder Policy.obs
obs = np.stack([obs]).astype(dtype=np.float32)
act, _ = Policy.act(obs=obs, stochastic=True)
act = np.asscalar(act)
observations.append(obs)
actions.append(act)
next_obs, reward, done, info = env.step(act)
if done:
print(run_steps)
obs = env.reset()
break
else:
obs = next_obs
observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape))
actions = np.array(actions).astype(dtype=np.int32)
open_file_and_save('trajectory/observations.csv', observations)
open_file_and_save('trajectory/actions.csv', actions)
if __name__ == '__main__':
args = argparser()
main(args)