Skip to content

Commit d4cf43e

Browse files
authored
Merge pull request #33 from JuliaPOMDP/abstractarrays
Support AbstractArray in DQExperience
2 parents 1935049 + a5b0acc commit d4cf43e

6 files changed

+72
-50
lines changed

Project.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "DeepQLearning"
22
uuid = "de0a67f4-c691-11e8-0034-5fc6e16e22d3"
33
repo = "https://github.com/JuliaPOMDP/DeepQLearning.jl"
4-
version = "0.4.5"
4+
version = "0.4.6"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
8+
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
@@ -21,13 +22,14 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
2122
Flux = "0.10"
2223
POMDPPolicies = "0.2.1"
2324
POMDPs = "0.7.3, 0.8"
24-
RLInterface = "0.3.2"
25+
RLInterface = "0.3.6"
2526
julia = "1"
2627

2728
[extras]
2829
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
2930
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
31+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3032
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3133

3234
[targets]
33-
test = ["POMDPModels", "POMDPSimulators", "Test"]
35+
test = ["POMDPModels", "POMDPSimulators", "StaticArrays", "Test"]

src/DeepQLearning.jl

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using POMDPPolicies
1212
using RLInterface
1313
using LinearAlgebra
1414
using TensorBoardLogger: TBLogger, log_value
15+
using EllipsisNotation
1516

1617
export DeepQLearningSolver,
1718
AbstractNNPolicy,

src/episode_replay.jl

+17-25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Replay buffer that store full episodes
22

3-
mutable struct EpisodeReplayBuffer{N<:Integer, T<:AbstractFloat, CI, Q}
3+
mutable struct EpisodeReplayBuffer{N<:Integer, T<:Real, CI, Q<:AbstractArray{T},A<:AbstractArray{T}}
44
max_size::Int64
55
batch_size::Int64
66
trace_length::Int64
@@ -9,43 +9,43 @@ mutable struct EpisodeReplayBuffer{N<:Integer, T<:AbstractFloat, CI, Q}
99
_idx::Int64
1010
_experience::Vector{Vector{DQExperience{N,T,Q}}}
1111

12-
_s_batch::Vector{Array{T}}
12+
_s_batch::Vector{A}
1313
_a_batch::Vector{Vector{CI}}
14-
_r_batch::Vector{Array{T}}
15-
_sp_batch::Vector{Array{T}}
16-
_done_batch::Vector{Array{T}}
17-
_trace_mask::Vector{Array{N}}
14+
_r_batch::Vector{Vector{T}}
15+
_sp_batch::Vector{A}
16+
_done_batch::Vector{Vector{T}}
17+
_trace_mask::Vector{Vector{N}}
1818
_episode::Vector{DQExperience{N,T,Q}}
1919
end
2020

21-
function EpisodeReplayBuffer(env::AbstractEnvironment,
21+
function EpisodeReplayBuffer(env::AbstractEnvironment{OV},
2222
max_size::Int64,
2323
batch_size::Int64,
2424
trace_length::Int64,
25-
rng::AbstractRNG = MersenneTwister(0))
25+
rng::AbstractRNG = MersenneTwister(0)) where {OV}
2626
s_dim = obs_dimensions(env)
2727
Q = length(s_dim)
28-
experience = Vector{Vector{DQExperience{Int32, Float32, Q}}}(undef, max_size)
28+
experience = Vector{Vector{DQExperience{Int32, Float32, OV}}}(undef, max_size)
2929
_s_batch = [zeros(Float32, s_dim..., batch_size) for i=1:trace_length]
3030
_a_batch = [[CartesianIndex(1,1) for i=1:batch_size] for i=1:trace_length]
3131
_r_batch = [zeros(Float32, batch_size) for i=1:trace_length]
3232
_sp_batch = [zeros(Float32, s_dim..., batch_size) for i=1:trace_length]
3333
_done_batch = [zeros(Float32, batch_size) for i=1:trace_length]
3434
_trace_mask = [zeros(Int32, batch_size) for i=1:trace_length]
35-
_episode = Vector{DQExperience{Int32, Float32, Q}}()
36-
return EpisodeReplayBuffer{Int32, Float32, CartesianIndex{2}, Q}(max_size, batch_size, trace_length, rng, 0, 1, experience,
35+
_episode = Vector{DQExperience{Int32, Float32, OV}}()
36+
return EpisodeReplayBuffer(max_size, batch_size, trace_length, rng, 0, 1, experience,
3737
_s_batch, _a_batch, _r_batch, _sp_batch, _done_batch, _trace_mask, _episode)
3838
end
3939

4040
is_full(r::EpisodeReplayBuffer) = r._curr_size == r.max_size
4141

4242
max_size(r::EpisodeReplayBuffer) = r.max_size
4343

44-
function add_exp!(r::EpisodeReplayBuffer{N, T}, exp::DQExperience) where {N, T}
44+
function add_exp!(r::EpisodeReplayBuffer{N,T,CI,Q}, exp::DQExperience) where {N,T,CI,Q}
4545
push!(r._episode, exp)
4646
if exp.done
4747
add_episode!(r, r._episode)
48-
r._episode = Vector{DQExperience{N, T}}()
48+
r._episode = Vector{DQExperience{N,T,Q}}()
4949
end
5050
end
5151

@@ -73,30 +73,22 @@ function StatsBase.sample(r::EpisodeReplayBuffer)
7373
sample_indices = sample(r.rng, 1:r._curr_size, r.batch_size, replace=false)
7474
@assert length(sample_indices) == size(r._s_batch[1])[end]
7575
s_batch_size = size(first(r._s_batch))
76-
for t=1:r.trace_length
77-
r._s_batch[t] = reshape(r._s_batch[t], (:, r.batch_size))
78-
r._sp_batch[t] = reshape(r._sp_batch[t], (:, r.batch_size))
79-
end
8076
for (i, idx) in enumerate(sample_indices)
8177
ep = r._experience[idx]
8278
# randomized start TODO add as an option of the buffer
8379
ep_start = rand(r.rng, 1:length(ep))
8480
t = 1
8581
for j=ep_start:min(length(ep), r.trace_length)
8682
expe = ep[t]
87-
r._s_batch[t][:, i] = vec(expe.s)
83+
r._s_batch[t][.., i] = vec(expe.s)
8884
r._a_batch[t][i] = CartesianIndex(expe.a, i)
8985
r._r_batch[t][i] = expe.r
90-
r._sp_batch[t][:, i] = vec(expe.sp)
86+
r._sp_batch[t][.., i] = vec(expe.sp)
9187
r._done_batch[t][i] = expe.done
9288
r._trace_mask[t][i] = 1
9389
t += 1
9490
end
9591
end
96-
for t=1:r.trace_length
97-
r._s_batch[t] = reshape(r._s_batch[t], s_batch_size)
98-
r._sp_batch[t] = reshape(r._sp_batch[t], s_batch_size)
99-
end
10092
return r._s_batch, r._a_batch, r._r_batch, r._sp_batch, r._done_batch, r._trace_mask
10193
end
10294

@@ -112,9 +104,9 @@ function populate_replay_buffer!(r::EpisodeReplayBuffer,
112104
@assert r._curr_size >= r.batch_size
113105
end
114106

115-
function generate_episode(env::AbstractEnvironment, action_indices; max_steps::Int64 = 100)
107+
function generate_episode(env::AbstractEnvironment{OV}, action_indices; max_steps::Int64 = 100) where OV
116108
s_dim = obs_dimensions(env)
117-
episode = DQExperience{Int32, Float32, length(s_dim)}[]
109+
episode = DQExperience{Int32, Float32, OV}[]
118110
sizehint!(episode, max_steps)
119111
# start simulation
120112
o = reset!(env)

src/prioritized_experience_replay.jl

+16-21
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
# Naive implementation
22

3-
struct DQExperience{N <: Real,T <: Real, Q}
4-
s::Array{T, Q}
3+
struct DQExperience{N <: Real,T <: Real, A<:AbstractArray{T}}
4+
s::A
55
a::N
66
r::T
7-
sp::Array{T, Q}
7+
sp::A
88
done::Bool
99
end
1010

1111
function Base.convert(::Type{DQExperience{Int32, Float32, C}}, x::DQExperience{A, B, C}) where {A, B, C}
12-
return DQExperience{Int32, Float32, C}(convert(Array{Float32, C}, x.s),
13-
convert(Int32, x.a),
14-
convert(Float32, x.r),
15-
convert(Array{Float32, C}, x.sp),
16-
x.done)
12+
return DQExperience{Int32, Float32, C}(convert(C, x.s),
13+
convert(Int32, x.a),
14+
convert(Float32, x.r),
15+
convert(C, x.sp),
16+
x.done)
1717
end
1818

19-
mutable struct PrioritizedReplayBuffer{N<:Integer, T<:AbstractFloat,CI, Q}
19+
mutable struct PrioritizedReplayBuffer{N<:Integer, T<:AbstractFloat,CI,Q,A<:AbstractArray{T}}
2020
max_size::Int64
2121
batch_size::Int64
2222
rng::AbstractRNG
@@ -28,23 +28,23 @@ mutable struct PrioritizedReplayBuffer{N<:Integer, T<:AbstractFloat,CI, Q}
2828
_priorities::Vector{T}
2929
_experience::Vector{DQExperience{N,T,Q}}
3030

31-
_s_batch::Array{T}
31+
_s_batch::A
3232
_a_batch::Vector{CI}
3333
_r_batch::Vector{T}
34-
_sp_batch::Array{T}
34+
_sp_batch::A
3535
_done_batch::Vector{T}
3636
_weights_batch::Vector{T}
3737
end
3838

39-
function PrioritizedReplayBuffer(env::AbstractEnvironment,
39+
function PrioritizedReplayBuffer(env::AbstractEnvironment{OV},
4040
max_size::Int64,
4141
batch_size::Int64;
4242
rng::AbstractRNG = MersenneTwister(0),
4343
α::Float32 = 6f-1,
4444
β::Float32 = 4f-1,
45-
ϵ::Float32 = 1f-3)
45+
ϵ::Float32 = 1f-3) where {OV}
4646
s_dim = obs_dimensions(env)
47-
experience = Vector{DQExperience{Int32, Float32, length(s_dim)}}(undef, max_size)
47+
experience = Vector{DQExperience{Int32, Float32, OV}}(undef, max_size)
4848
priorities = Vector{Float32}(undef, max_size)
4949
_s_batch = zeros(Float32, s_dim..., batch_size)
5050
_a_batch = [CartesianIndex(0,0) for i=1:batch_size]
@@ -87,21 +87,16 @@ end
8787

8888
function get_batch(r::PrioritizedReplayBuffer, sample_indices::Vector{Int64})
8989
@assert length(sample_indices) == size(r._s_batch)[end]
90-
s_batch_size = size(r._s_batch)
91-
r._s_batch = reshape(r._s_batch, (:, r.batch_size))
92-
r._sp_batch = reshape(r._sp_batch, (:, r.batch_size))
9390
for (i, idx) in enumerate(sample_indices)
9491
@inbounds begin
95-
r._s_batch[:, i] = vec(r._experience[idx].s)
92+
r._s_batch[.., i] = vec(r._experience[idx].s)
9693
r._a_batch[i] = CartesianIndex(r._experience[idx].a, i)
9794
r._r_batch[i] = r._experience[idx].r
98-
r._sp_batch[:, i] = vec(r._experience[idx].sp)
95+
r._sp_batch[.., i] = vec(r._experience[idx].sp)
9996
r._done_batch[i] = r._experience[idx].done
10097
r._weights_batch[i] = r._priorities[idx]
10198
end
10299
end
103-
r._s_batch = reshape(r._s_batch, s_batch_size)
104-
r._sp_batch = reshape(r._sp_batch, s_batch_size)
105100
p = r._weights_batch ./ sum(r._priorities[1:r._curr_size])
106101
weights = (r._curr_size * p).^(-r.β)
107102
return r._s_batch, r._a_batch, r._r_batch, r._sp_batch, r._done_batch, sample_indices, weights

test/prototype.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Flux
88
using DeepQLearning
99
include("test/test_env.jl")
1010

11-
# mdp = TestMDP((5,5), 4, 6)
11+
mdp = TestMDP((5,5), 4, 6)
1212
# mdp = SimpleGridWorld()
1313
rng = MersenneTwister(1)
1414
mdp = TestMDP((5,5), 1, 6)
@@ -20,8 +20,12 @@ solver = DeepQLearningSolver(batch_size = 128, eval_freq = 10_000, save_freq=10_
2020

2121
@btime policy = solve($solver, $mdp)
2222

23+
2324
policy = solve(solver, mdp)
2425

26+
env = MDPEnvironment(mdp)
27+
o = reset!(env)
28+
2529
using RLInterface
2630
using LinearAlgebra
2731

test/runtests.jl

+28
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using POMDPPolicies
55
using Flux
66
using Random
77
using RLInterface
8+
using StaticArrays
89
using Test
910
Random.seed!(7)
1011
GLOBAL_RNG = MersenneTwister(1) # for test consistency
@@ -110,3 +111,30 @@ end
110111
policy = solve(solver, pomdp)
111112
@test size(actionvalues(policy, true)) == (length(actions(pomdp)),)
112113
end
114+
115+
mutable struct StaticArrayMDP <: MDP{typeof(SVector(1)), Int64}
116+
state::typeof(SVector(1))
117+
end
118+
POMDPs.discount(::StaticArrayMDP) = 0.95f0
119+
POMDPs.initialstate(m::StaticArrayMDP, rng::AbstractRNG) = m.state
120+
121+
function POMDPs.gen(m::StaticArrayMDP, s, a, rng::AbstractRNG)
122+
return (sp=s + SVector(a), r=m.state[1]^2)
123+
end
124+
125+
POMDPs.isterminal(::StaticArrayMDP, s) = s[1] >= 3
126+
POMDPs.actions(::StaticArrayMDP) = [0,1]
127+
128+
129+
@testset "Static Array Env" begin
130+
mdp = StaticArrayMDP(SVector(1))
131+
132+
model = Chain(Dense(1, 32), Dense(32, length(actions(mdp))))
133+
134+
solver = DeepQLearningSolver(qnetwork = model, max_steps=10,
135+
learning_rate=0.005,log_freq=500,
136+
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
137+
policy = solve(solver, mdp)
138+
139+
@test evaluate(mdp, policy, GLOBAL_RNG) > 1.0
140+
end

0 commit comments

Comments
 (0)