Skip to content

Commit

Permalink
Use mapreduce in triu/tril for strided matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Nov 9, 2024
1 parent 01289e6 commit 3b50b5b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
14 changes: 10 additions & 4 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,10 +1390,16 @@ function istriu(A::AbstractMatrix, k::Integer = 0)
end
istriu(x::Number) = true

_alliszero(V) = all(iszero, V)
# parallel Base.FastContiguousSubArray for a StridedArray
FastContiguousSubArrayStrided{T,N,P<:StridedArray,I<:Tuple{AbstractUnitRange, Vararg{Any}}} = Base.SubArray{T,N,P,I,true}
# using mapreduce instead of all permits vectorization
_alliszero(V::FastContiguousSubArrayStrided) = mapreduce(iszero, &, V, init=true)

@inline function _istriu(A::AbstractMatrix, k)
m, n = size(A)
for j in 1:min(n, m + k - 1)
all(iszero, view(A, max(1, j - k + 1):m, j)) || return false
_alliszero(view(A, max(1, j - k + 1):m, j)) || return false
end
return true
end
Expand Down Expand Up @@ -1438,7 +1444,7 @@ istril(x::Number) = true
@inline function _istril(A::AbstractMatrix, k)
m, n = size(A)
for j in max(1, k + 2):n
all(iszero, view(A, 1:min(j - k - 1, m), j)) || return false
_alliszero(view(A, 1:min(j - k - 1, m), j)) || return false
end
return true
end
Expand Down Expand Up @@ -1485,9 +1491,9 @@ function _isbanded(A::StridedMatrix, kl::Integer, ku::Integer)
Base.require_one_based_indexing(A)
for col in axes(A,2)
toprows = @view A[begin:min(col-ku-1, end), col]
mapreduce(iszero, &, toprows, init=true) || return false
_alliszero(toprows) || return false
bottomrows = @view A[max(begin, col-kl+1):end, col]
mapreduce(iszero, &, bottomrows, init=true) || return false
_alliszero(bottomrows) || return false
end
return true
end
Expand Down
17 changes: 17 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -530,24 +530,37 @@ end
@test !istriu(pentadiag)
@test istriu(pentadiag, -2)
@test !istriu(tridiag)
@test istriu(tridiag) == istriu(tridiagG) == istriu(Tridiag)
@test istriu(tridiag, -1)
@test istriu(tridiag, -1) == istriu(tridiagG, -1) == istriu(Tridiag, -1)
@test istriu(ubidiag)
@test istriu(ubidiag) == istriu(ubidiagG) == istriu(uBidiag)
@test !istriu(ubidiag, 1)
@test istriu(ubidiag, 1) == istriu(ubidiagG, 1) == istriu(uBidiag, 1)
@test !istriu(lbidiag)
@test istriu(lbidiag) == istriu(lbidiagG) == istriu(lBidiag)
@test istriu(lbidiag, -1)
@test istriu(lbidiag, -1) == istriu(lbidiagG, -1) == istriu(lBidiag, -1)
@test istriu(adiag)
@test istriu(adiag) == istriu(adiagG) == istriu(aDiag)
end
@testset "istril" begin
@test !istril(pentadiag)
@test istril(pentadiag, 2)
@test !istril(tridiag)
@test istril(tridiag) == istril(tridiagG) == istril(Tridiag)
@test istril(tridiag, 1)
@test istril(tridiag, 1) == istril(tridiagG, 1) == istril(Tridiag, 1)
@test !istril(ubidiag)
@test istril(ubidiag) == istril(ubidiagG) == istril(ubidiagG)
@test istril(ubidiag, 1)
@test istril(ubidiag, 1) == istril(ubidiagG, 1) == istril(uBidiag, 1)
@test istril(lbidiag)
@test istril(lbidiag) == istril(lbidiagG) == istril(lBidiag)
@test !istril(lbidiag, -1)
@test istril(lbidiag, -1) == istril(lbidiagG, -1) == istril(lBidiag, -1)
@test istril(adiag)
@test istril(adiag) == istril(adiagG) == istril(aDiag)
end
@testset "isbanded" begin
@test isbanded(pentadiag, -2, 2)
Expand Down Expand Up @@ -580,9 +593,13 @@ end
end
@testset "isdiag" begin
@test !isdiag(tridiag)
@test isdiag(tridiag) == isdiag(tridiagG) == isdiag(Tridiag)
@test !isdiag(ubidiag)
@test isdiag(ubidiag) == isdiag(ubidiagG) == isdiag(uBidiag)
@test !isdiag(lbidiag)
@test isdiag(lbidiag) == isdiag(lbidiagG) == isdiag(lBidiag)
@test isdiag(adiag)
@test isdiag(adiag) ==isdiag(adiagG) == isdiag(aDiag)
end
end

Expand Down

0 comments on commit 3b50b5b

Please sign in to comment.