From fb5e96acd533fc0619e91c474ce7b74baf04ede0 Mon Sep 17 00:00:00 2001
From: Jishnu Bhattacharya <jishnub.github@gmail.com>
Date: Sat, 9 Nov 2024 16:22:55 +0530
Subject: [PATCH] Merge identical methods for Symmetric/Hermitian and
 SymTridiagonal (#56434)

Since the methods do identical things, we may define each method once
for a union of types instead of defining methods for each type.
---
 stdlib/LinearAlgebra/src/symmetric.jl | 151 +++++++++++---------------
 1 file changed, 65 insertions(+), 86 deletions(-)

diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl
index f8cbac2490794..b059f31737b55 100644
--- a/stdlib/LinearAlgebra/src/symmetric.jl
+++ b/stdlib/LinearAlgebra/src/symmetric.jl
@@ -219,10 +219,16 @@ convert(::Type{T}, m::Union{Symmetric,Hermitian}) where {T<:Hermitian} = m isa T
 
 const HermOrSym{T,        S} = Union{Hermitian{T,S}, Symmetric{T,S}}
 const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
+const SymSymTri{T} = Union{Symmetric{T}, SymTridiagonal{T}}
+const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
 const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
 const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
+const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
 const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}
 
+wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
+wrappertype(::Hermitian) = Hermitian
+
 size(A::HermOrSym) = size(A.data)
 axes(A::HermOrSym) = axes(A.data)
 @inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
@@ -814,15 +820,15 @@ end
 ^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
 ^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
 ^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
+function sympow(A::SymSymTri, p::Integer)
+    if p < 0
+        return Symmetric(Base.power_by_squaring(inv(A), -p))
+    else
+        return Symmetric(Base.power_by_squaring(A, p))
+    end
+end
 for hermtype in (:Symmetric, :SymTridiagonal)
     @eval begin
-        function sympow(A::$hermtype, p::Integer)
-            if p < 0
-                return Symmetric(Base.power_by_squaring(inv(A), -p))
-            else
-                return Symmetric(Base.power_by_squaring(A, p))
-            end
-        end
         function ^(A::$hermtype{<:Real}, p::Real)
             isinteger(p) && return integerpow(A, p)
             F = eigen(A)
@@ -844,8 +850,8 @@ function ^(A::Hermitian, p::Integer)
     else
         retmat = Base.power_by_squaring(A, p)
     end
-    for i = 1:size(A,1)
-        retmat[i,i] = real(retmat[i,i])
+    for i in diagind(retmat, IndexStyle(retmat))
+        retmat[i] = real(retmat[i])
     end
     return Hermitian(retmat)
 end
@@ -857,8 +863,8 @@ function ^(A::Hermitian{T}, p::Real) where T
         if T <: Real
             return Hermitian(retmat)
         else
-            for i = 1:size(A,1)
-                retmat[i,i] = real(retmat[i,i])
+            for i in diagind(retmat, IndexStyle(retmat))
+                retmat[i] = real(retmat[i])
             end
             return Hermitian(retmat)
         end
@@ -873,34 +879,25 @@ function ^(A::Hermitian{T}, p::Real) where T
 end
 
 for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
-    for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
-        @eval begin
-            function ($func)(A::$hermtype{<:Real})
-                F = eigen(A)
-                return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
-            end
-        end
-    end
     @eval begin
+        function ($func)(A::RealHermSymSymTri)
+            F = eigen(A)
+            return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
+        end
         function ($func)(A::Hermitian{<:Complex})
-            n = checksquare(A)
             F = eigen(A)
             retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
-            for i = 1:n
-                retmat[i,i] = real(retmat[i,i])
+            for i in diagind(retmat, IndexStyle(retmat))
+                retmat[i] = real(retmat[i])
             end
             return Hermitian(retmat)
         end
     end
 end
 
-for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal)
-    @eval begin
-        function cis(A::$wrapper{<:Real})
-            F = eigen(A)
-            return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
-        end
-    end
+function cis(A::RealHermSymSymTri)
+    F = eigen(A)
+    return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
 end
 function cis(A::Hermitian{<:Complex})
     F = eigen(A)
@@ -909,26 +906,21 @@ end
 
 
 for func in (:acos, :asin)
-    for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
-        @eval begin
-            function ($func)(A::$hermtype{<:Real})
-                F = eigen(A)
-                if all(λ -> -1 ≤ λ ≤ 1, F.values)
-                    return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
-                else
-                    return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
-                end
+    @eval begin
+        function ($func)(A::RealHermSymSymTri)
+            F = eigen(A)
+            if all(λ -> -1 ≤ λ ≤ 1, F.values)
+                return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
+            else
+                return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
             end
         end
-    end
-    @eval begin
         function ($func)(A::Hermitian{<:Complex})
-            n = checksquare(A)
             F = eigen(A)
             if all(λ -> -1 ≤ λ ≤ 1, F.values)
                 retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
-                for i = 1:n
-                    retmat[i,i] = real(retmat[i,i])
+                for i in diagind(retmat, IndexStyle(retmat))
+                    retmat[i] = real(retmat[i])
                 end
                 return Hermitian(retmat)
             else
@@ -938,25 +930,20 @@ for func in (:acos, :asin)
     end
 end
 
-for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
-    @eval begin
-        function acosh(A::$hermtype{<:Real})
-            F = eigen(A)
-            if all(λ -> λ ≥ 1, F.values)
-                return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
-            else
-                return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
-            end
-        end
+function acosh(A::RealHermSymSymTri)
+    F = eigen(A)
+    if all(λ -> λ ≥ 1, F.values)
+        return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
+    else
+        return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
     end
 end
 function acosh(A::Hermitian{<:Complex})
-    n = checksquare(A)
     F = eigen(A)
     if all(λ -> λ ≥ 1, F.values)
         retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
-        for i = 1:n
-            retmat[i,i] = real(retmat[i,i])
+        for i in diagind(retmat, IndexStyle(retmat))
+            retmat[i] = real(retmat[i])
         end
         return Hermitian(retmat)
     else
@@ -964,32 +951,28 @@ function acosh(A::Hermitian{<:Complex})
     end
 end
 
-for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
-    @eval begin
-        function sincos(A::$hermtype{<:Real})
-            n = checksquare(A)
-            F = eigen(A)
-            T = float(eltype(F.values))
-            S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
-            for i in 1:n
-                S.diag[i], C.diag[i] = sincos(F.values[i])
-            end
-            return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors')
-        end
+function sincos(A::RealHermSymSymTri)
+    n = checksquare(A)
+    F = eigen(A)
+    T = float(eltype(F.values))
+    S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
+    for i in eachindex(S.diag, C.diag, F.values)
+        S.diag[i], C.diag[i] = sincos(F.values[i])
     end
+    return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors')
 end
 function sincos(A::Hermitian{<:Complex})
     n = checksquare(A)
     F = eigen(A)
     T = float(eltype(F.values))
     S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
-    for i in 1:n
+    for i in eachindex(S.diag, C.diag, F.values)
         S.diag[i], C.diag[i] = sincos(F.values[i])
     end
     retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
-    for i = 1:n
-        retmatS[i,i] = real(retmatS[i,i])
-        retmatC[i,i] = real(retmatC[i,i])
+    for i in diagind(retmatS, IndexStyle(retmatS))
+        retmatS[i] = real(retmatS[i])
+        retmatC[i] = real(retmatC[i])
     end
     return Hermitian(retmatS), Hermitian(retmatC)
 end
@@ -999,28 +982,24 @@ for func in (:log, :sqrt)
     # sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
     rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
     rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
-    for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
-        @eval begin
-            function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real}
-                F = eigen(A)
-                λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
-                if all(λ -> λ ≥ λ₀, F.values)
-                    return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
-                else
-                    return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
-                end
+    @eval begin
+        function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
+            F = eigen(A)
+            λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
+            if all(λ -> λ ≥ λ₀, F.values)
+                return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
+            else
+                return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
             end
         end
-    end
-    @eval begin
         function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
             n = checksquare(A)
             F = eigen(A)
             λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
             if all(λ -> λ ≥ λ₀, F.values)
                 retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
-                for i = 1:n
-                    retmat[i,i] = real(retmat[i,i])
+                for i in diagind(retmat, IndexStyle(retmat))
+                    retmat[i] = real(retmat[i])
                 end
                 return Hermitian(retmat)
             else