-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathworker.py
108 lines (83 loc) · 4.1 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from env import Env
from agent import Agent
from utils import *
from model import PolicyNet
if not os.path.exists(gifs_path):
os.makedirs(gifs_path)
class Worker:
def __init__(self, meta_agent_id, policy_net, global_step, device='cpu', save_image=False):
self.meta_agent_id = meta_agent_id
self.global_step = global_step
self.save_image = save_image
self.device = device
self.env = Env(global_step, plot=self.save_image)
self.robot = Agent(policy_net, self.device, self.save_image)
self.episode_buffer = []
self.perf_metrics = dict()
for i in range(15):
self.episode_buffer.append([])
def run_episode(self):
done = False
self.robot.update_planning_state(self.env.belief_info, self.env.robot_location)
observation = self.robot.get_observation()
if self.save_image:
self.robot.plot_env()
self.env.plot_env(0)
for i in range(MAX_EPISODE_STEP):
self.save_observation(observation)
next_location, action_index = self.robot.select_next_waypoint(observation)
self.save_action(action_index)
node = self.robot.node_manager.nodes_dict.find((self.robot.location[0], self.robot.location[1]))
check = np.array(list(node.data.neighbor_set)).reshape(-1, 2)
assert next_location[0] + next_location[1] * 1j in check[:, 0] + check[:, 1] * 1j, print(next_location, self.robot.location, node.data.neighbor_set)
assert next_location[0] != self.robot.location[0] or next_location[1] != self.robot.location[1]
reward = self.env.step(next_location)
self.robot.update_planning_state(self.env.belief_info, self.env.robot_location)
if self.robot.utility.sum() == 0:
done = True
reward += 20
self.save_reward_done(reward, done)
observation = self.robot.get_observation()
self.save_next_observations(observation)
if self.save_image:
self.robot.plot_env()
self.env.plot_env(i+1)
if done:
break
# save metrics
self.perf_metrics['travel_dist'] = self.env.travel_dist
self.perf_metrics['explored_rate'] = self.env.explored_rate
self.perf_metrics['success_rate'] = done
# save gif
if self.save_image:
make_gif(gifs_path, self.global_step, self.env.frame_files, self.env.explored_rate)
def save_observation(self, observation):
node_inputs, node_padding_mask, edge_mask, current_index, current_edge, edge_padding_mask = observation
self.episode_buffer[0] += node_inputs
self.episode_buffer[1] += node_padding_mask.bool()
self.episode_buffer[2] += edge_mask.bool()
self.episode_buffer[3] += current_index
self.episode_buffer[4] += current_edge
self.episode_buffer[5] += edge_padding_mask.bool()
def save_action(self, action_index):
self.episode_buffer[6] += action_index.reshape(1, 1, 1)
def save_reward_done(self, reward, done):
self.episode_buffer[7] += torch.FloatTensor([reward]).reshape(1, 1, 1).to(self.device)
self.episode_buffer[8] += torch.tensor([int(done)]).reshape(1, 1, 1).to(self.device)
def save_next_observations(self, observation):
node_inputs, node_padding_mask, edge_mask, current_index, current_edge, edge_padding_mask = observation
self.episode_buffer[9] += node_inputs
self.episode_buffer[10] += node_padding_mask.bool()
self.episode_buffer[11] += edge_mask.bool()
self.episode_buffer[12] += current_index
self.episode_buffer[13] += current_edge
self.episode_buffer[14] += edge_padding_mask.bool()
if __name__ == "__main__":
torch.manual_seed(4777)
np.random.seed(4777)
model = PolicyNet(NODE_INPUT_DIM, EMBEDDING_DIM)
# checkpoint = torch.load(model_path + '/checkpoint.pth', map_location='cpu')
# model.load_state_dict(checkpoint['policy_model'])
worker = Worker(0, model, 77, save_image=True)
worker.run_episode()