-
Notifications
You must be signed in to change notification settings - Fork 200
/
Copy pathreplay_buffer.py
58 lines (47 loc) · 1.98 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
''' https://github.com/pytorch/tutorials/blob/master/intermediate_source/reinforcement_q_learning.py
https://gist.github.com/Pocuston/13f1a7786648e1e2ff95bfad02a51521
'''
######################################################################
# Replay Memory
# -------------
#
# We'll be using experience replay memory for training our DQN. It stores
# the transitions that the agent observes, allowing us to reuse this data
# later. By sampling from it randomly, the transitions that build up a
# batch are decorrelated. It has been shown that this greatly stabilizes
# and improves the DQN training procedure.
#
# For this, we're going to need two classses:
#
# - ``Transition`` - a named tuple representing a single transition in
# our environment. It essentially maps (state, action) pairs
# to their (next_state, reward) result, with the state being the
# screen difference image as described later on.
# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
# transitions observed recently. It also implements a ``.sample()``
# method for selecting a random batch of transitions for training.
#
import random
from collections import namedtuple
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
'''
def push(self, *.args): ## *args):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
'''
def push(self, batch):
self.memory.append(batch)
if len(self.memory) > self.capacity:
del self.memory[0]
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)