diff --git a/src/grad.jl b/src/grad.jl index 887b33ae..6216cba4 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -35,15 +35,70 @@ end jacobian(fdm, f, x::Vector{<:Real}) = jacobian(fdm, f, x, length(f(x))) """ - jvp(fdm, f, x::AbstractVector{<:Real}, ẋ::AbstractVector{<:Real}) + _jvp(fdm, f, x::Vector{<:Real}, ẋ::AbstractVector{<:Real}) Convenience function to compute `jacobian(f, x) * ẋ`. """ -jvp(fdm, f, x::Vector{<:Real}, ẋ::AV{<:Real}) = jacobian(fdm, f, x) * ẋ +_jvp(fdm, f, x::Vector{<:Real}, ẋ::AV{<:Real}) = jacobian(fdm, f, x) * ẋ """ - j′vp(fdm, f, ȳ::AbstractVector{<:Real}, x::AbstractVector{<:Real}) + _j′vp(fdm, f, ȳ::AbstractVector{<:Real}, x::Vector{<:Real}) Convenience function to compute `jacobian(f, x)' * ȳ`. """ -j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x, length(ȳ))' * ȳ +_j′vp(fdm, f, ȳ::AV{<:Real}, x::Vector{<:Real}) = jacobian(fdm, f, x, length(ȳ))' * ȳ + +""" + jvp(fdm, f, x, ẋ) + +Compute a Jacobian-vector product with any types of arguments for which `to_vec` is defined. +""" +function jvp(fdm, f, (x, ẋ)::Tuple{Any, Any}) + x_vec, vec_to_x = to_vec(x) + _, vec_to_y = to_vec(f(x)) + return vec_to_y(_jvp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], x_vec, to_vec(ẋ)[1])) +end +function jvp(fdm, f, xẋs::Tuple{Any, Any}...) + x, ẋ = collect(zip(xẋs...)) + return jvp(fdm, xs->f(xs...), (x, ẋ)) +end + +""" + j′vp(fdm, f, ȳ, x...) + +Compute an adjoint with any types of arguments for which `to_vec` is defined. +""" +function j′vp(fdm, f, ȳ, x) + x_vec, vec_to_x = to_vec(x) + ȳ_vec, _ = to_vec(ȳ) + return vec_to_x(_j′vp(fdm, x_vec->to_vec(f(vec_to_x(x_vec)))[1], ȳ_vec, x_vec)) +end +j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs) + +""" + to_vec(x) + +Transform `x` into a `Vector`, and return a closure which inverts the transformation. +""" +to_vec(x::Real) = ([x], first) + +# Arrays. +to_vec(x::Vector{<:Real}) = (x, identity) +to_vec(x::Array) = vec(x), x_vec->reshape(x_vec, size(x)) + +# AbstractArrays. +function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular} + x_vec, back = to_vec(Matrix(x)) + return x_vec, x_vec->T(reshape(back(x_vec), size(x))) +end +to_vec(x::Symmetric) = vec(Matrix(x)), x_vec->Symmetric(reshape(x_vec, size(x))) +to_vec(X::Diagonal) = vec(Matrix(X)), x_vec->Diagonal(reshape(x_vec, size(X)...)) + +# Non-array data structures. +function to_vec(x::Tuple) + x_vecs, x_backs = zip(map(to_vec, x)...) + sz = cumsum([map(length, x_vecs)...]) + return vcat(x_vecs...), function(v) + return ntuple(n->x_backs[n](v[sz[n]-length(x[n])+1:sz[n]]), length(x)) + end +end diff --git a/test/grad.jl b/test/grad.jl index 8dffd8b8..ae7c8d73 100644 --- a/test/grad.jl +++ b/test/grad.jl @@ -1,29 +1,83 @@ -using FDM: grad, jacobian, jvp, j′vp +using FDM: grad, jacobian, _jvp, _j′vp, jvp, j′vp, to_vec @testset "grad" begin - rng, fdm = MersenneTwister(123456), central_fdm(5, 1) - x = randn(rng, 2) - xc = copy(x) - @test grad(fdm, x->sin(x[1]) + cos(x[2]), x) ≈ [cos(x[1]), -sin(x[2])] - @test xc == x -end -function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact) - xc = copy(x) - @test jacobian(fdm, f, x, length(ȳ)) ≈ J_exact - @test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ)) - @test jvp(fdm, f, x, ẋ) ≈ J_exact * ẋ - @test j′vp(fdm, f, ȳ, x) ≈ J_exact' * ȳ - @test xc == x -end + @testset "grad" begin + rng, fdm = MersenneTwister(123456), central_fdm(5, 1) + x = randn(rng, 2) + xc = copy(x) + @test grad(fdm, x->sin(x[1]) + cos(x[2]), x) ≈ [cos(x[1]), -sin(x[2])] + @test xc == x + end + + function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact) + xc = copy(x) + @test jacobian(fdm, f, x, length(ȳ)) ≈ J_exact + @test jacobian(fdm, f, x) == jacobian(fdm, f, x, length(ȳ)) + @test _jvp(fdm, f, x, ẋ) ≈ J_exact * ẋ + @test _j′vp(fdm, f, ȳ, x) ≈ J_exact' * ȳ + @test xc == x + end + + @testset "jacobian / _jvp / _j′vp" begin + rng, P, Q, fdm = MersenneTwister(123456), 3, 2, central_fdm(5, 1) + ȳ, A, x, ẋ = randn(rng, P), randn(rng, P, Q), randn(rng, Q), randn(rng, Q) + Ac = copy(A) + + check_jac_and_jvp_and_j′vp(fdm, x->A * x, ȳ, x, ẋ, A) + @test Ac == A + check_jac_and_jvp_and_j′vp(fdm, x->sin.(A * x), ȳ, x, ẋ, cos.(A * x) .* A) + @test Ac == A + end + + function test_to_vec(x) + x_vec, back = to_vec(x) + @test x_vec isa Vector + @test x == back(x_vec) + return nothing + end + + @testset "to_vec" begin + test_to_vec(1.0) + test_to_vec(1) + test_to_vec(randn(3)) + test_to_vec(randn(5, 11)) + test_to_vec(randn(13, 17, 19)) + test_to_vec(randn(13, 0, 19)) + test_to_vec(UpperTriangular(randn(13, 13))) + test_to_vec(Symmetric(randn(11, 11))) + test_to_vec(Diagonal(randn(7))) + + @testset "Tuples" begin + test_to_vec((5, 4)) + test_to_vec((5, randn(5))) + test_to_vec((randn(4), randn(4, 3, 2), 1)) + test_to_vec((5, randn(4, 3, 2), UpperTriangular(randn(4, 4)), 2.5)) + test_to_vec(((6, 5), 3, randn(3, 2, 0, 1))) + end + end -@testset "jacobian / jvp / j′vp" begin - rng, P, Q, fdm = MersenneTwister(123456), 3, 2, central_fdm(5, 1) - ȳ, A, x, ẋ = randn(rng, P), randn(rng, P, Q), randn(rng, Q), randn(rng, Q) - Ac = copy(A) + @testset "jvp" begin + rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1) + x, y = randn(rng, N), randn(rng, M) + ẋ, ẏ = randn(rng, N), randn(rng, M) + xy, ẋẏ = vcat(x, y), vcat(ẋ, ẏ) + ż_manual = _jvp(fdm, (xy)->sum(sin, xy), xy, ẋẏ)[1] + ż_auto = jvp(fdm, x->sum(sin, x[1]) + sum(sin, x[2]), ((x, y), (ẋ, ẏ))) + ż_multi = jvp(fdm, (x, y)->sum(sin, x) + sum(sin, y), (x, ẋ), (y, ẏ)) + @test ż_manual ≈ ż_auto + @test ż_manual ≈ ż_multi + end - check_jac_and_jvp_and_j′vp(fdm, x->A * x, ȳ, x, ẋ, A) - @test Ac == A - check_jac_and_jvp_and_j′vp(fdm, x->sin.(A * x), ȳ, x, ẋ, cos.(A * x) .* A) - @test Ac == A + @testset "j′vp" begin + rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1) + x, y = randn(rng, N), randn(rng, M) + z̄ = randn(rng, N + M) + xy = vcat(x, y) + x̄ȳ_manual = j′vp(fdm, xy->sin.(xy), z̄, xy) + x̄ȳ_auto = j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y)) + x̄ȳ_multi = j′vp(fdm, (x, y)->sin.(vcat(x, y)), z̄, x, y) + @test x̄ȳ_manual ≈ vcat(x̄ȳ_auto...) + @test x̄ȳ_manual ≈ vcat(x̄ȳ_multi...) + end end