From 449a3e5966691feb67456a7d919620cac911f75b Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 22 Dec 2022 16:19:38 +0100 Subject: [PATCH] Handle more multiplications with `AbstractQ`s (#117) --- Project.toml | 2 +- src/ArrayLayouts.jl | 4 ++-- src/mul.jl | 23 +++++++++++++++-------- test/test_muladd.jl | 18 ++++++++++++++++++ 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 60cf398..577f5e6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArrayLayouts" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" authors = ["Sheehan Olver "] -version = "0.8.17" +version = "0.8.18" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/ArrayLayouts.jl b/src/ArrayLayouts.jl index bf737f3..a8cba5e 100644 --- a/src/ArrayLayouts.jl +++ b/src/ArrayLayouts.jl @@ -15,8 +15,8 @@ using Base.Broadcast: Broadcasted import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize! -using LinearAlgebra: AbstractTriangular, AbstractQ, QRCompactWYQ, QRPackedQ, checksquare, - pinv, tilebufsize, cholcopy, +using LinearAlgebra: AbstractQ, QRCompactWYQ, QRPackedQ, HessenbergQ, + AbstractTriangular, checksquare, pinv, tilebufsize, cholcopy, norm2, norm1, normInf, normMinusInf, AdjOrTrans, HermOrSym, RealHermSymComplexHerm, AdjointAbsVec, TransposeAbsVec, checknonsingular, _apply_ipiv_rows!, ipiv2perm, chkfullrank diff --git a/src/mul.jl b/src/mul.jl index d774016..107982d 100644 --- a/src/mul.jl +++ b/src/mul.jl @@ -88,19 +88,26 @@ check_mul_axes(A) = nothing _check_mul_axes(::Number, ::Number) = nothing _check_mul_axes(::Number, _) = nothing _check_mul_axes(_, ::Number) = nothing -_check_mul_axes(A, B) = axes(A,2) == axes(B,1) || throw(DimensionMismatch("Second axis of A, $(axes(A,2)), and first axis of B, $(axes(B,1)) must match")) +_check_mul_axes(A, B) = axes(A, 2) == axes(B, 1) || throw(DimensionMismatch("Second axis of A, $(axes(A,2)), and first axis of B, $(axes(B,1)) must match")) +# we need to special case AbstractQ as it allows non-compatiple multiplication +const FlexibleLeftQs = Union{QRCompactWYQ,QRPackedQ,HessenbergQ} +_check_mul_axes(::FlexibleLeftQs, ::Number) = nothing +_check_mul_axes(Q::FlexibleLeftQs, B) = + axes(Q.factors, 1) == axes(B, 1) || axes(Q.factors, 2) == axes(B, 1) || + throw(DimensionMismatch("First axis of B, $(axes(B,1)) must match either axes of A, $(axes(Q.factors))")) +_check_mul_axes(::Number, ::AdjointQtype{<:Any,<:FlexibleLeftQs}) = nothing +function _check_mul_axes(A, adjQ::AdjointQtype{<:Any,<:FlexibleLeftQs}) + Q = parent(adjQ) + axes(A, 2) == axes(Q.factors, 1) || axes(A, 2) == axes(Q.factors, 2) || + throw(DimensionMismatch("Second axis of A, $(axes(A,2)) must match either axes of B, $(axes(Q.factors))")) +end +_check_mul_axes(Q::FlexibleLeftQs, adjQ::AdjointQtype{<:Any,<:FlexibleLeftQs}) = + invoke(_check_mul_axes, Tuple{Any,Any}, Q, adjQ) function check_mul_axes(A, B, C...) _check_mul_axes(A, B) check_mul_axes(B, C...) end -# we need to special case AbstractQ as it allows non-compatiple multiplication -function check_mul_axes(A::Union{QRCompactWYQ,QRPackedQ}, B, C...) - axes(A.factors, 1) == axes(B, 1) || axes(A.factors, 2) == axes(B, 1) || - throw(DimensionMismatch("First axis of B, $(axes(B,1)) must match either axes of A, $(axes(A))")) - check_mul_axes(B, C...) -end - function instantiate(M::Mul) @boundscheck check_mul_axes(M.A, M.B) M diff --git a/test/test_muladd.jl b/test/test_muladd.jl index 731a990..62c8078 100644 --- a/test/test_muladd.jl +++ b/test/test_muladd.jl @@ -615,9 +615,12 @@ Random.seed!(0) Q = qr(randn(5,5)).Q b = randn(5) B = randn(5,5) + @test Q*1.0 == ArrayLayouts.lmul!(Q, Matrix{Float64}(I, 5, 5)) @test Q*b == ArrayLayouts.lmul!(Q, copy(b)) == mul(Q,b) @test Q*B == ArrayLayouts.lmul!(Q, copy(B)) == mul(Q,B) @test B*Q == ArrayLayouts.rmul!(copy(B), Q) == mul(B,Q) + @test 1.0*Q ≈ ArrayLayouts.rmul!(Matrix{Float64}(I, 5, 5), Q) + @test 1.0*Q' ≈ ArrayLayouts.rmul!(Matrix{Float64}(I, 5, 5), Q') @test Q*Q ≈ mul(Q,Q) @test Q'*b == ArrayLayouts.lmul!(Q', copy(b)) == mul(Q',b) @test Q'*B == ArrayLayouts.lmul!(Q', copy(B)) == mul(Q',B) @@ -627,6 +630,21 @@ Random.seed!(0) @test Q'*Q ≈ mul(Q',Q) @test Q*UpperTriangular(B) ≈ mul(Q, UpperTriangular(B)) @test UpperTriangular(B)*Q ≈ mul(UpperTriangular(B), Q) + + Q = qr(randn(7,5)).Q + b = randn(5) + B = randn(5,5) + @test Q*1.0 == ArrayLayouts.lmul!(Q, Matrix{Float64}(I, 7, 7)) + @test Q*b == mul(Q,b) + @test Q*B == mul(Q,B) + @test 1.0*Q ≈ ArrayLayouts.rmul!(Matrix{Float64}(I, 7, 7), Q) + @test Q*Q ≈ mul(Q,Q) + @test B*Q' == mul(B,Q') + @test Q*Q' ≈ mul(Q,Q') + @test Q'*Q' ≈ mul(Q',Q') + @test Q'*Q ≈ mul(Q',Q) + VERSION >= v"1.8-" && @test Q*UpperTriangular(B) ≈ mul(Q, UpperTriangular(B)) + @test UpperTriangular(B)*Q' ≈ mul(UpperTriangular(B), Q') end @testset "Mul" begin