Skip to content

Commit

Permalink
Merge identical methods for Symmetric/Hermitian and SymTridiagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Nov 4, 2024
1 parent 50713ee commit d017644
Showing 1 changed file with 64 additions and 85 deletions.
149 changes: 64 additions & 85 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,16 @@ convert(::Type{T}, m::Union{Symmetric,Hermitian}) where {T<:Hermitian} = m isa T

const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const SymSymTri{T} = Union{Symmetric{T}, SymTridiagonal{T}}
const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}

wrapperop(::Union{Symmetric, SymTridiagonal}) = Symmetric
wrapperop(::Hermitian) = Hermitian

size(A::HermOrSym) = size(A.data)
axes(A::HermOrSym) = axes(A.data)
@inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
Expand Down Expand Up @@ -814,15 +820,15 @@ end
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
function sympow(A::SymSymTri, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
for hermtype in (:Symmetric, :SymTridiagonal)
@eval begin
function sympow(A::$hermtype, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
function ^(A::$hermtype{<:Real}, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigen(A)
Expand All @@ -844,8 +850,8 @@ function ^(A::Hermitian, p::Integer)
else
retmat = Base.power_by_squaring(A, p)
end
for i = 1:size(A,1)
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
end
Expand All @@ -857,8 +863,8 @@ function ^(A::Hermitian{T}, p::Real) where T
if T <: Real
return Hermitian(retmat)
else
for i = 1:size(A,1)
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
end
Expand All @@ -873,34 +879,25 @@ function ^(A::Hermitian{T}, p::Real) where T
end

for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{<:Real})
F = eigen(A)
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::RealHermSymSymTri)
F = eigen(A)
return wrapperop(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
end
function ($func)(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
end
end
end

for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal)
@eval begin
function cis(A::$wrapper{<:Real})
F = eigen(A)
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
end
end
function cis(A::RealHermSymSymTri)
F = eigen(A)
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
end
function cis(A::Hermitian{<:Complex})
F = eigen(A)
Expand All @@ -909,26 +906,21 @@ end


for func in (:acos, :asin)
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{<:Real})
F = eigen(A)
if all-> -1 λ 1, F.values)
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
@eval begin
function ($func)(A::RealHermSymSymTri)
F = eigen(A)
if all-> -1 λ 1, F.values)
return wrapperop(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
if all-> -1 λ 1, F.values)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
else
Expand All @@ -938,56 +930,47 @@ for func in (:acos, :asin)
end
end

for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function acosh(A::$hermtype{<:Real})
F = eigen(A)
if all-> λ 1, F.values)
return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
end
end
function acosh(A::RealHermSymSymTri)
F = eigen(A)
if all-> λ 1, F.values)
return wrapperop(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
end
end
function acosh(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
if all-> λ 1, F.values)
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
else
return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
end
end

for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function sincos(A::$hermtype{<:Real})
n = checksquare(A)
F = eigen(A)
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
for i in 1:n
S.diag[i], C.diag[i] = sincos(F.values[i])
end
return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors')
end
function sincos(A::RealHermSymSymTri)
n = checksquare(A)
F = eigen(A)
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
for i in eachindex(S.diag, C.diag, F.values)
S.diag[i], C.diag[i] = sincos(F.values[i])
end
return wrapperop(A)((F.vectors * S) * F.vectors'), wrapperop(A)((F.vectors * C) * F.vectors')
end
function sincos(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
for i in 1:n
for i in eachindex(S.diag, C.diag, F.values)
S.diag[i], C.diag[i] = sincos(F.values[i])
end
retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
for i = 1:n
retmatS[i,i] = real(retmatS[i,i])
retmatC[i,i] = real(retmatC[i,i])
for i in diagind(retmatS, IndexStyle(retmatS))
retmatS[i] = real(retmatS[i])
retmatC[i] = real(retmatC[i])
end
return Hermitian(retmatS), Hermitian(retmatC)
end
Expand All @@ -997,28 +980,24 @@ for func in (:log, :sqrt)
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all-> λ λ₀, F.values)
return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
@eval begin
function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all-> λ λ₀, F.values)
return wrapperop(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
n = checksquare(A)
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all-> λ λ₀, F.values)
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
for i = 1:n
retmat[i,i] = real(retmat[i,i])
for i in diagind(retmat, IndexStyle(retmat))
retmat[i] = real(retmat[i])
end
return Hermitian(retmat)
else
Expand Down

0 comments on commit d017644

Please sign in to comment.