Skip to content

Commit

Permalink
fix NaNMath exponentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Nov 1, 2024
1 parent 7e9d778 commit f473514
Showing 1 changed file with 66 additions and 36 deletions.
102 changes: 66 additions & 36 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,43 +552,73 @@ end
# exponentiation #
#----------------#

for f in (:(Base.:^), :(NaNMath.pow))
@eval begin
@define_binary_dual_op(
$f,
begin
vx, vy = value(x), value(y)
expv = ($f)(vx, vy)
powval = vy * ($f)(vx, vy - 1)
if isconstant(y)
logval = one(expv)
elseif iszero(vx) && vy > 0
logval = zero(vx)
else
logval = expv * log(vx)
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
end,
begin
v = value(x)
expv = ($f)(v, y)
if y == zero(y) || iszero(partials(x))
new_partials = zero(partials(x))
else
new_partials = partials(x) * y * ($f)(v, y - 1)
end
return Dual{Tx}(expv, new_partials)
end,
begin
v = value(y)
expv = ($f)(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
return Dual{Ty}(expv, deriv * partials(y))
end
)
@define_binary_dual_op(
Base.:^,
begin
vx, vy = value(x), value(y)
expv = (^)(vx, vy)
powval = vy * (^)(vx, vy - 1)
if isconstant(y)
logval = one(expv)
elseif iszero(vx) && vy > 0
logval = zero(vx)
else
logval = expv * log(vx)
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
end,
begin
v = value(x)
expv = (^)(v, y)
if y == zero(y) || iszero(partials(x))
new_partials = zero(partials(x))
else
new_partials = partials(x) * y * (^)(v, y - 1)
end
return Dual{Tx}(expv, new_partials)
end,
begin
v = value(y)
expv = (^)(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv * log(x)
return Dual{Ty}(expv, deriv * partials(y))
end
end
)

@define_binary_dual_op(
NaNMath.pow,
begin
vx, vy = value(x), value(y)
expv = NaNMath.pow(vx, vy)
powval = vy * NaNMath.pow(vx, vy - 1)
if isconstant(y)
logval = one(expv)
elseif iszero(vx) && vy > 0
logval = zero(vx)

Check warning on line 598 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L597-L598

Added lines #L597 - L598 were not covered by tests
else
logval = expv * NaNMath.log(vx)

Check warning on line 600 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L600

Added line #L600 was not covered by tests
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
end,
begin
v = value(x)
expv = NaNMath.pow(v, y)
if y == zero(y) || iszero(partials(x))
new_partials = zero(partials(x))

Check warning on line 609 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L609

Added line #L609 was not covered by tests
else
new_partials = partials(x) * y * NaNMath.pow(v, y - 1)
end
return Dual{Tx}(expv, new_partials)
end,
begin
v = value(y)
expv = NaNMath.pow(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*NaNMath.log(x)
return Dual{Ty}(expv, deriv * partials(y))
end
)

@inline Base.literal_pow(::typeof(^), x::Dual{T}, ::Val{0}) where {T} =
Dual{T}(one(value(x)), zero(partials(x)))
Expand Down

0 comments on commit f473514

Please sign in to comment.