1
1
# Replay buffer that store full episodes
2
2
3
- mutable struct EpisodeReplayBuffer{N<: Integer , T<: AbstractFloat , CI, Q}
3
+ mutable struct EpisodeReplayBuffer{N<: Integer , T<: Real , CI, Q<: AbstractArray{T} ,A <: AbstractArray{T} }
4
4
max_size:: Int64
5
5
batch_size:: Int64
6
6
trace_length:: Int64
@@ -9,43 +9,43 @@ mutable struct EpisodeReplayBuffer{N<:Integer, T<:AbstractFloat, CI, Q}
9
9
_idx:: Int64
10
10
_experience:: Vector{Vector{DQExperience{N,T,Q}}}
11
11
12
- _s_batch:: Vector{Array{T} }
12
+ _s_batch:: Vector{A }
13
13
_a_batch:: Vector{Vector{CI}}
14
- _r_batch:: Vector{Array {T}}
15
- _sp_batch:: Vector{Array{T} }
16
- _done_batch:: Vector{Array {T}}
17
- _trace_mask:: Vector{Array {N}}
14
+ _r_batch:: Vector{Vector {T}}
15
+ _sp_batch:: Vector{A }
16
+ _done_batch:: Vector{Vector {T}}
17
+ _trace_mask:: Vector{Vector {N}}
18
18
_episode:: Vector{DQExperience{N,T,Q}}
19
19
end
20
20
21
- function EpisodeReplayBuffer (env:: AbstractEnvironment ,
21
+ function EpisodeReplayBuffer (env:: AbstractEnvironment{OV} ,
22
22
max_size:: Int64 ,
23
23
batch_size:: Int64 ,
24
24
trace_length:: Int64 ,
25
- rng:: AbstractRNG = MersenneTwister (0 ))
25
+ rng:: AbstractRNG = MersenneTwister (0 )) where {OV}
26
26
s_dim = obs_dimensions (env)
27
27
Q = length (s_dim)
28
- experience = Vector {Vector{DQExperience{Int32, Float32, Q }}} (undef, max_size)
28
+ experience = Vector {Vector{DQExperience{Int32, Float32, OV }}} (undef, max_size)
29
29
_s_batch = [zeros (Float32, s_dim... , batch_size) for i= 1 : trace_length]
30
30
_a_batch = [[CartesianIndex (1 ,1 ) for i= 1 : batch_size] for i= 1 : trace_length]
31
31
_r_batch = [zeros (Float32, batch_size) for i= 1 : trace_length]
32
32
_sp_batch = [zeros (Float32, s_dim... , batch_size) for i= 1 : trace_length]
33
33
_done_batch = [zeros (Float32, batch_size) for i= 1 : trace_length]
34
34
_trace_mask = [zeros (Int32, batch_size) for i= 1 : trace_length]
35
- _episode = Vector {DQExperience{Int32, Float32, Q }} ()
36
- return EpisodeReplayBuffer {Int32, Float32, CartesianIndex{2}, Q} (max_size, batch_size, trace_length, rng, 0 , 1 , experience,
35
+ _episode = Vector {DQExperience{Int32, Float32, OV }} ()
36
+ return EpisodeReplayBuffer (max_size, batch_size, trace_length, rng, 0 , 1 , experience,
37
37
_s_batch, _a_batch, _r_batch, _sp_batch, _done_batch, _trace_mask, _episode)
38
38
end
39
39
40
40
is_full (r:: EpisodeReplayBuffer ) = r. _curr_size == r. max_size
41
41
42
42
max_size (r:: EpisodeReplayBuffer ) = r. max_size
43
43
44
- function add_exp! (r:: EpisodeReplayBuffer{N, T } , exp:: DQExperience ) where {N, T }
44
+ function add_exp! (r:: EpisodeReplayBuffer{N,T,CI,Q } , exp:: DQExperience ) where {N,T,CI,Q }
45
45
push! (r. _episode, exp)
46
46
if exp. done
47
47
add_episode! (r, r. _episode)
48
- r. _episode = Vector {DQExperience{N, T }} ()
48
+ r. _episode = Vector {DQExperience{N,T,Q }} ()
49
49
end
50
50
end
51
51
@@ -73,30 +73,22 @@ function StatsBase.sample(r::EpisodeReplayBuffer)
73
73
sample_indices = sample (r. rng, 1 : r. _curr_size, r. batch_size, replace= false )
74
74
@assert length (sample_indices) == size (r. _s_batch[1 ])[end ]
75
75
s_batch_size = size (first (r. _s_batch))
76
- for t= 1 : r. trace_length
77
- r. _s_batch[t] = reshape (r. _s_batch[t], (:, r. batch_size))
78
- r. _sp_batch[t] = reshape (r. _sp_batch[t], (:, r. batch_size))
79
- end
80
76
for (i, idx) in enumerate (sample_indices)
81
77
ep = r. _experience[idx]
82
78
# randomized start TODO add as an option of the buffer
83
79
ep_start = rand (r. rng, 1 : length (ep))
84
80
t = 1
85
81
for j= ep_start: min (length (ep), r. trace_length)
86
82
expe = ep[t]
87
- r. _s_batch[t][: , i] = vec (expe. s)
83
+ r. _s_batch[t][.. , i] = vec (expe. s)
88
84
r. _a_batch[t][i] = CartesianIndex (expe. a, i)
89
85
r. _r_batch[t][i] = expe. r
90
- r. _sp_batch[t][: , i] = vec (expe. sp)
86
+ r. _sp_batch[t][.. , i] = vec (expe. sp)
91
87
r. _done_batch[t][i] = expe. done
92
88
r. _trace_mask[t][i] = 1
93
89
t += 1
94
90
end
95
91
end
96
- for t= 1 : r. trace_length
97
- r. _s_batch[t] = reshape (r. _s_batch[t], s_batch_size)
98
- r. _sp_batch[t] = reshape (r. _sp_batch[t], s_batch_size)
99
- end
100
92
return r. _s_batch, r. _a_batch, r. _r_batch, r. _sp_batch, r. _done_batch, r. _trace_mask
101
93
end
102
94
@@ -112,9 +104,9 @@ function populate_replay_buffer!(r::EpisodeReplayBuffer,
112
104
@assert r. _curr_size >= r. batch_size
113
105
end
114
106
115
- function generate_episode (env:: AbstractEnvironment , action_indices; max_steps:: Int64 = 100 )
107
+ function generate_episode (env:: AbstractEnvironment{OV} , action_indices; max_steps:: Int64 = 100 ) where OV
116
108
s_dim = obs_dimensions (env)
117
- episode = DQExperience{Int32, Float32, length (s_dim) }[]
109
+ episode = DQExperience{Int32, Float32, OV }[]
118
110
sizehint! (episode, max_steps)
119
111
# start simulation
120
112
o = reset! (env)
0 commit comments