Skip to content

Commit

Permalink
Check isdiag in dense trig functions (#56483)
Browse files Browse the repository at this point in the history
This improves performance for dense diagonal matrices, as we may apply
the function only to the diagonal elements.
```julia
julia> A = diagm(0=>rand(100));

julia> @Btime cos($A);
  349.211 μs (22 allocations: 401.58 KiB) # nightly v"1.12.0-DEV.1571"
  16.215 μs (7 allocations: 80.02 KiB) # this PR
```

---------

Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
  • Loading branch information
jishnub and dkarrasch authored Nov 10, 2024
1 parent 88201cf commit b6a2cc1
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 34 deletions.
72 changes: 56 additions & 16 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,12 @@ Base.:^(::Irrational{:ℯ}, A::AbstractMatrix) = exp(A)
## "Functions of Matrices: Theory and Computation", SIAM
function exp!(A::StridedMatrix{T}) where T<:BlasFloat
n = checksquare(A)
if ishermitian(A)
if isdiag(A)
for i in diagind(A, IndexStyle(A))
A[i] = exp(A[i])
end
return A
elseif ishermitian(A)
return copytri!(parent(exp(Hermitian(A))), 'U', true)
end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
Expand Down Expand Up @@ -1014,9 +1019,16 @@ end
cbrt(A::AdjointAbsMat) = adjoint(cbrt(parent(A)))
cbrt(A::TransposeAbsMat) = transpose(cbrt(parent(A)))

function applydiagonal(f, A)
dinv = f(Diagonal(A))
copyto!(similar(A, eltype(dinv)), dinv)
end

function inv(A::StridedMatrix{T}) where T
checksquare(A)
if istriu(A)
if isdiag(A)
Ai = applydiagonal(inv, A)
elseif istriu(A)
Ai = triu!(parent(inv(UpperTriangular(A))))
elseif istril(A)
Ai = tril!(parent(inv(LowerTriangular(A))))
Expand Down Expand Up @@ -1044,14 +1056,18 @@ julia> cos(fill(1.0, (2,2)))
```
"""
function cos(A::AbstractMatrix{<:Real})
if issymmetric(A)
if isdiag(A)
return applydiagonal(cos, A)
elseif issymmetric(A)
return copytri!(parent(cos(Symmetric(A))), 'U')
end
T = complex(float(eltype(A)))
return real(exp!(T.(im .* A)))
end
function cos(A::AbstractMatrix{<:Complex})
if ishermitian(A)
if isdiag(A)
return applydiagonal(cos, A)
elseif ishermitian(A)
return copytri!(parent(cos(Hermitian(A))), 'U', true)
end
T = complex(float(eltype(A)))
Expand All @@ -1077,14 +1093,18 @@ julia> sin(fill(1.0, (2,2)))
```
"""
function sin(A::AbstractMatrix{<:Real})
if issymmetric(A)
if isdiag(A)
return applydiagonal(sin, A)
elseif issymmetric(A)
return copytri!(parent(sin(Symmetric(A))), 'U')
end
T = complex(float(eltype(A)))
return imag(exp!(T.(im .* A)))
end
function sin(A::AbstractMatrix{<:Complex})
if ishermitian(A)
if isdiag(A)
return applydiagonal(sin, A)
elseif ishermitian(A)
return copytri!(parent(sin(Hermitian(A))), 'U', true)
end
T = complex(float(eltype(A)))
Expand Down Expand Up @@ -1163,7 +1183,9 @@ julia> tan(fill(1.0, (2,2)))
```
"""
function tan(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(tan, A)
elseif ishermitian(A)
return copytri!(parent(tan(Hermitian(A))), 'U', true)
end
S, C = sincos(A)
Expand All @@ -1177,7 +1199,9 @@ end
Compute the matrix hyperbolic cosine of a square matrix `A`.
"""
function cosh(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(cosh, A)
elseif ishermitian(A)
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
end
X = exp(A)
Expand All @@ -1191,7 +1215,9 @@ end
Compute the matrix hyperbolic sine of a square matrix `A`.
"""
function sinh(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(sinh, A)
elseif ishermitian(A)
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
end
X = exp(A)
Expand All @@ -1205,7 +1231,9 @@ end
Compute the matrix hyperbolic tangent of a square matrix `A`.
"""
function tanh(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(tanh, A)
elseif ishermitian(A)
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
end
X = exp(A)
Expand Down Expand Up @@ -1240,7 +1268,9 @@ julia> acos(cos([0.5 0.1; -0.2 0.3]))
```
"""
function acos(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(acos, A)
elseif ishermitian(A)
acosHermA = acos(Hermitian(A))
return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA)
end
Expand Down Expand Up @@ -1271,7 +1301,9 @@ julia> asin(sin([0.5 0.1; -0.2 0.3]))
```
"""
function asin(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(asin, A)
elseif ishermitian(A)
asinHermA = asin(Hermitian(A))
return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA)
end
Expand Down Expand Up @@ -1302,7 +1334,9 @@ julia> atan(tan([0.5 0.1; -0.2 0.3]))
```
"""
function atan(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(atan, A)
elseif ishermitian(A)
return copytri!(parent(atan(Hermitian(A))), 'U', true)
end
SchurF = Schur{Complex}(schur(A))
Expand All @@ -1320,7 +1354,9 @@ logarithmic formulas used to compute this function, see [^AH16_4].
[^AH16_4]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577)
"""
function acosh(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(acosh, A)
elseif ishermitian(A)
acoshHermA = acosh(Hermitian(A))
return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA)
end
Expand All @@ -1339,7 +1375,9 @@ logarithmic formulas used to compute this function, see [^AH16_5].
[^AH16_5]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577)
"""
function asinh(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(asinh, A)
elseif ishermitian(A)
return copytri!(parent(asinh(Hermitian(A))), 'U', true)
end
SchurF = Schur{Complex}(schur(A))
Expand All @@ -1357,7 +1395,9 @@ logarithmic formulas used to compute this function, see [^AH16_6].
[^AH16_6]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577)
"""
function atanh(A::AbstractMatrix)
if ishermitian(A)
if isdiag(A)
return applydiagonal(atanh, A)
elseif ishermitian(A)
return copytri!(parent(atanh(Hermitian(A))), 'U', true)
end
SchurF = Schur{Complex}(schur(A))
Expand Down
37 changes: 19 additions & 18 deletions stdlib/LinearAlgebra/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -607,15 +607,16 @@ end
-0.4579038628067864 1.7361475641080275 6.478801851038108])
A3 = convert(Matrix{elty}, [0.25 0.25; 0 0])
A4 = convert(Matrix{elty}, [0 0.02; 0 0])
A5 = convert(Matrix{elty}, [2.0 0; 0 3.0])

cosA1 = convert(Matrix{elty},[-0.18287716254368605 -0.29517205254584633 0.761711400552759;
0.23326967400345625 0.19797853773269333 -0.14758602627292305;
0.23326967400345636 0.6141253742798355 -0.5637328628200653])
sinA1 = convert(Matrix{elty}, [0.2865568596627417 -1.107751980582015 -0.13772915374386513;
-0.6227405671629401 0.2176922827908092 -0.5538759902910078;
-0.6227405671629398 -0.6916051440348725 0.3554214365346742])
@test cos(A1) cosA1
@test sin(A1) sinA1
@test @inferred(cos(A1)) cosA1
@test @inferred(sin(A1)) sinA1

cosA2 = convert(Matrix{elty}, [-0.6331745163802187 0.12878366262380136 -0.17304181968301532;
0.12878366262380136 -0.5596234510748788 0.5210483146041339;
Expand All @@ -637,36 +638,36 @@ end
@test sin(A4) sinA4

# Identities
for (i, A) in enumerate((A1, A2, A3, A4))
@test sincos(A) == (sin(A), cos(A))
for (i, A) in enumerate((A1, A2, A3, A4, A5))
@test @inferred(sincos(A)) == (sin(A), cos(A))
@test cos(A)^2 + sin(A)^2 Matrix(I, size(A))
@test cos(A) cos(-A)
@test sin(A) -sin(-A)
@test tan(A) sin(A) / cos(A)
@test @inferred(tan(A)) sin(A) / cos(A)

@test cos(A) real(exp(im*A))
@test sin(A) imag(exp(im*A))
@test cos(A) real(cis(A))
@test sin(A) imag(cis(A))
@test cis(A) cos(A) + im * sin(A)
@test @inferred(cis(A)) cos(A) + im * sin(A)

@test cosh(A) 0.5 * (exp(A) + exp(-A))
@test sinh(A) 0.5 * (exp(A) - exp(-A))
@test cosh(A) cosh(-A)
@test sinh(A) -sinh(-A)
@test @inferred(cosh(A)) 0.5 * (exp(A) + exp(-A))
@test @inferred(sinh(A)) 0.5 * (exp(A) - exp(-A))
@test @inferred(cosh(A)) cosh(-A)
@test @inferred(sinh(A)) -sinh(-A)

# Some of the following identities fail for A3, A4 because the matrices are singular
if i in (1, 2)
@test sec(A) inv(cos(A))
@test csc(A) inv(sin(A))
@test cot(A) inv(tan(A))
@test sech(A) inv(cosh(A))
@test csch(A) inv(sinh(A))
@test coth(A) inv(tanh(A))
if i in (1, 2, 5)
@test @inferred(sec(A)) inv(cos(A))
@test @inferred(csc(A)) inv(sin(A))
@test @inferred(cot(A)) inv(tan(A))
@test @inferred(sech(A)) inv(cosh(A))
@test @inferred(csch(A)) inv(sinh(A))
@test @inferred(coth(A)) inv(@inferred tanh(A))
end
# The following identities fail for A1, A2 due to rounding errors;
# probably needs better algorithm for the general case
if i in (3, 4)
if i in (3, 4, 5)
@test cosh(A)^2 - sinh(A)^2 Matrix(I, size(A))
@test tanh(A) sinh(A) / cosh(A)
end
Expand Down

0 comments on commit b6a2cc1

Please sign in to comment.