Skip to content

Commit

Permalink
Merge pull request #13 from JuliaDiffEq/reverse_mode
Browse files Browse the repository at this point in the history
Reverse-Mode Neural ODE
  • Loading branch information
ChrisRackauckas authored Jan 25, 2019
2 parents 530d7c8 + 0106d76 commit 4cc55da
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ include("Flux/layers.jl")
include("Flux/neural_de.jl")
include("Flux/utils.jl")

export diffeq_fd, diffeq_rd, diffeq_adjoint, neural_ode, neural_msde
export diffeq_fd, diffeq_rd, diffeq_adjoint
export neural_ode, neural_ode_rd
export neural_dmsde
end
9 changes: 8 additions & 1 deletion src/Flux/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ using Flux.Tracker: @grad
## Reverse-Mode via Flux.jl

function diffeq_rd(p,prob,args...;u0=prob.u0,kwargs...)
_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
if DiffEqBase.isinplace(prob)
# use Array{TrackedReal} for mutation to work
_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
else
# use TrackedArray for efficiency of the tape
_prob = remake(prob,u0=convert(typeof(p),u0),p=p)
end
solve(_prob,args...;kwargs...)
end

Expand All @@ -13,6 +19,7 @@ function diffeq_fd(p,f,n,prob,args...;u0=prob.u0,kwargs...)
_prob = remake(prob,u0=convert.(eltype(p),u0),p=p)
f(solve(_prob,args...;kwargs...))
end

diffeq_fd(p::TrackedVector,args...;kwargs...) = Flux.Tracker.track(diffeq_fd, p, args...; kwargs...)
Flux.Tracker.@grad function diffeq_fd(p::TrackedVector,f,n,prob,args...;u0=prob.u0,kwargs...)
_f = function (p)
Expand Down
33 changes: 18 additions & 15 deletions src/Flux/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,26 @@ function neural_ode(model,x,tspan,
return diffeq_adjoint(p,prob,args...;kwargs...)
end

function neural_ode_rd(model,x,tspan,
args...;kwargs...)
Tracker.istracked(x) && error("u0 is not currently differentiable.")
p = destructure(model)
dudt_(u::TrackedArray,p,t) = restructure(model,p)(u)
dudt_(u::AbstractArray,p,t) = Flux.data(restructure(model,p)(u))
prob = ODEProblem(dudt_,x,tspan,p)
return Flux.Tracker.collect(diffeq_rd(p,prob,args...;kwargs...))
end

neural_msde(x,model,mp,tspan,args...;kwargs...) = neural_msde(x,model,mp,tspan,
diffeq_fd,
args...;kwargs...)
function neural_msde(model,x,mp,tspan,
ad_func::Function,
args...;kwargs...)
p = Flux.data(destructure(model))
dudt_(du,u::TrackedArray,p,t) = du .= restructure(model,p)(u)
dudt_(du,u::AbstractArray,p,t) = du .= Flux.data(restructure(model,p)(u))
g(du,u,p,t) = du .= mp.*u
function neural_dmsde(model,x,mp,tspan,
args...;kwargs...)
Tracker.istracked(x) && error("u0 is not currently differentiable.")
p = destructure(model)
dudt_(u::TrackedArray,p,t) = restructure(model,p)(u)
dudt_(u::AbstractArray,p,t) = Flux.data(restructure(model,p)(u))
g(u,p,t) = mp.*u
prob = SDEProblem(dudt_,g,x,tspan,p)

if ad_func === diffeq_adjoint
return ad_func(p,prob,args...;kwargs...)
elseif ad_func === diffeq_fd
return ad_func(p,Array,length(p),prob,args...;kwargs...)
else
return ad_func(p,Array,prob,args...;kwargs...)
end
return Flux.Tracker.collect(diffeq_rd(p,prob,args...;kwargs...))
end
1 change: 1 addition & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OrdinaryDiffEq
StochasticDiffEq
DelayDiffEq
25 changes: 25 additions & 0 deletions test/layers_dde.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Flux, DiffEqFlux, DelayDiffEq, Plots

## Setup DDE to optimize
function delay_lotka_volterra(du,u,h,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx =- β*y)*h(p,t-0.1)[1]
du[2] = dy =*x - γ)*y
end
h(p,t) = ones(eltype(p),2)
prob = DDEProblem(delay_lotka_volterra,[1.0,1.0],h,(0.0,10.0),constant_lags=[0.1])
p = param([2.2, 1.0, 2.0, 0.4])
function predict_fd_dde()
diffeq_fd(p,sol->sol[1,:],101,prob,MethodOfSteps(Tsit5()),saveat=0.1)
end
loss_fd_dde() = sum(abs2,x-1 for x in predict_fd_dde())
@test_broken loss_fd_dde()
@test_broken Flux.back!(loss_fd_dde())

function predict_rd_dde()
diffeq_rd(p,prob,MethodOfSteps(Tsit5()),saveat=0.1)[1,:]
end
loss_rd_dde() = sum(abs2,x-1 for x in predict_rd_dde())
loss_rd_dde()
Flux.back!(loss_rd_dde())
27 changes: 27 additions & 0 deletions test/layers_sde.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Flux, DiffEqFlux, StochasticDiffEq, Plots

function lotka_volterra(du,u,p,t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
function lotka_volterra_noise(du,u,p,t)
du[1] = 0.1u[1]
du[2] = 0.1u[2]
end
prob = SDEProblem(lotka_volterra,lotka_volterra_noise,[1.0,1.0],(0.0,10.0))
p = param([2.2, 1.0, 2.0, 0.4])
function predict_fd_sde()
diffeq_fd(p,sol->sol[1,:],101,prob,SOSRI(),saveat=0.1)
end
loss_fd_sde() = sum(abs2,x-1 for x in predict_fd_sde())
loss_fd_sde()
Flux.back!(loss_fd_sde())

function predict_rd_sde()
Array(diffeq_rd(p,prob,SOSRI(),saveat=0.1))
end
loss_rd_sde() = sum(abs2,x-1 for x in predict_rd_sde())
loss_rd_sde()
@test_broken Flux.back!(loss_rd_sde())
8 changes: 7 additions & 1 deletion test/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,11 @@ dudt = Chain(Dense(2,50,tanh),Dense(50,2))

neural_ode(dudt,x,tspan,Tsit5(),save_everystep=false,save_start=false)
neural_ode(dudt,x,tspan,Tsit5(),saveat=0.1)
neural_ode_rd(dudt,x,tspan,Tsit5(),saveat=0.1)

neural_msde(dudt,x,[0.1,0.1],tspan,SOSRI(),saveat=0.1)
Flux.back!(sum(neural_ode(dudt,x,tspan,Tsit5(),saveat=0.0:0.1:10.0)))
Flux.back!(sum(neural_ode_rd(dudt,x,tspan,Tsit5(),saveat=0.1)))

mp = Float32[0.1,0.1]
neural_dmsde(dudt,x,mp,tspan,SOSRI(),saveat=0.1)
Flux.back!(sum(neural_dmsde(dudt,x,mp,tspan,SOSRI(),saveat=0.1)))

0 comments on commit 4cc55da

Please sign in to comment.