-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathminigrid_traj_collection_script.py
103 lines (75 loc) · 2.41 KB
/
minigrid_traj_collection_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# import os
import numpy as np
import time
import argparse
import matplotlib.pyplot as plt
import gym
import gym_minigrid
from gym_minigrid import wrappers
from stable_baselines3 import PPO
from imitation.data.types import Trajectory
import pickle5 as pickle
from utils.env_utils import minigrid_get_env, minigrid_render
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
"-e",
help="minigrid gym environment to train on",
default="MiniGrid-LavaCrossingS9N1-v0",
)
parser.add_argument("--run", "-r", help="Run name", default="testing")
parser.add_argument("--save-name", "-s", help="Save name", default="saved_testing")
parser.add_argument(
"--seed", type=int, help="random seed to generate the environment with", default=1
)
parser.add_argument(
"--max-timesteps", "-t", type=int, help="cut traj at max timestep", default=50
)
parser.add_argument(
"--ntraj", type=int, help="number of trajectories to collect", default=10
)
parser.add_argument(
"--flat",
"-f",
default=False,
help="Partially Observable FlatObs or Fully Observable Image ",
action="store_true",
)
parser.add_argument(
"--render", default=False, help="Render", action="store_true",
)
parser.add_argument(
"--best", default=True, help="Use best model", action="store_false",
)
args = parser.parse_args()
env = minigrid_get_env(args.env, 1, args.flat)
best_model_path = "./logs/" + args.env + "/ppo/" + args.run + "/best_model/best_model.zip"
pkl_save_path = "./traj_datasets/" + args.save_name + ".pkl"
model = PPO.load(best_model_path)
traj_dataset = []
for traj in range(args.ntraj):
obs_list = []
action_list = []
obs = env.reset()
obs_list.append(obs[0])
if args.render:
minigrid_render(obs)
for i in range(args.max_timesteps):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
action_list.append(action[0])
obs_list.append(obs[0])
if args.render:
minigrid_render(obs)
if done:
break
traj_dataset.append(
Trajectory(
obs=np.array(obs_list),
acts=np.array(action_list),
infos=np.array([{} for i in action_list]),
)
)
with open(pkl_save_path, "wb") as handle:
pickle.dump(traj_dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
print(f"{len(traj_dataset)} trajectories saved at {pkl_save_path}")