Skip to content

Commit 5449310

Browse files
authored
Updates to fix issues associated with Flux 0.14 (#76)
* updates to fix issues associated with Flux 0.14 * version bump
1 parent bdaa6cb commit 5449310

File tree

7 files changed

+73
-24
lines changed

7 files changed

+73
-24
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ log*
66
*.bson
77
events.out.tfevents*
88
.vscode
9-
Manifest.toml
9+
Manifest.toml
10+
.DS_Store

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DeepQLearning"
22
uuid = "de0a67f4-c691-11e8-0034-5fc6e16e22d3"
33
repo = "https://github.com/JuliaPOMDP/DeepQLearning.jl"
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"

README.md

+12-12
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ using DeepQLearning
2727
using POMDPs
2828
using Flux
2929
using POMDPModels
30-
using POMDPSimulators
3130
using POMDPTools
3231

3332
# load MDP model from POMDPModels or define your own!
@@ -37,7 +36,7 @@ mdp = SimpleGridWorld();
3736
# the gridworld state is represented by a 2 dimensional vector.
3837
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))
3938

40-
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2))
39+
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));
4140

4241
solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
4342
exploration_policy = exploration,
@@ -99,39 +98,40 @@ mdp = SimpleGridWorld();
9998
# the model weights will be send to the gpu in the call to solve
10099
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))
101100

102-
solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
103-
learning_rate=0.005,log_freq=500,
104-
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
101+
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));
102+
103+
solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
104+
exploration_policy=exploration,
105+
learning_rate=0.005,log_freq=500,
106+
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
105107
policy = solve(solver, mdp)
106108
```
107109

108110
## Solver Options
109111

110112
**Fields of the Q Learning solver:**
111113
- `qnetwork::Any = nothing` Specify the architecture of the Q network
114+
- `exploration_policy::<ExplorationPolicy` Exploration strategy (e.g. EpsGreedyPolicy)
112115
- `learning_rate::Float64 = 1e-4` learning rate
113116
- `max_steps::Int64` total number of training step default = 1000
114-
- `target_update_freq::Int64` frequency at which the target network is updated default = 500
115117
- `batch_size::Int64` batch size sampled from the replay buffer default = 32
116118
- `train_freq::Int64` frequency at which the active network is updated default = 4
117-
- `log_freq::Int64` frequency at which to logg info default = 100
118119
- `eval_freq::Int64` frequency at which to eval the network default = 100
120+
- `target_update_freq::Int64` frequency at which the target network is updated default = 500
119121
- `num_ep_eval::Int64` number of episodes to evaluate the policy default = 100
120-
- `eps_fraction::Float64` fraction of the training set used to explore default = 0.5
121-
- `eps_end::Float64` value of epsilon at the end of the exploration phase default = 0.01
122122
- `double_q::Bool` double q learning udpate default = true
123123
- `dueling::Bool` dueling structure for the q network default = true
124124
- `recurrence::Bool = false` set to true to use DRQN, it will throw an error if you set it to false and pass a recurrent model.
125+
- `evaluation_policy::Function = basic_evaluation` function use to evaluate the policy every `eval_freq` steps, the default is a rollout that return the undiscounted average reward
125126
- `prioritized_replay::Bool` enable prioritized experience replay default = true
126127
- `prioritized_replay_alpha::Float64` default = 0.6
127128
- `prioritized_replay_epsilon::Float64` default = 1e-6
128129
- `prioritized_replay_beta::Float64` default = 0.4
129130
- `buffer_size::Int64` size of the experience replay buffer default = 1000
130131
- `max_episode_length::Int64` maximum length of a training episode default = 100
131132
- `train_start::Int64` number of steps used to fill in the replay buffer initially default = 200
132-
- `save_freq::Int64` save the model every `save_freq` steps, default = 1000
133-
- `evaluation_policy::Function = basic_evaluation` function use to evaluate the policy every `eval_freq` steps, the default is a rollout that return the undiscounted average reward
134-
- `exploration_policy::Any = linear_epsilon_greedy(max_steps, eps_fraction, eps_end)` exploration strategy (default is epsilon greedy with linear decay)
135133
- `rng::AbstractRNG` random number generator default = MersenneTwister(0)
136134
- `logdir::String = ""` folder in which to save the model
135+
- `save_freq::Int64` save the model every `save_freq` steps, default = 1000
136+
- `log_freq::Int64` frequency at which to logg info default = 100
137137
- `verbose::Bool` default = true

src/dueling.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function (m::DuelingNetwork)(inpt)
1010
return m.val(x) .+ m.adv(x) .- mean(m.adv(x), dims=1)
1111
end
1212

13-
Flux.@functor DuelingNetwork
13+
Flux.@layer DuelingNetwork
1414

1515
function Flux.reset!(m::DuelingNetwork)
1616
Flux.reset!(m.base)

src/solver.jl

+10-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@with_kw mutable struct DeepQLearningSolver{E<:ExplorationPolicy} <: Solver
22
qnetwork::Any = nothing # intended to be a flux model
3+
exploration_policy::E # No default since 9ac3ab
34
learning_rate::Float32 = 1f-4
45
max_steps::Int64 = 1000
56
batch_size::Int64 = 32
@@ -11,7 +12,6 @@
1112
dueling::Bool = true
1213
recurrence::Bool = false
1314
evaluation_policy::Any = basic_evaluation
14-
exploration_policy::E
1515
trace_length::Int64 = 40
1616
prioritized_replay::Bool = true
1717
prioritized_replay_alpha::Float32 = 0.6f0
@@ -139,9 +139,8 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::Abstr
139139
sethiddenstates!(active_q, hs)
140140
end
141141

142-
if t%solver.target_update_freq == 0
143-
weights = Flux.params(active_q)
144-
Flux.loadparams!(target_q, weights)
142+
if t % solver.target_update_freq == 0
143+
target_q = deepcopy(active_q)
145144
end
146145

147146
if t % solver.eval_freq == 0
@@ -170,9 +169,9 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::Abstr
170169
if model_saved
171170
if solver.verbose
172171
@printf("Restore model with eval reward %1.3f \n", saved_mean_reward)
173-
saved_model = BSON.load(joinpath(solver.logdir, "qnetwork.bson"))[:qnetwork]
174-
Flux.loadparams!(getnetwork(policy), saved_model)
175172
end
173+
saved_model_state = BSON.load(joinpath(solver.logdir, "qnetwork_state.bson"))[:qnetwork_state]
174+
Flux.loadmodel!(policy.qnetwork, saved_model_state)
176175
end
177176
return policy
178177
end
@@ -289,7 +288,9 @@ end
289288

290289
function save_model(solver::DeepQLearningSolver, active_q, scores_eval::Float64, saved_mean_reward::Float64, model_saved::Bool)
291290
if scores_eval >= saved_mean_reward
292-
bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in Flux.params(active_q)])
291+
copied_model = deepcopy(active_q)
292+
Flux.reset!(copied_model)
293+
bson(joinpath(solver.logdir, "qnetwork_state.bson"), qnetwork_state=Flux.state(copied_model))
293294
if solver.verbose
294295
@printf("Saving new model with eval reward %1.3f \n", scores_eval)
295296
end
@@ -311,8 +312,8 @@ function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv)
311312
active_q = solver.qnetwork
312313
end
313314
policy = NNPolicy(env, active_q, collect(actions(env)), length(obs_dimensions(env)))
314-
weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork]
315-
Flux.loadparams!(getnetwork(policy), weights)
315+
saved_network_state = BSON.load(solver.logdir*"qnetwork_state.bson")[:qnetwork_state]
316+
Flux.loadmodel!(getnetwork(policy), saved_network_state)
316317
Flux.testmode!(getnetwork(policy))
317318
return policy
318319
end

test/README_examples.jl

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using DeepQLearning
2+
using POMDPs
3+
using Flux
4+
using POMDPModels
5+
using POMDPTools
6+
7+
@testset "README Example 1" begin
8+
# load MDP model from POMDPModels or define your own!
9+
mdp = SimpleGridWorld();
10+
11+
# Define the Q network (see Flux.jl documentation)
12+
# the gridworld state is represented by a 2 dimensional vector.
13+
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))
14+
15+
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));
16+
17+
solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
18+
exploration_policy = exploration,
19+
learning_rate=0.005,log_freq=500,
20+
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
21+
policy = solve(solver, mdp)
22+
23+
sim = RolloutSimulator(max_steps=30)
24+
r_tot = simulate(sim, mdp, policy)
25+
println("Total discounted reward for 1 simulation: $r_tot")
26+
end
27+
28+
@testset "README Example 2" begin
29+
# Without using CuArrays
30+
mdp = SimpleGridWorld();
31+
32+
# the model weights will be send to the gpu in the call to solve
33+
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))
34+
35+
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));
36+
37+
solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
38+
exploration_policy=exploration,
39+
learning_rate=0.005,log_freq=500,
40+
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
41+
policy = solve(solver, mdp)
42+
end

test/runtests.jl

+5
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,8 @@ end
232232

233233
@test evaluate(env, policy, GLOBAL_RNG) > 1.0
234234
end
235+
236+
237+
@testset "README Examples" begin
238+
include("README_examples.jl")
239+
end

0 commit comments

Comments
 (0)