diff --git a/src/dual.jl b/src/dual.jl index 2fb3e62b..bb4df73f 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -2,9 +2,22 @@ # Dual # ######## -struct Dual{T,V<:Real,N} <: Real +""" + ForwardDiff.can_dual(V::Type) + +Determines whether the type V is allowed as the scalar type in a +Dual. By default, only `<:Real` types are allowed. +""" +can_dual(::Type{<:Real}) = true +can_dual(::Type) = false + +struct Dual{T,V,N} <: Real value::V partials::Partials{N,V} + function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N} + can_dual(V) || throw_cannot_dual(V) + new{T, V, N}(value, partials) + end end ############## @@ -19,6 +32,11 @@ end Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} = print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)") +@noinline function throw_cannot_dual(V::Type) + throw(ArgumentError("Cannot create a dual over scalar type $V." * + " If the type behaves as a scalar, define FowardDiff.can_dual.")) +end + """ ForwardDiff.≺(a, b)::Bool @@ -41,38 +59,42 @@ tag can be extracted, so it should be used in the _innermost_ function. return Dual{T}(convert(C, value), convert(Partials{N,C}, partials)) end -@inline Dual{T}(value::Real, partials::Tuple) where {T} = Dual{T}(value, Partials(partials)) -@inline Dual{T}(value::Real, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials)) -@inline Dual{T}(value::Real, partials::Real...) where {T} = Dual{T}(value, partials) -@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V<:Real,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p)) +@inline Dual{T}(value, partials::Tuple) where {T} = Dual{T}(value, Partials(partials)) +@inline Dual{T}(value, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials)) +@inline Dual{T}(value) where {T} = Dual{T}(value, ()) +@inline Dual{T}(x::Dual{T}) where {T} = Dual{T}(x, ()) +@inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(partial1, partials...)) +@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p)) @inline Dual(args...) = Dual{Nothing}(args...) # we define these special cases so that the "constructor <--> convert" pun holds for `Dual` -@inline Dual{T,V,N}(x::Real) where {T,V,N} = convert(Dual{T,V,N}, x) -@inline Dual{T,V}(x::Real) where {T,V} = convert(Dual{T,V}, x) +@inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x +@inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x) +@inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x) +@inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x) ############################## # Utility/Accessor Functions # ############################## -@inline value(x::Real) = x +@inline value(x) = x @inline value(d::Dual) = d.value -@inline value(::Type{T}, x::Real) where T = x +@inline value(::Type{T}, x) where T = x @inline value(::Type{T}, d::Dual{T}) where T = value(d) function value(::Type{T}, d::Dual{S}) where {T,S} # TODO: in the case of nested Duals, it may be possible to "transpose" the Dual objects throw(DualMismatchError(T,S)) end -@inline partials(x::Real) = Partials{0,typeof(x)}(tuple()) +@inline partials(x) = Partials{0,typeof(x)}(tuple()) @inline partials(d::Dual) = d.partials -@inline partials(x::Real, i...) = zero(x) +@inline partials(x, i...) = zero(x) @inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i] @inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j] @inline Base.@propagate_inbounds partials(d::Dual, i, j, k...) = partials(partials(d, i, j), k...) -@inline Base.@propagate_inbounds partials(::Type{T}, x::Real, i...) where T = partials(x, i...) +@inline Base.@propagate_inbounds partials(::Type{T}, x, i...) where T = partials(x, i...) @inline Base.@propagate_inbounds partials(::Type{T}, d::Dual{T}, i...) where T = partials(d, i...) partials(::Type{T}, d::Dual{S}, i...) where {T,S} = throw(DualMismatchError(T,S)) @@ -289,7 +311,7 @@ end ######################## Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}}, - ::Type{Dual{T2,V2,N2}}) where {T1,V1<:Real,N1,T2,V2<:Real,N2} + ::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2} # V1 and V2 might themselves be Dual types if T2 ≺ T1 Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1} @@ -299,26 +321,27 @@ Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}}, end function Base.promote_rule(::Type{Dual{T,A,N}}, - ::Type{Dual{T,B,N}}) where {T,A<:Real,B<:Real,N} + ::Type{Dual{T,B,N}}) where {T,A,B,N} return Dual{T,promote_type(A, B),N} end for R in (Irrational, Real, BigFloat, Bool) if isconcretetype(R) # issue #322 @eval begin - Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V<:Real,N} = Dual{T,promote_type($R, V),N} - Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V<:Real,N} = Dual{T,promote_type(V, $R),N} + Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N} + Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N} end else @eval begin - Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V<:Real,N} = Dual{T,promote_type(R, V),N} - Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V<:Real,N,R<:$R} = Dual{T,promote_type(V, R),N} + Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N} + Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N} end end end -Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V<:Real,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d))) -Base.convert(::Type{Dual{T,V,N}}, x::Real) where {T,V<:Real,N} = Dual{T}(convert(V, x), zero(Partials{N,V})) +Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d))) +Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V})) +Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V})) Base.convert(::Type{D}, d::D) where {D<:Dual} = d Base.float(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d) @@ -468,9 +491,9 @@ end # fma # #-----# -@generated function calc_fma_xyz(x::Dual{T,<:Real,N}, - y::Dual{T,<:Real,N}, - z::Dual{T,<:Real,N}) where {T,N} +@generated function calc_fma_xyz(x::Dual{T,<:Any,N}, + y::Dual{T,<:Any,N}, + z::Dual{T,<:Any,N}) where {T,N} ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) return quote $(Expr(:meta, :inline)) @@ -485,9 +508,9 @@ end return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx)) end -@generated function calc_fma_xz(x::Dual{T,<:Real,N}, +@generated function calc_fma_xz(x::Dual{T,<:Any,N}, y::Real, - z::Dual{T,<:Real,N}) where {T,N} + z::Dual{T,<:Any,N}) where {T,N} ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) return quote $(Expr(:meta, :inline)) @@ -510,9 +533,9 @@ end # muladd # #--------# -@generated function calc_muladd_xyz(x::Dual{T,<:Real,N}, - y::Dual{T,<:Real,N}, - z::Dual{T,<:Real,N}) where {T,N} +@generated function calc_muladd_xyz(x::Dual{T,<:Any,N}, + y::Dual{T,<:Any,N}, + z::Dual{T,<:Any,N}) where {T,N} ex = Expr(:tuple, [:(muladd(value(x), partials(y)[$i], muladd(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...) return quote $(Expr(:meta, :inline)) @@ -527,9 +550,9 @@ end return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx)) end -@generated function calc_muladd_xz(x::Dual{T,<:Real,N}, +@generated function calc_muladd_xz(x::Dual{T,<:Any,N}, y::Real, - z::Dual{T,<:Real,N}) where {T,N} + z::Dual{T,<:Any,N}) where {T,N} ex = Expr(:tuple, [:(muladd(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...) return quote $(Expr(:meta, :inline))