Skip to content

Commit

Permalink
Merge pull request #11 from invenia/wct/extend-jvp
Browse files Browse the repository at this point in the history
Implements extensions to jvp and adjoint
  • Loading branch information
willtebbutt authored Apr 15, 2019
2 parents 8d757c6 + 6cb9da3 commit 8cf50b7
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 27 deletions.
63 changes: 59 additions & 4 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,70 @@ end
jacobian(fdm, f, x::Vector{<:Real}) = jacobian(fdm, f, x, length(f(x)))

"""
jvp(fdm, f, x::AbstractVector{<:Real}, ẋ::AbstractVector{<:Real})
_jvp(fdm, f, x::Vector{<:Real}, ẋ::AbstractVector{<:Real})
Convenience function to compute `jacobian(f, x) * ẋ`.
"""
jvp(fdm, f, x::Vector{<:Real}, ẋ::AV{<:Real}) = jacobian(fdm, f, x) *
_jvp(fdm, f, x::Vector{<:Real}, ẋ::AV{<:Real}) = jacobian(fdm, f, x) *

"""
j′vp(fdm, f, ȳ::AbstractVector{<:Real}, x::AbstractVector{<:Real})
_j′vp(fdm, f, ȳ::AbstractVector{<:Real}, x::Vector{<:Real})
Convenience function to compute `jacobian(f, x)' * ȳ`.
"""
j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x, length(ȳ))' *
_j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x, length(ȳ))' *

"""
jvp(fdm, f, x, ẋ)
Compute a Jacobian-vector product with any types of arguments for which `to_vec` is defined.
"""
function jvp(fdm, f, (x, ẋ)::Tuple{Any, Any})
x_vec, vec_to_x = to_vec(x)
_, vec_to_y = to_vec(f(x))
return vec_to_y(_jvp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], x_vec, to_vec(ẋ)[1]))
end
function jvp(fdm, f, xẋs::Tuple{Any, Any}...)
x, ẋ = collect(zip(xẋs...))
return jvp(fdm, xs->f(xs...), (x, ẋ))
end

"""
j′vp(fdm, f, ȳ, x...)
Compute an adjoint with any types of arguments for which `to_vec` is defined.
"""
function j′vp(fdm, f, ȳ, x)
x_vec, vec_to_x = to_vec(x)
ȳ_vec, _ = to_vec(ȳ)
return vec_to_x(_j′vp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], ȳ_vec, x_vec))
end
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)

"""
to_vec(x)
Transform `x` into a `Vector`, and return a closure which inverts the transformation.
"""
to_vec(x::Real) = ([x], first)

# Arrays.
to_vec(x::Vector{<:Real}) = (x, identity)
to_vec(x::Array) = vec(x), x_vec->reshape(x_vec, size(x))

# AbstractArrays.
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
x_vec, back = to_vec(Matrix(x))
return x_vec, x_vec->T(reshape(back(x_vec), size(x)))
end
to_vec(x::Symmetric) = vec(Matrix(x)), x_vec->Symmetric(reshape(x_vec, size(x)))
to_vec(X::Diagonal) = vec(Matrix(X)), x_vec->Diagonal(reshape(x_vec, size(X)...))

# Non-array data structures.
function to_vec(x::Tuple)
x_vecs, x_backs = zip(map(to_vec, x)...)
sz = cumsum([map(length, x_vecs)...])
return vcat(x_vecs...), function(v)
return ntuple(n->x_backs[n](v[sz[n]-length(x[n])+1:sz[n]]), length(x))
end
end
100 changes: 77 additions & 23 deletions test/grad.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,83 @@
using FDM: grad, jacobian, jvp, j′vp
using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec

@testset "grad" begin
rng, fdm = MersenneTwister(123456), central_fdm(5, 1)
x = randn(rng, 2)
xc = copy(x)
@test grad(fdm, x->sin(x[1]) + cos(x[2]), x) [cos(x[1]), -sin(x[2])]
@test xc == x
end

function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact)
xc = copy(x)
@test jacobian(fdm, f, x, length(ȳ)) J_exact
@test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ))
@test jvp(fdm, f, x, ẋ) J_exact *
@test j′vp(fdm, f, ȳ, x) J_exact' *
@test xc == x
end
@testset "grad" begin
rng, fdm = MersenneTwister(123456), central_fdm(5, 1)
x = randn(rng, 2)
xc = copy(x)
@test grad(fdm, x->sin(x[1]) + cos(x[2]), x) [cos(x[1]), -sin(x[2])]
@test xc == x
end

function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact)
xc = copy(x)
@test jacobian(fdm, f, x, length(ȳ)) J_exact
@test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ))
@test _jvp(fdm, f, x, ẋ) J_exact *
@test _j′vp(fdm, f, ȳ, x) J_exact' *
@test xc == x
end

@testset "jacobian / _jvp / _j′vp" begin
rng, P, Q, fdm = MersenneTwister(123456), 3, 2, central_fdm(5, 1)
ȳ, A, x, ẋ = randn(rng, P), randn(rng, P, Q), randn(rng, Q), randn(rng, Q)
Ac = copy(A)

check_jac_and_jvp_and_j′vp(fdm, x->A * x, ȳ, x, ẋ, A)
@test Ac == A
check_jac_and_jvp_and_j′vp(fdm, x->sin.(A * x), ȳ, x, ẋ, cos.(A * x) .* A)
@test Ac == A
end

function test_to_vec(x)
x_vec, back = to_vec(x)
@test x_vec isa Vector
@test x == back(x_vec)
return nothing
end

@testset "to_vec" begin
test_to_vec(1.0)
test_to_vec(1)
test_to_vec(randn(3))
test_to_vec(randn(5, 11))
test_to_vec(randn(13, 17, 19))
test_to_vec(randn(13, 0, 19))
test_to_vec(UpperTriangular(randn(13, 13)))
test_to_vec(Symmetric(randn(11, 11)))
test_to_vec(Diagonal(randn(7)))

@testset "Tuples" begin
test_to_vec((5, 4))
test_to_vec((5, randn(5)))
test_to_vec((randn(4), randn(4, 3, 2), 1))
test_to_vec((5, randn(4, 3, 2), UpperTriangular(randn(4, 4)), 2.5))
test_to_vec(((6, 5), 3, randn(3, 2, 0, 1)))
end
end

@testset "jacobian / jvp / j′vp" begin
rng, P, Q, fdm = MersenneTwister(123456), 3, 2, central_fdm(5, 1)
ȳ, A, x, ẋ = randn(rng, P), randn(rng, P, Q), randn(rng, Q), randn(rng, Q)
Ac = copy(A)
@testset "jvp" begin
rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1)
x, y = randn(rng, N), randn(rng, M)
ẋ, ẏ = randn(rng, N), randn(rng, M)
xy, ẋẏ = vcat(x, y), vcat(ẋ, ẏ)
ż_manual = _jvp(fdm, (xy)->sum(sin, xy), xy, ẋẏ)[1]
ż_auto = jvp(fdm, x->sum(sin, x[1]) + sum(sin, x[2]), ((x, y), (ẋ, ẏ)))
ż_multi = jvp(fdm, (x, y)->sum(sin, x) + sum(sin, y), (x, ẋ), (y, ẏ))
@test ż_manual ż_auto
@test ż_manual ż_multi
end

check_jac_and_jvp_and_j′vp(fdm, x->A * x, ȳ, x, ẋ, A)
@test Ac == A
check_jac_and_jvp_and_j′vp(fdm, x->sin.(A * x), ȳ, x, ẋ, cos.(A * x) .* A)
@test Ac == A
@testset "j′vp" begin
rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1)
x, y = randn(rng, N), randn(rng, M)
= randn(rng, N + M)
xy = vcat(x, y)
x̄ȳ_manual = j′vp(fdm, xy->sin.(xy), z̄, xy)
x̄ȳ_auto = j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y))
x̄ȳ_multi = j′vp(fdm, (x, y)->sin.(vcat(x, y)), z̄, x, y)
@test x̄ȳ_manual vcat(x̄ȳ_auto...)
@test x̄ȳ_manual vcat(x̄ȳ_multi...)
end
end

0 comments on commit 8cf50b7

Please sign in to comment.