diff --git a/Project.toml b/Project.toml index 1f12c7ba..436885a7 100644 --- a/Project.toml +++ b/Project.toml @@ -14,12 +14,15 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] ForwardDiffStaticArraysExt = "StaticArrays" +ForwardDiffUnitfulExt = "Unitful" [compat] Calculus = "0.5" @@ -32,6 +35,7 @@ NaNMath = "1" Preferences = "1" SpecialFunctions = "1, 2" StaticArrays = "1.5" +Unitful = "1" julia = "1.6" [extras] @@ -41,6 +45,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Calculus", "DiffTests", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "SparseArrays", "StaticArrays", "Test", "Unitful", "InteractiveUtils"] diff --git a/ext/ForwardDiffUnitfulExt.jl b/ext/ForwardDiffUnitfulExt.jl new file mode 100644 index 00000000..8c0cdf16 --- /dev/null +++ b/ext/ForwardDiffUnitfulExt.jl @@ -0,0 +1,21 @@ +module ForwardDiffUnitfulExt + +import ForwardDiff: value, extract_derivative, derivative +using ForwardDiff: Dual, Tag +using Unitful: ustrip, unit, Quantity + +@inline function value(::Type{T}, d::Quantity{TD}) where {T, TD <: Dual} + value(T, ustrip(d)) * unit(d) +end + +@inline function extract_derivative(::Type{T}, d::Quantity{TD}) where {T, TD <: Dual} + extract_derivative(T, ustrip(d)) * unit(d) +end + +@inline function derivative(f::F, x::Quantity{R}) where {F,R<:Real} + T = typeof(Tag(f, R)) + ydual = f(Dual{T}(ustrip(x), one(x)) * unit(x)) + return extract_derivative(T, ydual) / unit(x) +end + +end diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index fdfcd560..890a46de 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -24,6 +24,7 @@ include("hessian.jl") if !isdefined(Base, :get_extension) include("../ext/ForwardDiffStaticArraysExt.jl") + include("../ext/ForwardDiffUnitfulExt.jl") end export DiffResults diff --git a/test/DerivativeTest.jl b/test/DerivativeTest.jl index dfdd8ed2..29edd7f8 100644 --- a/test/DerivativeTest.jl +++ b/test/DerivativeTest.jl @@ -6,6 +6,7 @@ using Test using Random using ForwardDiff using DiffTests +using Unitful include(joinpath(dirname(@__FILE__), "utils.jl")) @@ -104,4 +105,10 @@ end @test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im) end +@testset "Unitful" begin + for x in [42, 42u"m"] + @test isapprox(ForwardDiff.derivative(x -> 3.14u"m"*x, x), 3.14u"m") + end +end + end # module