From b6a2cc1a8e759f3fc8a7222bdd28351416a4805d Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 10 Nov 2024 17:09:56 +0530 Subject: [PATCH] Check `isdiag` in dense trig functions (#56483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This improves performance for dense diagonal matrices, as we may apply the function only to the diagonal elements. ```julia julia> A = diagm(0=>rand(100)); julia> @btime cos($A); 349.211 μs (22 allocations: 401.58 KiB) # nightly v"1.12.0-DEV.1571" 16.215 μs (7 allocations: 80.02 KiB) # this PR ``` --------- Co-authored-by: Daniel Karrasch --- stdlib/LinearAlgebra/src/dense.jl | 72 +++++++++++++++++++++++------- stdlib/LinearAlgebra/test/dense.jl | 37 +++++++-------- 2 files changed, 75 insertions(+), 34 deletions(-) diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 2711bba5cd3ac..d975df1cc0fb7 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -683,7 +683,12 @@ Base.:^(::Irrational{:ℯ}, A::AbstractMatrix) = exp(A) ## "Functions of Matrices: Theory and Computation", SIAM function exp!(A::StridedMatrix{T}) where T<:BlasFloat n = checksquare(A) - if ishermitian(A) + if isdiag(A) + for i in diagind(A, IndexStyle(A)) + A[i] = exp(A[i]) + end + return A + elseif ishermitian(A) return copytri!(parent(exp(Hermitian(A))), 'U', true) end ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A @@ -1014,9 +1019,16 @@ end cbrt(A::AdjointAbsMat) = adjoint(cbrt(parent(A))) cbrt(A::TransposeAbsMat) = transpose(cbrt(parent(A))) +function applydiagonal(f, A) + dinv = f(Diagonal(A)) + copyto!(similar(A, eltype(dinv)), dinv) +end + function inv(A::StridedMatrix{T}) where T checksquare(A) - if istriu(A) + if isdiag(A) + Ai = applydiagonal(inv, A) + elseif istriu(A) Ai = triu!(parent(inv(UpperTriangular(A)))) elseif istril(A) Ai = tril!(parent(inv(LowerTriangular(A)))) @@ -1044,14 +1056,18 @@ julia> cos(fill(1.0, (2,2))) ``` """ function cos(A::AbstractMatrix{<:Real}) - if issymmetric(A) + if isdiag(A) + return applydiagonal(cos, A) + elseif issymmetric(A) return copytri!(parent(cos(Symmetric(A))), 'U') end T = complex(float(eltype(A))) return real(exp!(T.(im .* A))) end function cos(A::AbstractMatrix{<:Complex}) - if ishermitian(A) + if isdiag(A) + return applydiagonal(cos, A) + elseif ishermitian(A) return copytri!(parent(cos(Hermitian(A))), 'U', true) end T = complex(float(eltype(A))) @@ -1077,14 +1093,18 @@ julia> sin(fill(1.0, (2,2))) ``` """ function sin(A::AbstractMatrix{<:Real}) - if issymmetric(A) + if isdiag(A) + return applydiagonal(sin, A) + elseif issymmetric(A) return copytri!(parent(sin(Symmetric(A))), 'U') end T = complex(float(eltype(A))) return imag(exp!(T.(im .* A))) end function sin(A::AbstractMatrix{<:Complex}) - if ishermitian(A) + if isdiag(A) + return applydiagonal(sin, A) + elseif ishermitian(A) return copytri!(parent(sin(Hermitian(A))), 'U', true) end T = complex(float(eltype(A))) @@ -1163,7 +1183,9 @@ julia> tan(fill(1.0, (2,2))) ``` """ function tan(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(tan, A) + elseif ishermitian(A) return copytri!(parent(tan(Hermitian(A))), 'U', true) end S, C = sincos(A) @@ -1177,7 +1199,9 @@ end Compute the matrix hyperbolic cosine of a square matrix `A`. """ function cosh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(cosh, A) + elseif ishermitian(A) return copytri!(parent(cosh(Hermitian(A))), 'U', true) end X = exp(A) @@ -1191,7 +1215,9 @@ end Compute the matrix hyperbolic sine of a square matrix `A`. """ function sinh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(sinh, A) + elseif ishermitian(A) return copytri!(parent(sinh(Hermitian(A))), 'U', true) end X = exp(A) @@ -1205,7 +1231,9 @@ end Compute the matrix hyperbolic tangent of a square matrix `A`. """ function tanh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(tanh, A) + elseif ishermitian(A) return copytri!(parent(tanh(Hermitian(A))), 'U', true) end X = exp(A) @@ -1240,7 +1268,9 @@ julia> acos(cos([0.5 0.1; -0.2 0.3])) ``` """ function acos(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(acos, A) + elseif ishermitian(A) acosHermA = acos(Hermitian(A)) return isa(acosHermA, Hermitian) ? copytri!(parent(acosHermA), 'U', true) : parent(acosHermA) end @@ -1271,7 +1301,9 @@ julia> asin(sin([0.5 0.1; -0.2 0.3])) ``` """ function asin(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(asin, A) + elseif ishermitian(A) asinHermA = asin(Hermitian(A)) return isa(asinHermA, Hermitian) ? copytri!(parent(asinHermA), 'U', true) : parent(asinHermA) end @@ -1302,7 +1334,9 @@ julia> atan(tan([0.5 0.1; -0.2 0.3])) ``` """ function atan(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(atan, A) + elseif ishermitian(A) return copytri!(parent(atan(Hermitian(A))), 'U', true) end SchurF = Schur{Complex}(schur(A)) @@ -1320,7 +1354,9 @@ logarithmic formulas used to compute this function, see [^AH16_4]. [^AH16_4]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577) """ function acosh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(acosh, A) + elseif ishermitian(A) acoshHermA = acosh(Hermitian(A)) return isa(acoshHermA, Hermitian) ? copytri!(parent(acoshHermA), 'U', true) : parent(acoshHermA) end @@ -1339,7 +1375,9 @@ logarithmic formulas used to compute this function, see [^AH16_5]. [^AH16_5]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577) """ function asinh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(asinh, A) + elseif ishermitian(A) return copytri!(parent(asinh(Hermitian(A))), 'U', true) end SchurF = Schur{Complex}(schur(A)) @@ -1357,7 +1395,9 @@ logarithmic formulas used to compute this function, see [^AH16_6]. [^AH16_6]: Mary Aprahamian and Nicholas J. Higham, "Matrix Inverse Trigonometric and Inverse Hyperbolic Functions: Theory and Algorithms", MIMS EPrint: 2016.4. [https://doi.org/10.1137/16M1057577](https://doi.org/10.1137/16M1057577) """ function atanh(A::AbstractMatrix) - if ishermitian(A) + if isdiag(A) + return applydiagonal(atanh, A) + elseif ishermitian(A) return copytri!(parent(atanh(Hermitian(A))), 'U', true) end SchurF = Schur{Complex}(schur(A)) diff --git a/stdlib/LinearAlgebra/test/dense.jl b/stdlib/LinearAlgebra/test/dense.jl index 1d43d76899392..10f50a80ab7fd 100644 --- a/stdlib/LinearAlgebra/test/dense.jl +++ b/stdlib/LinearAlgebra/test/dense.jl @@ -607,6 +607,7 @@ end -0.4579038628067864 1.7361475641080275 6.478801851038108]) A3 = convert(Matrix{elty}, [0.25 0.25; 0 0]) A4 = convert(Matrix{elty}, [0 0.02; 0 0]) + A5 = convert(Matrix{elty}, [2.0 0; 0 3.0]) cosA1 = convert(Matrix{elty},[-0.18287716254368605 -0.29517205254584633 0.761711400552759; 0.23326967400345625 0.19797853773269333 -0.14758602627292305; @@ -614,8 +615,8 @@ end sinA1 = convert(Matrix{elty}, [0.2865568596627417 -1.107751980582015 -0.13772915374386513; -0.6227405671629401 0.2176922827908092 -0.5538759902910078; -0.6227405671629398 -0.6916051440348725 0.3554214365346742]) - @test cos(A1) ≈ cosA1 - @test sin(A1) ≈ sinA1 + @test @inferred(cos(A1)) ≈ cosA1 + @test @inferred(sin(A1)) ≈ sinA1 cosA2 = convert(Matrix{elty}, [-0.6331745163802187 0.12878366262380136 -0.17304181968301532; 0.12878366262380136 -0.5596234510748788 0.5210483146041339; @@ -637,36 +638,36 @@ end @test sin(A4) ≈ sinA4 # Identities - for (i, A) in enumerate((A1, A2, A3, A4)) - @test sincos(A) == (sin(A), cos(A)) + for (i, A) in enumerate((A1, A2, A3, A4, A5)) + @test @inferred(sincos(A)) == (sin(A), cos(A)) @test cos(A)^2 + sin(A)^2 ≈ Matrix(I, size(A)) @test cos(A) ≈ cos(-A) @test sin(A) ≈ -sin(-A) - @test tan(A) ≈ sin(A) / cos(A) + @test @inferred(tan(A)) ≈ sin(A) / cos(A) @test cos(A) ≈ real(exp(im*A)) @test sin(A) ≈ imag(exp(im*A)) @test cos(A) ≈ real(cis(A)) @test sin(A) ≈ imag(cis(A)) - @test cis(A) ≈ cos(A) + im * sin(A) + @test @inferred(cis(A)) ≈ cos(A) + im * sin(A) - @test cosh(A) ≈ 0.5 * (exp(A) + exp(-A)) - @test sinh(A) ≈ 0.5 * (exp(A) - exp(-A)) - @test cosh(A) ≈ cosh(-A) - @test sinh(A) ≈ -sinh(-A) + @test @inferred(cosh(A)) ≈ 0.5 * (exp(A) + exp(-A)) + @test @inferred(sinh(A)) ≈ 0.5 * (exp(A) - exp(-A)) + @test @inferred(cosh(A)) ≈ cosh(-A) + @test @inferred(sinh(A)) ≈ -sinh(-A) # Some of the following identities fail for A3, A4 because the matrices are singular - if i in (1, 2) - @test sec(A) ≈ inv(cos(A)) - @test csc(A) ≈ inv(sin(A)) - @test cot(A) ≈ inv(tan(A)) - @test sech(A) ≈ inv(cosh(A)) - @test csch(A) ≈ inv(sinh(A)) - @test coth(A) ≈ inv(tanh(A)) + if i in (1, 2, 5) + @test @inferred(sec(A)) ≈ inv(cos(A)) + @test @inferred(csc(A)) ≈ inv(sin(A)) + @test @inferred(cot(A)) ≈ inv(tan(A)) + @test @inferred(sech(A)) ≈ inv(cosh(A)) + @test @inferred(csch(A)) ≈ inv(sinh(A)) + @test @inferred(coth(A)) ≈ inv(@inferred tanh(A)) end # The following identities fail for A1, A2 due to rounding errors; # probably needs better algorithm for the general case - if i in (3, 4) + if i in (3, 4, 5) @test cosh(A)^2 - sinh(A)^2 ≈ Matrix(I, size(A)) @test tanh(A) ≈ sinh(A) / cosh(A) end