Skip to content

Commit

Permalink
Remove V <: Real type restriction
Browse files Browse the repository at this point in the history
Instead, use an extensible function that the constructor uses to
check whether a type is valid to be used as `Dual`'s scalar type.

Fixes JuliaDiff#216
  • Loading branch information
Keno committed Oct 15, 2018
1 parent e1a129b commit 13d231a
Showing 1 changed file with 53 additions and 30 deletions.
83 changes: 53 additions & 30 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@
# Dual #
########

struct Dual{T,V<:Real,N} <: Real
"""
ForwardDiff.can_dual(V::Type)
Determines whether the type V is allowed as the scalar type in a
Dual. By default, only `<:Real` types are allowed.
"""
can_dual(::Type{<:Real}) = true
can_dual(::Type) = false

struct Dual{T,V,N} <: Real
value::V
partials::Partials{N,V}
function Dual{T, V, N}(value::V, partials::Partials{N, V}) where {T, V, N}
can_dual(V) || throw_cannot_dual(V)
new{T, V, N}(value, partials)
end
end

##############
Expand All @@ -19,6 +32,11 @@ end
Base.showerror(io::IO, e::DualMismatchError{A,B}) where {A,B} =
print(io, "Cannot determine ordering of Dual tags $(e.a) and $(e.b)")

@noinline function throw_cannot_dual(V::Type)
throw(ArgumentError("Cannot create a dual over scalar type $V." *
" If the type behaves as a scalar, define FowardDiff.can_dual."))
end

"""
ForwardDiff.≺(a, b)::Bool
Expand All @@ -41,38 +59,42 @@ tag can be extracted, so it should be used in the _innermost_ function.
return Dual{T}(convert(C, value), convert(Partials{N,C}, partials))
end

@inline Dual{T}(value::Real, partials::Tuple) where {T} = Dual{T}(value, Partials(partials))
@inline Dual{T}(value::Real, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials))
@inline Dual{T}(value::Real, partials::Real...) where {T} = Dual{T}(value, partials)
@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V<:Real,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p))
@inline Dual{T}(value, partials::Tuple) where {T} = Dual{T}(value, Partials(partials))
@inline Dual{T}(value, partials::Tuple{}) where {T} = Dual{T}(value, Partials{0,typeof(value)}(partials))
@inline Dual{T}(value) where {T} = Dual{T}(value, ())
@inline Dual{T}(x::Dual{T}) where {T} = Dual{T}(x, ())
@inline Dual{T}(value, partial1, partials...) where {T} = Dual{T}(value, tuple(partial1, partials...))
@inline Dual{T}(value::V, ::Chunk{N}, p::Val{i}) where {T,V,N,i} = Dual{T}(value, single_seed(Partials{N,V}, p))
@inline Dual(args...) = Dual{Nothing}(args...)

# we define these special cases so that the "constructor <--> convert" pun holds for `Dual`
@inline Dual{T,V,N}(x::Real) where {T,V,N} = convert(Dual{T,V,N}, x)
@inline Dual{T,V}(x::Real) where {T,V} = convert(Dual{T,V}, x)
@inline Dual{T,V,N}(x::Dual{T,V,N}) where {T,V,N} = x
@inline Dual{T,V,N}(x) where {T,V,N} = convert(Dual{T,V,N}, x)
@inline Dual{T,V,N}(x::Number) where {T,V,N} = convert(Dual{T,V,N}, x)
@inline Dual{T,V}(x) where {T,V} = convert(Dual{T,V}, x)

##############################
# Utility/Accessor Functions #
##############################

@inline value(x::Real) = x
@inline value(x) = x
@inline value(d::Dual) = d.value

@inline value(::Type{T}, x::Real) where T = x
@inline value(::Type{T}, x) where T = x
@inline value(::Type{T}, d::Dual{T}) where T = value(d)
function value(::Type{T}, d::Dual{S}) where {T,S}
# TODO: in the case of nested Duals, it may be possible to "transpose" the Dual objects
throw(DualMismatchError(T,S))
end

@inline partials(x::Real) = Partials{0,typeof(x)}(tuple())
@inline partials(x) = Partials{0,typeof(x)}(tuple())
@inline partials(d::Dual) = d.partials
@inline partials(x::Real, i...) = zero(x)
@inline partials(x, i...) = zero(x)
@inline Base.@propagate_inbounds partials(d::Dual, i) = d.partials[i]
@inline Base.@propagate_inbounds partials(d::Dual, i, j) = partials(d, i).partials[j]
@inline Base.@propagate_inbounds partials(d::Dual, i, j, k...) = partials(partials(d, i, j), k...)

@inline Base.@propagate_inbounds partials(::Type{T}, x::Real, i...) where T = partials(x, i...)
@inline Base.@propagate_inbounds partials(::Type{T}, x, i...) where T = partials(x, i...)
@inline Base.@propagate_inbounds partials(::Type{T}, d::Dual{T}, i...) where T = partials(d, i...)
partials(::Type{T}, d::Dual{S}, i...) where {T,S} = throw(DualMismatchError(T,S))

Expand Down Expand Up @@ -289,7 +311,7 @@ end
########################

Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}},
::Type{Dual{T2,V2,N2}}) where {T1,V1<:Real,N1,T2,V2<:Real,N2}
::Type{Dual{T2,V2,N2}}) where {T1,V1,N1,T2,V2,N2}
# V1 and V2 might themselves be Dual types
if T2 T1
Dual{T1,promote_type(V1,Dual{T2,V2,N2}),N1}
Expand All @@ -299,26 +321,27 @@ Base.@pure function Base.promote_rule(::Type{Dual{T1,V1,N1}},
end

function Base.promote_rule(::Type{Dual{T,A,N}},
::Type{Dual{T,B,N}}) where {T,A<:Real,B<:Real,N}
::Type{Dual{T,B,N}}) where {T,A,B,N}
return Dual{T,promote_type(A, B),N}
end

for R in (Irrational, Real, BigFloat, Bool)
if isconcretetype(R) # issue #322
@eval begin
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V<:Real,N} = Dual{T,promote_type($R, V),N}
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V<:Real,N} = Dual{T,promote_type(V, $R),N}
Base.promote_rule(::Type{$R}, ::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,promote_type($R, V),N}
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{$R}) where {T,V,N} = Dual{T,promote_type(V, $R),N}
end
else
@eval begin
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V<:Real,N} = Dual{T,promote_type(R, V),N}
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V<:Real,N,R<:$R} = Dual{T,promote_type(V, R),N}
Base.promote_rule(::Type{R}, ::Type{Dual{T,V,N}}) where {R<:$R,T,V,N} = Dual{T,promote_type(R, V),N}
Base.promote_rule(::Type{Dual{T,V,N}}, ::Type{R}) where {T,V,N,R<:$R} = Dual{T,promote_type(V, R),N}
end
end
end

Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V<:Real,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d)))
Base.convert(::Type{Dual{T,V,N}}, x::Real) where {T,V<:Real,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
Base.convert(::Type{Dual{T,V,N}}, d::Dual{T}) where {T,V,N} = Dual{T}(convert(V, value(d)), convert(Partials{N,V}, partials(d)))
Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
Base.convert(::Type{D}, d::D) where {D<:Dual} = d

Base.float(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d)
Expand Down Expand Up @@ -468,9 +491,9 @@ end
# fma #
#-----#

@generated function calc_fma_xyz(x::Dual{T,<:Real,N},
y::Dual{T,<:Real,N},
z::Dual{T,<:Real,N}) where {T,N}
@generated function calc_fma_xyz(x::Dual{T,<:Any,N},
y::Dual{T,<:Any,N},
z::Dual{T,<:Any,N}) where {T,N}
ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
return quote
$(Expr(:meta, :inline))
Expand All @@ -485,9 +508,9 @@ end
return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
end

@generated function calc_fma_xz(x::Dual{T,<:Real,N},
@generated function calc_fma_xz(x::Dual{T,<:Any,N},
y::Real,
z::Dual{T,<:Real,N}) where {T,N}
z::Dual{T,<:Any,N}) where {T,N}
ex = Expr(:tuple, [:(fma(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
return quote
$(Expr(:meta, :inline))
Expand All @@ -510,9 +533,9 @@ end
# muladd #
#--------#

@generated function calc_muladd_xyz(x::Dual{T,<:Real,N},
y::Dual{T,<:Real,N},
z::Dual{T,<:Real,N}) where {T,N}
@generated function calc_muladd_xyz(x::Dual{T,<:Any,N},
y::Dual{T,<:Any,N},
z::Dual{T,<:Any,N}) where {T,N}
ex = Expr(:tuple, [:(muladd(value(x), partials(y)[$i], muladd(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
return quote
$(Expr(:meta, :inline))
Expand All @@ -527,9 +550,9 @@ end
return Dual{T}(result, _mul_partials(partials(x), partials(y), vy, vx))
end

@generated function calc_muladd_xz(x::Dual{T,<:Real,N},
@generated function calc_muladd_xz(x::Dual{T,<:Any,N},
y::Real,
z::Dual{T,<:Real,N}) where {T,N}
z::Dual{T,<:Any,N}) where {T,N}
ex = Expr(:tuple, [:(muladd(partials(x)[$i], y, partials(z)[$i])) for i in 1:N]...)
return quote
$(Expr(:meta, :inline))
Expand Down

0 comments on commit 13d231a

Please sign in to comment.