From d017644a097848d5402554398a0b344bc1ff26c4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 4 Nov 2024 16:29:40 +0530 Subject: [PATCH] Merge identical methods for Symmetric/Hermitian and SymTridiagonal --- stdlib/LinearAlgebra/src/symmetric.jl | 149 +++++++++++--------------- 1 file changed, 64 insertions(+), 85 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 265995d9e7806c..35817d2610af8b 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -219,10 +219,16 @@ convert(::Type{T}, m::Union{Symmetric,Hermitian}) where {T<:Hermitian} = m isa T const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}} const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}} +const SymSymTri{T} = Union{Symmetric{T}, SymTridiagonal{T}} +const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}} const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}} const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}} +const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}} const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}} +wrapperop(::Union{Symmetric, SymTridiagonal}) = Symmetric +wrapperop(::Hermitian) = Hermitian + size(A::HermOrSym) = size(A.data) axes(A::HermOrSym) = axes(A.data) @inline function Base.isassigned(A::HermOrSym, i::Int, j::Int) @@ -814,15 +820,15 @@ end ^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p) ^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p) ^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p) +function sympow(A::SymSymTri, p::Integer) + if p < 0 + return Symmetric(Base.power_by_squaring(inv(A), -p)) + else + return Symmetric(Base.power_by_squaring(A, p)) + end +end for hermtype in (:Symmetric, :SymTridiagonal) @eval begin - function sympow(A::$hermtype, p::Integer) - if p < 0 - return Symmetric(Base.power_by_squaring(inv(A), -p)) - else - return Symmetric(Base.power_by_squaring(A, p)) - end - end function ^(A::$hermtype{<:Real}, p::Real) isinteger(p) && return integerpow(A, p) F = eigen(A) @@ -844,8 +850,8 @@ function ^(A::Hermitian, p::Integer) else retmat = Base.power_by_squaring(A, p) end - for i = 1:size(A,1) - retmat[i,i] = real(retmat[i,i]) + for i in diagind(retmat, IndexStyle(retmat)) + retmat[i] = real(retmat[i]) end return Hermitian(retmat) end @@ -857,8 +863,8 @@ function ^(A::Hermitian{T}, p::Real) where T if T <: Real return Hermitian(retmat) else - for i = 1:size(A,1) - retmat[i,i] = real(retmat[i,i]) + for i in diagind(retmat, IndexStyle(retmat)) + retmat[i] = real(retmat[i]) end return Hermitian(retmat) end @@ -873,34 +879,25 @@ function ^(A::Hermitian{T}, p::Real) where T end for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt) - for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] - @eval begin - function ($func)(A::$hermtype{<:Real}) - F = eigen(A) - return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors') - end - end - end @eval begin + function ($func)(A::RealHermSymSymTri) + F = eigen(A) + return wrapperop(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors') + end function ($func)(A::Hermitian{<:Complex}) - n = checksquare(A) F = eigen(A) retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' - for i = 1:n - retmat[i,i] = real(retmat[i,i]) + for i in diagind(retmat, IndexStyle(retmat)) + retmat[i] = real(retmat[i]) end return Hermitian(retmat) end end end -for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal) - @eval begin - function cis(A::$wrapper{<:Real}) - F = eigen(A) - return Symmetric(F.vectors .* cis.(F.values') * F.vectors') - end - end +function cis(A::RealHermSymSymTri) + F = eigen(A) + return Symmetric(F.vectors .* cis.(F.values') * F.vectors') end function cis(A::Hermitian{<:Complex}) F = eigen(A) @@ -909,26 +906,21 @@ end for func in (:acos, :asin) - for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] - @eval begin - function ($func)(A::$hermtype{<:Real}) - F = eigen(A) - if all(λ -> -1 ≤ λ ≤ 1, F.values) - return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors') - else - return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') - end + @eval begin + function ($func)(A::RealHermSymSymTri) + F = eigen(A) + if all(λ -> -1 ≤ λ ≤ 1, F.values) + return wrapperop(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') end end - end - @eval begin function ($func)(A::Hermitian{<:Complex}) - n = checksquare(A) F = eigen(A) if all(λ -> -1 ≤ λ ≤ 1, F.values) retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' - for i = 1:n - retmat[i,i] = real(retmat[i,i]) + for i in diagind(retmat, IndexStyle(retmat)) + retmat[i] = real(retmat[i]) end return Hermitian(retmat) else @@ -938,25 +930,20 @@ for func in (:acos, :asin) end end -for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] - @eval begin - function acosh(A::$hermtype{<:Real}) - F = eigen(A) - if all(λ -> λ ≥ 1, F.values) - return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors') - else - return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors') - end - end +function acosh(A::RealHermSymSymTri) + F = eigen(A) + if all(λ -> λ ≥ 1, F.values) + return wrapperop(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors') end end function acosh(A::Hermitian{<:Complex}) - n = checksquare(A) F = eigen(A) if all(λ -> λ ≥ 1, F.values) retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors' - for i = 1:n - retmat[i,i] = real(retmat[i,i]) + for i in diagind(retmat, IndexStyle(retmat)) + retmat[i] = real(retmat[i]) end return Hermitian(retmat) else @@ -964,30 +951,26 @@ function acosh(A::Hermitian{<:Complex}) end end -for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] - @eval begin - function sincos(A::$hermtype{<:Real}) - n = checksquare(A) - F = eigen(A) - S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,))) - for i in 1:n - S.diag[i], C.diag[i] = sincos(F.values[i]) - end - return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors') - end +function sincos(A::RealHermSymSymTri) + n = checksquare(A) + F = eigen(A) + S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,))) + for i in eachindex(S.diag, C.diag, F.values) + S.diag[i], C.diag[i] = sincos(F.values[i]) end + return wrapperop(A)((F.vectors * S) * F.vectors'), wrapperop(A)((F.vectors * C) * F.vectors') end function sincos(A::Hermitian{<:Complex}) n = checksquare(A) F = eigen(A) S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,))) - for i in 1:n + for i in eachindex(S.diag, C.diag, F.values) S.diag[i], C.diag[i] = sincos(F.values[i]) end retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors' - for i = 1:n - retmatS[i,i] = real(retmatS[i,i]) - retmatC[i,i] = real(retmatC[i,i]) + for i in diagind(retmatS, IndexStyle(retmatS)) + retmatS[i] = real(retmatS[i]) + retmatC[i] = real(retmatC[i]) end return Hermitian(retmatS), Hermitian(retmatC) end @@ -997,28 +980,24 @@ for func in (:log, :sqrt) # sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[] rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0 - for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] - @eval begin - function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real} - F = eigen(A) - λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff - if all(λ -> λ ≥ λ₀, F.values) - return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors') - else - return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') - end + @eval begin + function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real} + F = eigen(A) + λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff + if all(λ -> λ ≥ λ₀, F.values) + return wrapperop(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') end end - end - @eval begin function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex} n = checksquare(A) F = eigen(A) λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff if all(λ -> λ ≥ λ₀, F.values) retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors' - for i = 1:n - retmat[i,i] = real(retmat[i,i]) + for i in diagind(retmat, IndexStyle(retmat)) + retmat[i] = real(retmat[i]) end return Hermitian(retmat) else