-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathworker.py
123 lines (99 loc) · 4.46 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# =====================================
# @Time : 2020/6/10
# @Author : Yang Guan (Tsinghua Univ.)
# @FileName: worker.py
# =====================================
import logging
import gym
import numpy as np
from preprocessor import Preprocessor
from utils.dummy_vec_env import DummyVecEnv
from utils.misc import judge_is_nan
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# logger.setLevel(logging.INFO)
class OffPolicyWorker(object):
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)
"""just for sample"""
def __init__(self, policy_cls, env_id, args, worker_id):
logging.getLogger("tensorflow").setLevel(logging.ERROR)
self.worker_id = worker_id
self.args = args
self.num_agent = self.args.num_agent
if self.args.env_id == 'PathTracking-v0':
self.env = gym.make(self.args.env_id, num_agent=self.num_agent, num_future_data=self.args.num_future_data)
else:
env = gym.make(self.args.env_id)
self.env = DummyVecEnv(env)
self.policy_with_value = policy_cls(**vars(self.args))
self.batch_size = self.args.batch_size
self.obs = self.env.reset()
self.done = False
self.preprocessor = Preprocessor(**vars(self.args))
self.explore_sigma = self.args.explore_sigma
self.iteration = 0
self.num_sample = 0
self.sample_times = 0
self.stats = {}
logger.info('Worker initialized')
def get_stats(self):
self.stats.update(dict(worker_id=self.worker_id,
num_sample=self.num_sample,
# ppc_params=self.get_ppc_params()
)
)
return self.stats
def save_weights(self, save_dir, iteration):
self.policy_with_value.save_weights(save_dir, iteration)
def load_weights(self, load_dir, iteration):
self.policy_with_value.load_weights(load_dir, iteration)
def get_weights(self):
return self.policy_with_value.get_weights()
def set_weights(self, weights):
return self.policy_with_value.set_weights(weights)
def apply_gradients(self, iteration, grads):
self.iteration = iteration
self.policy_with_value.apply_gradients(self.tf.constant(iteration, dtype=self.tf.int32), grads)
def get_ppc_params(self):
return self.preprocessor.get_params()
def set_ppc_params(self, params):
self.preprocessor.set_params(params)
def save_ppc_params(self, save_dir):
self.preprocessor.save_params(save_dir)
def load_ppc_params(self, load_dir):
self.preprocessor.load_params(load_dir)
def sample(self):
batch_data = []
for _ in range(int(self.batch_size/self.num_agent)):
processed_obs = self.preprocessor.process_obs(self.obs)
judge_is_nan([processed_obs])
action, logp = self.policy_with_value.compute_action(self.tf.constant(processed_obs))
if self.explore_sigma is not None:
action += np.random.normal(0, self.explore_sigma, np.shape(action))
try:
judge_is_nan([action])
except ValueError:
print('processed_obs', processed_obs)
print('preprocessor_params', self.preprocessor.get_params())
print('policy_weights', self.policy_with_value.policy.trainable_weights)
action, logp = self.policy_with_value.compute_action(processed_obs)
judge_is_nan([action])
raise ValueError
obs_tp1, reward, self.done, info = self.env.step(action.numpy())
processed_rew = self.preprocessor.process_rew(reward, self.done)
for i in range(self.num_agent):
batch_data.append((self.obs[i].copy(), action[i].numpy(), reward[i], obs_tp1[i].copy(), self.done[i]))
self.obs = self.env.reset()
if self.worker_id == 1 and self.sample_times % self.args.worker_log_interval == 0:
logger.info('Worker_info: {}'.format(self.get_stats()))
self.num_sample += len(batch_data)
self.sample_times += 1
return batch_data
def sample_with_count(self):
batch_data = self.sample()
return batch_data, len(batch_data)