diff --git a/src/dual.jl b/src/dual.jl index 7e8ec110..180dc1b7 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -88,6 +88,7 @@ Dual{T,V,N}(x::Base.TwicePrecision) where {T,V,N} = @inline value(x) = x @inline value(d::Dual) = d.value +@inline value(d::Complex{<:Dual}) = complex(value(real(d)), value(imag(d))) @inline value(::Type{T}, x) where T = x @inline value(::Type{T}, d::Dual{T}) where T = value(d) @@ -101,6 +102,8 @@ end @inline partials(x) = Partials{0,typeof(x)}(tuple()) @inline partials(d::Dual) = d.partials +@inline partials(d::Complex{<:Dual}, i) = complex(partials(real(d), i), partials(imag(d), i)) +@inline partials(d::Complex{<:Dual}) = Partials(complex.(partials(real(d)).values, partials(imag(d)).values)) @inline partials(x, i...) = zero(x) @inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i] @inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j] @@ -119,6 +122,8 @@ end @inline npartials(::Dual{T,V,N}) where {T,V,N} = N @inline npartials(::Type{Dual{T,V,N}}) where {T,V,N} = N +@inline npartials(::Complex{<:Dual{T,V,N}}) where {T,V,N} = N +@inline npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N @inline order(::Type{V}) where {V} = 0 @inline order(::Type{Dual{T,V,N}}) where {T,V,N} = 1 + order(V) diff --git a/test/DualTest.jl b/test/DualTest.jl index 6d0bf85f..528d60b4 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -4,7 +4,7 @@ using Test using Printf using Random using ForwardDiff -using ForwardDiff: Partials, Dual, value, partials +using ForwardDiff: Partials, Dual, value, partials, npartials using NaNMath, SpecialFunctions, LogExpFunctions using DiffRules @@ -673,4 +673,14 @@ end @test ForwardDiff.derivative(x -> sum(1 .+ x .* (0:0.1:1)), 1) == 5.5 end +@testset "Complex value/partials" begin + x = Dual(1,2,3) + im*Dual(4,5,6) + @test value(x) == 1+im*4 + @test partials(x,1) == 2 + 5im + @test partials(x,2) == 3 + 6im + @test partials(x) == [2+5im,3+6im] + @test partials(x) isa Partials + @test npartials(x) == npartials(typeof(x)) == 2 +end + end # module