This repository has been archived by the owner on Sep 21, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrollout.py
110 lines (91 loc) · 3.48 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json
import os
from signal import signal, SIGINT
from sys import exit
import imageio
import numpy as np
import robosuite as suite
import torch
from robosuite.controllers import ALL_CONTROLLERS, load_controller_config
from robosuite.wrappers import GymWrapper
from util.arguments import add_rollout_args, parser
from util.rlkit_utils import simulate_policy
os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
# Add and parse arguments
add_rollout_args()
args = parser.parse_args()
# Define callbacks
video_writer = None
def handler(signal_received, frame):
# Handle any cleanup here
print('SIGINT or CTRL-C detected. Closing video writer and exiting gracefully')
video_writer.close()
exit(0)
# Tell Python to run the handler() function when SIGINT is recieved
signal(SIGINT, handler)
if __name__ == "__main__":
# Set random seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Get path to saved model
kwargs_fpath = os.path.join(args.load_dir, "variant.json")
try:
with open(kwargs_fpath) as f:
kwargs = json.load(f)
except FileNotFoundError:
print("Error opening default controller filepath at: {}. "
"Please check filepath and try again.".format(kwargs_fpath))
# Grab / modify env args
env_args = kwargs["eval_environment_kwargs"]
if args.horizon is not None:
env_args["horizon"] = args.horizon
env_args["render_camera"] = args.camera
env_args["hard_reset"] = True
env_args["ignore_done"] = True
# Specify camera name if we're recording a video
if args.record_video:
env_args["camera_names"] = args.camera
env_args["camera_heights"] = 512
env_args["camera_widths"] = 512
# Setup video recorder if necesssary
if args.record_video:
# Grab name of this rollout combo
video_name = "{}-{}-{}".format(
env_args["env_name"], "".join(env_args["robots"]), env_args["controller"]).replace("_", "-")
# Calculate appropriate fps
fps = int(env_args["control_freq"])
# Define video writer
video_writer = imageio.get_writer("{}.mp4".format(video_name), fps=fps)
# Pop the controller
controller = env_args.pop("controller")
if controller in ALL_CONTROLLERS:
controller_config = load_controller_config(default_controller=controller)
else:
controller_config = load_controller_config(custom_fpath=controller)
# Create env
env_suite = suite.make(**env_args,
controller_configs=controller_config,
has_renderer=not args.record_video,
has_offscreen_renderer=args.record_video,
use_object_obs=True,
use_camera_obs=args.record_video,
reward_shaping=True
)
# Make sure we only pass in the proprio and object obs (no images)
keys = ["object-state"]
for idx in range(len(env_suite.robots)):
keys.append(f"robot{idx}_proprio-state")
# Wrap environment so it's compatible with Gym API
env = GymWrapper(env_suite, keys=keys)
# Run rollout
simulate_policy(
env=env,
model_path=os.path.join(args.load_dir, "params.pkl"),
horizon=env_args["horizon"],
render=not args.record_video,
video_writer=video_writer,
num_episodes=args.num_episodes,
printout=True,
use_gpu=args.gpu,
noise_power=args.noise_power
)