From f13847ac26cf8f004d3d82d80987359ff2de3b75 Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Thu, 13 Feb 2025 12:44:41 +0800 Subject: [PATCH 1/9] RectDiagonal Multiplication --- src/FillArrays.jl | 1 + src/fillalgebra.jl | 68 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 26 ++++++++++++++++++ 3 files changed, 95 insertions(+) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 66dad480..a2d2b263 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -489,6 +489,7 @@ Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::Abstrac const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}} const RectOrDiagonalFill{T,V<:AbstractFillVector{T},Axes} = RectOrDiagonal{T,V,Axes} const RectDiagonalFill{T,V<:AbstractFillVector{T}} = RectDiagonal{T,V} +const DiagonalFill{T,V<:AbstractFillVector{T}} = Diagonal{T,V} const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}} const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}} diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index f98ae605..4a7f6c23 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -485,6 +485,74 @@ end @inline elconvert(::Type{T}, A::AbstractUnitRange) where T<:Integer = AbstractUnitRange{T}(A) @inline elconvert(::Type{T}, A::AbstractArray) where T = AbstractArray{T}(A) +# RectDiagonal Multiplication +function *(A::RectDiagonal, B::Diagonal) + check_matmul_sizes(A, B) + len = minimum(size(A)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2))) +end +function *(A::Diagonal, B::RectDiagonal) + check_matmul_sizes(A, B) + len = minimum(size(B)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2))) +end + +function *(A::RectDiagonal, B::AbstractMatrix) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + diag = A.diag + out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0) + out[axes(diag, 1), :] = diag .* view(B, axes(diag,1), :) + out +end +function *(A::RectDiagonal, x::AbstractVector) + check_matmul_sizes(A, x) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(x)) + diag = A.diag + out = fill!(similar(diag, TS, axes(A,1)), 0) + out[axes(diag, 1)] = diag .* view(x, axes(diag,1)) + out +end +function *(A::AbstractMatrix, B::RectDiagonal) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0) + diag = B.diag + out[:, axes(diag, 1)] = view(A, :, axes(diag,1)) .* diag' + out +end +function *(A::RectDiagonal, B::RectDiagonal) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + out = fill!(similar(A.diag, TS, min(size(A, 1), size(B, 2))), 0) + len = min(minimum(size(A)), minimum(size(B))) + out[Base.OneTo(len)] .= view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)) + RectDiagonal(out, (size(A,1), size(B,2))) +end + +# RectDiagonalFill Multiplication +*(a::RectDiagonalFill, b::Number) = RectDiagonal(a.diag * b, a.axes) +*(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes) + +# DiagonalFill Multiplication +for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractVector) + @eval begin + function *(A::DiagonalFill, B::$type) + check_matmul_sizes(A, B) + getindex_value(A.diag) * B + end + end +end + +for type in (AbstractMatrix, Diagonal, RectDiagonal, DiagonalFill) + @eval begin + function *(A::$type, B::DiagonalFill) + check_matmul_sizes(A, B) + getindex_value(B.diag) * A + end + end +end + #### # norm #### diff --git a/test/runtests.jl b/test/runtests.jl index 088089d4..b16b3c5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -386,6 +386,32 @@ end @test stringmime("text/plain", D) == "3×2 RectDiagonal{Float64, Vector{Float64}, Tuple{Base.OneTo{$Int}, Base.OneTo{$Int}}}:\n 1.0 ⋅ \n ⋅ 2.0\n ⋅ ⋅ " end +@testset "RectDiagonal Multiplication" begin + diag = Diagonal(rand(3)) + diagfill = Diagonal(Fill(1., 3)) + rectdiag = RectDiagonal(rand(3), 3, 3) + rectdiagfill = RectDiagonal(Fill(1., 3)) + mat = rand(3,3) + vec = rand(3) + arr_diagfill = Array(diagfill) + arr_rectdiag = Array(rectdiag) + arr_rectdiagfill = Array(rectdiagfill) + for a in (diag, diagfill, rectdiag, rectdiagfill, mat, vec) + arr_a = Array(a) + @test diagfill * a isa typeof(a) + @test diagfill * a == arr_diagfill * arr_a + @test rectdiag * a == arr_rectdiag * arr_a + @test rectdiagfill * a == arr_rectdiagfill * arr_a + end + for a in (diag, diagfill, rectdiag, rectdiagfill, mat) + arr_a = Array(a) + @test a * diagfill isa typeof(a) + @test a * diagfill == arr_a * arr_diagfill + @test a * rectdiag == arr_a * arr_rectdiag + @test a * rectdiagfill == arr_a * arr_rectdiagfill + end +end + # Check that all pair-wise combinations of + / - elements of As and Bs yield the correct # type, and produce numerically correct results. as_array(x::AbstractArray) = Array(x) From 04a29110a9f4e216bff711a23babd7aa6e1bcd9d Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Thu, 13 Feb 2025 12:57:22 +0800 Subject: [PATCH 2/9] fix multiplication ambiguities involving Diagonal and AbstractZeros/FillMatrix --- src/fillalgebra.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 4a7f6c23..682b8fed 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -535,7 +535,7 @@ end *(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes) # DiagonalFill Multiplication -for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractVector) +for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AbstractVector, AbstractZerosVector) @eval begin function *(A::DiagonalFill, B::$type) check_matmul_sizes(A, B) @@ -544,7 +544,7 @@ for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractVector) end end -for type in (AbstractMatrix, Diagonal, RectDiagonal, DiagonalFill) +for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, DiagonalFill) @eval begin function *(A::$type, B::DiagonalFill) check_matmul_sizes(A, B) From da40b933bee6b2a1f3c0779bf9ef9d4b1ca07afc Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Thu, 13 Feb 2025 15:54:31 +0800 Subject: [PATCH 3/9] fix adjoint and transpose ambiguity --- src/fillalgebra.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 682b8fed..3de3a35e 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -535,7 +535,7 @@ end *(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes) # DiagonalFill Multiplication -for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AbstractVector, AbstractZerosVector) +for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, AbstractVector, AbstractZerosVector) @eval begin function *(A::DiagonalFill, B::$type) check_matmul_sizes(A, B) @@ -544,7 +544,7 @@ for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, Abstra end end -for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, DiagonalFill) +for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, DiagonalFill) @eval begin function *(A::$type, B::DiagonalFill) check_matmul_sizes(A, B) From b2978b72ffb36bd479a3841cbafea02c1fa6cad1 Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Thu, 13 Feb 2025 16:10:19 +0800 Subject: [PATCH 4/9] use in-place operation && fix DiagonalFill multiplication --- src/fillalgebra.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 3de3a35e..a4213069 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -502,7 +502,7 @@ function *(A::RectDiagonal, B::AbstractMatrix) TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) diag = A.diag out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0) - out[axes(diag, 1), :] = diag .* view(B, axes(diag,1), :) + out[axes(diag, 1), :] .= diag .* view(B, axes(diag,1), :) out end function *(A::RectDiagonal, x::AbstractVector) @@ -510,7 +510,7 @@ function *(A::RectDiagonal, x::AbstractVector) TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(x)) diag = A.diag out = fill!(similar(diag, TS, axes(A,1)), 0) - out[axes(diag, 1)] = diag .* view(x, axes(diag,1)) + out[axes(diag, 1)] .= diag .* view(x, axes(diag,1)) out end function *(A::AbstractMatrix, B::RectDiagonal) @@ -518,7 +518,7 @@ function *(A::AbstractMatrix, B::RectDiagonal) TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0) diag = B.diag - out[:, axes(diag, 1)] = view(A, :, axes(diag,1)) .* diag' + out[:, axes(diag, 1)] .= view(A, :, axes(diag,1)) .* diag' out end function *(A::RectDiagonal, B::RectDiagonal) @@ -544,7 +544,7 @@ for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, Abstra end end -for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, DiagonalFill) +for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec) @eval begin function *(A::$type, B::DiagonalFill) check_matmul_sizes(A, B) @@ -553,6 +553,23 @@ for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, Abstra end end +function *(A::DiagonalFill, B::DiagonalFill) + check_matmul_sizes(A, B) + Diagonal(A.diag .* B.diag) +end + +function *(A::DiagonalFill, B::RectDiagonalFill) + check_matmul_sizes(A, B) + len = minimum(size(B)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), B.axes) +end + +function *(A::RectDiagonalFill, B::DiagonalFill) + check_matmul_sizes(A, B) + len = minimum(size(A)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), A.axes) +end + #### # norm #### From d0ecdc28e49de8b6f86064d72e82dd049680cf10 Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Sun, 16 Feb 2025 18:37:37 +0800 Subject: [PATCH 5/9] fix ambiguities --- src/fillalgebra.jl | 213 ++++++++++++++++++++++---- src/oneelement.jl | 24 +++ test/runtests.jl | 367 ++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 551 insertions(+), 53 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index a4213069..e3d0f248 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -15,6 +15,7 @@ for OP in (:transpose, :adjoint) end $OP(a::AbstractOnesMatrix) = fillsimilar(a, reverse(axes(a))) $OP(a::FillMatrix) = Fill($OP(a.value), reverse(a.axes)) + $OP(a::RectDiagonal) = RectDiagonal(vec($OP(a.diag)), reverse(a.axes)) end end @@ -84,6 +85,11 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) # this treats a size (n,) vector as a nx1 matrix, so b needs to have 1 row # special cased, as OnesMatrix * OnesMatrix isn't a Ones *(a::AbstractOnesVector, b::AbstractOnesMatrix) = mult_ones(a, b) +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}) + @eval begin + *(A::AbstractOnesVector, B::$type) = Ones{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) + end +end *(a::AbstractZerosMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b) *(a::AbstractZerosMatrix, b::AbstractZerosVector) = mult_zeros(a, b) @@ -486,6 +492,9 @@ end @inline elconvert(::Type{T}, A::AbstractArray) where T = AbstractArray{T}(A) # RectDiagonal Multiplication +const RectDiagonalZeros{T,V<:AbstractZerosVector{T}} = RectDiagonal{T,V} +const RectDiagonalOnes{T,V<:AbstractOnesVector{T}} = RectDiagonal{T,V} + function *(A::RectDiagonal, B::Diagonal) check_matmul_sizes(A, B) len = minimum(size(A)) @@ -497,77 +506,219 @@ function *(A::Diagonal, B::RectDiagonal) RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2))) end -function *(A::RectDiagonal, B::AbstractMatrix) - check_matmul_sizes(A, B) - TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) - diag = A.diag - out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0) - out[axes(diag, 1), :] .= diag .* view(B, axes(diag,1), :) - out +for type in (AbstractMatrix, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) + @eval begin + function *(A::RectDiagonal, B::$type) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + diag = A.diag + out = fill!(similar(diag, TS, axes(A,1), axes(B,2)), 0) + len = Base.OneTo(minimum(size(A))) + out[len, :] .= view(diag, len) .* view(B, len, :) + out + end + + function *(A::$type, B::RectDiagonal) + check_matmul_sizes(A, B) + TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) + out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0) + len = Base.OneTo(minimum(size(B))) + out[:, len] .= view(A, :, len) .* view(reshape(B.diag, 1, :), Base.OneTo(1), len) + out + end + end end + function *(A::RectDiagonal, x::AbstractVector) check_matmul_sizes(A, x) TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(x)) diag = A.diag out = fill!(similar(diag, TS, axes(A,1)), 0) - out[axes(diag, 1)] .= diag .* view(x, axes(diag,1)) - out -end -function *(A::AbstractMatrix, B::RectDiagonal) - check_matmul_sizes(A, B) - TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) - out = fill!(similar(A, TS, axes(A,1), axes(B, 2)), 0) - diag = B.diag - out[:, axes(diag, 1)] .= view(A, :, axes(diag,1)) .* diag' + len = Base.OneTo(minimum(size(A))) + out[len] .= view(diag, len) .* view(x, len) out end + function *(A::RectDiagonal, B::RectDiagonal) check_matmul_sizes(A, B) TS = Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) out = fill!(similar(A.diag, TS, min(size(A, 1), size(B, 2))), 0) - len = min(minimum(size(A)), minimum(size(B))) - out[Base.OneTo(len)] .= view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)) + len = Base.OneTo(min(minimum(size(A)), minimum(size(B)))) + out[len] .= view(A.diag, len) .* view(B.diag, len) RectDiagonal(out, (size(A,1), size(B,2))) end -# RectDiagonalFill Multiplication +for type in (RectDiagonal, RectDiagonalZeros) + @eval begin + function *(A::$type, B::AbstractZerosMatrix) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2)) + end + + function *(A::$type, B::AbstractZerosVector) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1)) + end + + function *(A::AbstractZerosMatrix, B::$type) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A, 1), size(B, 2)) + end + + *(A::AdjointAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B + *(A::TransposeAbsVec{<:Any,<:AbstractZerosVector}, B::$type) = Zeros(A) * B + *(A::$type, B::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B) + *(A::$type, B::TransposeAbsVec{<:Any,<:AbstractZerosVector}) = A * Zeros(B) + end +end + +for type in (AbstractMatrix, RectDiagonal, Diagonal, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) + @eval begin + function *(A::$type, B::RectDiagonalZeros) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) + end + function *(A::RectDiagonalZeros, B::$type) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) + end + end +end +function *(A::RectDiagonalZeros, B::AbstractVector) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1)) +end +function *(A::RectDiagonalZeros, B::RectDiagonalZeros) + check_matmul_sizes(A, B) + Zeros{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) +end + *(a::RectDiagonalFill, b::Number) = RectDiagonal(a.diag * b, a.axes) *(a::Number, b::RectDiagonalFill) = RectDiagonal(a * b.diag, b.axes) # DiagonalFill Multiplication -for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, AbstractVector, AbstractZerosVector) +const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V} +const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V} +linearalgebra_types = (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, + AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular, + LowerTriangular, UpperTriangular, LinearAlgebra.AbstractTriangular, Symmetric, Hermitian, + SymTridiagonal, UpperHessenberg, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement) +for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...) @eval begin function *(A::DiagonalFill, B::$type) check_matmul_sizes(A, B) getindex_value(A.diag) * B end + *(A::DiagonalZeros, B::$type) = Zeros(A) * B + function *(A::DiagonalOnes, B::$type) + check_matmul_sizes(A, B) + one(eltype(A)) * B + end end end -for type in (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec) +# TODO: add dim check to all abstract ones multiplication +for type in linearalgebra_types @eval begin function *(A::$type, B::DiagonalFill) check_matmul_sizes(A, B) getindex_value(B.diag) * A end + *(A::$type, B::DiagonalZeros) = A * Zeros(B) + function *(A::$type, B::DiagonalOnes) + check_matmul_sizes(A, B) + one(eltype(B)) * A + end end end -function *(A::DiagonalFill, B::DiagonalFill) - check_matmul_sizes(A, B) - Diagonal(A.diag .* B.diag) +for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros) + for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}, RectDiagonalZeros) + @eval begin + *(A::$type2, B::$type1) = Zeros(A) * B + *(A::$type1, B::$type2) = A * Zeros(B) + end + end end -function *(A::DiagonalFill, B::RectDiagonalFill) +for type in (DiagonalFill, DiagonalOnes, RectDiagonalFill) + @eval begin + *(A::$type, B::DiagonalZeros) = A * Zeros(B) + *(A::DiagonalZeros, B::$type) = Zeros(A) * B + end +end +function *(A::DiagonalZeros, B::DiagonalZeros) check_matmul_sizes(A, B) - len = minimum(size(B)) - RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), B.axes) + Zeros{promote_type(eltype(A),eltype(B))}(A) +end + +for type1 in (DiagonalFill, DiagonalOnes) + for type2 in (DiagonalFill, DiagonalOnes) + if type1 !== DiagonalOnes || type2 !== DiagonalOnes + @eval begin + function *(A::$type1, B::$type2) + check_matmul_sizes(A, B) + getindex_value(A.diag) * B + end + end + end + end end - -function *(A::RectDiagonalFill, B::DiagonalFill) +function *(A::DiagonalOnes, B::DiagonalOnes) check_matmul_sizes(A, B) - len = minimum(size(A)) - RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), A.axes) + Diagonal(Ones{promote_type(eltype(A), eltype(B))}(size(A, 1))) +end + +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix, AbstractOnesVector) + @eval begin + *(A::DiagonalOnes, B::$type) = Ones{promote_type(eltype(A),eltype(B))}(size(B)) + end +end +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix) + @eval begin + *(A::$type, B::DiagonalOnes) = Ones{promote_type(eltype(A),eltype(B))}(size(A)) + end +end + +for type in (DiagonalFill, DiagonalOnes) + @eval begin + function *(A::$type, B::RectDiagonalFill) + check_matmul_sizes(A, B) + len = minimum(size(B)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), size(B)) + end + + function *(A::RectDiagonalFill, B::$type) + check_matmul_sizes(A, B) + len = minimum(size(A)) + RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), size(A)) + end + end +end + +for type1 in (AbstractMatrix, Diagonal) + for type2 in (Diagonal, DiagonalOnes, DiagonalFill) + @eval begin + *(Da::DiagonalZeros, A::$type1, Db::$type2) = Zeros(Da) * A + *(Da::$type2, A::$type1, Db::DiagonalZeros) = A * Zeros(Db) + end + end + + for type2 in (Diagonal, DiagonalFill) + @eval begin + *(Da::DiagonalOnes, A::$type1, Db::$type2) = A * Db + *(Da::$type2, A::$type1, Db::DiagonalOnes) = Da * A + end + end + + @eval begin + *(Da::DiagonalZeros, A::$type1, Db::DiagonalZeros) = Zeros(Da) * A * Zeros(Db) + *(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes) = A + + *(Da::DiagonalFill, A::$type1, Db::Diagonal) = getindex_value(Da.diag) * A * Db + *(Da::Diagonal, A::$type1, Db::DiagonalFill) = Da * A * getindex_value(Db.diag) + *(Da::DiagonalFill, A::$type1, Db::DiagonalFill) = getindex_value(Da.diag) * getindex_value(Db.diag) * A + end end #### diff --git a/src/oneelement.jl b/src/oneelement.jl index 9b76d35f..6e6bd200 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -190,6 +190,30 @@ function *(D::Diagonal, A::OneElementMatrix) OneElement(val, A.ind, size(A)) end +function *(A::OneElementMatrix, D::DiagonalZeros) + check_matmul_sizes(A, D) + Zeros{promote_type(eltype(A),eltype(D))}(size(A, 1), size(D, 2)) +end + +function *(D::DiagonalZeros, A::OneElementMatrix) + check_matmul_sizes(D, A) + Zeros{promote_type(eltype(A),eltype(D))}(size(D, 1), size(A, 2)) +end + +for type in (DiagonalFill, DiagonalOnes) + @eval begin + function *(A::OneElementMatrix, D::$type) + check_matmul_sizes(A, D) + getindex_value(D.diag) * A + end + + function *(D::$type, A::OneElementMatrix) + check_matmul_sizes(D, A) + getindex_value(D.diag) * A + end + end +end + # Inplace multiplication # We use this for out overloads for _mul! for OneElement because its more efficient diff --git a/test/runtests.jl b/test/runtests.jl index b16b3c5b..04be62f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -387,28 +387,351 @@ end end @testset "RectDiagonal Multiplication" begin - diag = Diagonal(rand(3)) - diagfill = Diagonal(Fill(1., 3)) - rectdiag = RectDiagonal(rand(3), 3, 3) - rectdiagfill = RectDiagonal(Fill(1., 3)) - mat = rand(3,3) - vec = rand(3) - arr_diagfill = Array(diagfill) - arr_rectdiag = Array(rectdiag) - arr_rectdiagfill = Array(rectdiagfill) - for a in (diag, diagfill, rectdiag, rectdiagfill, mat, vec) - arr_a = Array(a) - @test diagfill * a isa typeof(a) - @test diagfill * a == arr_diagfill * arr_a - @test rectdiag * a == arr_rectdiag * arr_a - @test rectdiagfill * a == arr_rectdiagfill * arr_a - end - for a in (diag, diagfill, rectdiag, rectdiagfill, mat) - arr_a = Array(a) - @test a * diagfill isa typeof(a) - @test a * diagfill == arr_a * arr_diagfill - @test a * rectdiag == arr_a * arr_rectdiag - @test a * rectdiagfill == arr_a * arr_rectdiagfill + using FillArrays: RectDiagonalFill, RectDiagonalZeros, RectDiagonalOnes, DiagonalFill, DiagonalOnes, DiagonalZeros + m = 3 + n = 3 + val = 2.0 + + # Create the instances as given. + instances = Dict( + :RectDiagonal => RectDiagonal(rand(m), m, n), + :RectDiagonalFill => RectDiagonal(Fill(val, min(m, n)), m, n), + :RectDiagonalZeros => RectDiagonal(Zeros(min(m, n)), m, n), + :RectDiagonalOnes => RectDiagonal(Ones(min(m, n)), m, n), + :Diagonal => Diagonal(rand(m)), + :DiagonalFill => Diagonal(Fill(val, m)), + :DiagonalZeros => Diagonal(Zeros(m)), + :DiagonalOnes => Diagonal(Ones(m)), + :Zeros => Zeros(m, n), + :Ones => Ones(m, n), + :Fill => Fill(val, m, n), + :Mat => rand(m, n), + ) + + mat_instances = Dict( + :TransZerosVec => Zeros(n), + :TransOnesVec => Ones(n), + :TransFillVec => Fill(1., n), + :TransVec => rand(n), + ) + + vec_instances = Dict( + :Vec => rand(n), + :ZerosVec => Zeros(n), + :OnesVec => Ones(n), + :FillVec => Fill(1., n) + ) + + # Expected outcome table. + # The header (in order) corresponds to the following instance symbols: + # :RectDiagonal, :RectDiagonalFill, :RectDiagonalZeros, :RectDiagonalOnes, + # :Diagonal, :DiagonalFill, :DiagonalZeros, :DiagonalOnes, :Zeros, :Ones, :Fill, :Mat + # Each row gives the expected resultant type when doing multiplication, + expected = Dict( + :RectDiagonal => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => RectDiagonal, + :DiagonalFill => RectDiagonal, + :DiagonalZeros => Zeros, + :DiagonalOnes => RectDiagonal, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector + ), + :RectDiagonalFill => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => RectDiagonal, + :DiagonalFill => RectDiagonalFill, + :DiagonalZeros => Zeros, + :DiagonalOnes => RectDiagonalFill, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector + ), + :RectDiagonalZeros => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + :Vec => Zeros, + :ZerosVec => Zeros, + :OnesVec => Zeros, + :FillVec => Zeros + ), + :RectDiagonalOnes => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => RectDiagonal, + :DiagonalFill => RectDiagonalFill, + :DiagonalZeros => Zeros, + :DiagonalOnes => RectDiagonalOnes, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector + ), + :Diagonal => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonal, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonal, + :Diagonal => Diagonal, + :DiagonalFill => Diagonal, + :DiagonalZeros => Zeros, + :DiagonalOnes => Diagonal, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector + ), + :DiagonalFill => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonalFill, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonalFill, + :Diagonal => Diagonal, + :DiagonalFill => DiagonalFill, + :DiagonalZeros => Zeros, + :DiagonalOnes => DiagonalFill, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Fill, + :FillVec => Fill + ), + :DiagonalZeros => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + :Vec => Zeros, + :ZerosVec => Zeros, + :OnesVec => Zeros, + :FillVec => Zeros + ), + :DiagonalOnes => Dict( + :RectDiagonal => RectDiagonal, + :RectDiagonalFill => RectDiagonalFill, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => RectDiagonalOnes, + :Diagonal => Diagonal, + :DiagonalFill => DiagonalFill, + :DiagonalZeros => Zeros, + :DiagonalOnes => DiagonalOnes, + :Zeros => Zeros, + :Ones => Ones, + :Fill => Fill, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Ones, + :FillVec => Fill + ), + :Zeros => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + :Vec => Zeros, + :ZerosVec => Zeros, + :OnesVec => Zeros, + :FillVec => Zeros + ), + :Ones => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Ones, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + :Vec => Fill, + :ZerosVec => Zeros, + :OnesVec => Fill, + :FillVec => Fill + ), + :Fill => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Fill, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + :Vec => Fill, + :ZerosVec => Zeros, + :OnesVec => Fill, + :FillVec => Fill + ), + :Mat => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Array, + :DiagonalZeros => Zeros, + :DiagonalOnes => Array, + :Zeros => Zeros, + :Ones => Array, + :Fill => Array, + :Mat => Array, + :Vec => Vector, + :ZerosVec => Zeros, + :OnesVec => Vector, + :FillVec => Vector + ), + :TransZerosVec => Dict( + :RectDiagonal => Zeros, + :RectDiagonalFill => Zeros, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Zeros, + :Diagonal => Zeros, + :DiagonalFill => Zeros, + :DiagonalZeros => Zeros, + :DiagonalOnes => Zeros, + :Zeros => Zeros, + :Ones => Zeros, + :Fill => Zeros, + :Mat => Zeros, + ), + :TransOnesVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Ones, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), + :TransFillVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Fill, + :DiagonalZeros => Zeros, + :DiagonalOnes => Fill, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), + :TransVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Array, + :DiagonalZeros => Zeros, + :DiagonalOnes => Array, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), + ) + + for (k2, B) in instances + for op2 in (adjoint, transpose, identity) + for (k1, A) in instances + for op1 in (adjoint, transpose, identity) + @test typeof(op1(A) * op2(B)) <: expected[k1][k2] + end + end + + for (k1, A) in mat_instances + for op1 in (adjoint, transpose) + result = op1(A) * op2(B) + @test typeof(result) <: expected[k1][k2] || typeof(result.parent) <: expected[k1][k2] + if !(typeof(result) <: expected[k1][k2]) && !(typeof(result.parent) <: expected[k1][k2]) + @show op1 op2 typeof(op1(A)) typeof(op2(B)) typeof(result) expected[k1][k2] + end + end + end + end + end + + for (k1, A) in instances + for (k2, B) in vec_instances + for op1 in (adjoint, transpose, identity) + @test typeof(op1(A) * B) <: expected[k1][k2] + end + + if !(typeof(A*B) <: expected[k1][k2]) + @show k1 k2 expected[k1][k2] + end + end end end From 194df93fab126bf6a122e04dedd40f07466dbe86 Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Sun, 16 Feb 2025 20:39:54 +0800 Subject: [PATCH 6/9] fix outer product --- src/fillalgebra.jl | 33 ++++++++++++++++++++++++++++++++- test/runtests.jl | 40 ++++++++++++++++++++++++++++++++-------- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index e3d0f248..1123b361 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -81,13 +81,44 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b) +for type in (AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}) + @eval begin + function *(A::AbstractFillVector, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2)) + end + end +end # this treats a size (n,) vector as a nx1 matrix, so b needs to have 1 row # special cased, as OnesMatrix * OnesMatrix isn't a Ones *(a::AbstractOnesVector, b::AbstractOnesMatrix) = mult_ones(a, b) for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}) @eval begin - *(A::AbstractOnesVector, B::$type) = Ones{promote_type(eltype(A),eltype(B))}(size(A,1), size(B,2)) + *(A::AbstractOnesVector, B::$type) = mult_ones(A, B) + end +end + +for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}) + for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector) + @eval begin + function *(A::$type1, B::$type2) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) + end + end + end +end + +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}, ) + @eval begin + function *(A::AbstractZerosVector, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index b958ccf1..0254dace 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -386,7 +386,7 @@ end @test stringmime("text/plain", D) == "3×2 RectDiagonal{Float64, Vector{Float64}, Tuple{Base.OneTo{$Int}, Base.OneTo{$Int}}}:\n 1.0 ⋅ \n ⋅ 2.0\n ⋅ ⋅ " end -@testset "RectDiagonal Multiplication" begin +@testset "RectDiagonal multiplication" begin using FillArrays: RectDiagonalFill, RectDiagonalZeros, RectDiagonalOnes, DiagonalFill, DiagonalOnes, DiagonalZeros m = 3 n = 3 @@ -644,6 +644,30 @@ end :OnesVec => Vector, :FillVec => Vector ), + :Vec => Dict( + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, + ), + :ZerosVec => Dict( + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, + ), + :OnesVec => Dict( + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Ones, + :TransFillVec => Fill, + ), + :FillVec => Dict( + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Fill, + :TransFillVec => Fill, + ), :TransZerosVec => Dict( :RectDiagonal => Zeros, :RectDiagonalFill => Zeros, @@ -714,22 +738,22 @@ end for op1 in (adjoint, transpose) result = op1(A) * op2(B) @test typeof(result) <: expected[k1][k2] || typeof(result.parent) <: expected[k1][k2] - if !(typeof(result) <: expected[k1][k2]) && !(typeof(result.parent) <: expected[k1][k2]) - @show op1 op2 typeof(op1(A)) typeof(op2(B)) typeof(result) expected[k1][k2] - end end end end end - for (k1, A) in instances - for (k2, B) in vec_instances + for (k2, B) in vec_instances + for (k1, A) in instances for op1 in (adjoint, transpose, identity) @test typeof(op1(A) * B) <: expected[k1][k2] end - if !(typeof(A*B) <: expected[k1][k2]) - @show k1 k2 expected[k1][k2] + end + for (k1, A) in mat_instances + for op1 in (adjoint, transpose) + @test typeof(op1(A)*B) <: Number + @test typeof(B*op1(A)) <: expected[k2][k1] end end end From 9daec23fc5a18c8474019bef9a9e90a605c5db34 Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Sun, 16 Feb 2025 21:14:46 +0800 Subject: [PATCH 7/9] Check matmul sizes --- src/fillalgebra.jl | 61 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 1123b361..099c40b7 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -648,7 +648,6 @@ for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...) end end -# TODO: add dim check to all abstract ones multiplication for type in linearalgebra_types @eval begin function *(A::$type, B::DiagonalFill) @@ -702,12 +701,18 @@ end for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix, AbstractOnesVector) @eval begin - *(A::DiagonalOnes, B::$type) = Ones{promote_type(eltype(A),eltype(B))}(size(B)) + function *(A::DiagonalOnes, B::$type) + check_matmul_sizes(A, B) + Ones{promote_type(eltype(A),eltype(B))}(size(B)) + end end end for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix) @eval begin - *(A::$type, B::DiagonalOnes) = Ones{promote_type(eltype(A),eltype(B))}(size(A)) + function *(A::$type, B::DiagonalOnes) + check_matmul_sizes(A, B) + Ones{promote_type(eltype(A),eltype(B))}(size(A)) + end end end @@ -715,14 +720,14 @@ for type in (DiagonalFill, DiagonalOnes) @eval begin function *(A::$type, B::RectDiagonalFill) check_matmul_sizes(A, B) - len = minimum(size(B)) - RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), size(B)) + len = Base.OneTo(minimum(size(B))) + RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(B)) end function *(A::RectDiagonalFill, B::$type) check_matmul_sizes(A, B) - len = minimum(size(A)) - RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), size(A)) + len = Base.OneTo(minimum(size(A))) + RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(A)) end end end @@ -730,25 +735,51 @@ end for type1 in (AbstractMatrix, Diagonal) for type2 in (Diagonal, DiagonalOnes, DiagonalFill) @eval begin - *(Da::DiagonalZeros, A::$type1, Db::$type2) = Zeros(Da) * A - *(Da::$type2, A::$type1, Db::DiagonalZeros) = A * Zeros(Db) + function *(Da::DiagonalZeros, A::$type1, Db::$type2) + check_matmul_sizes(A, Db) + Zeros(Da) * A + end + function *(Da::$type2, A::$type1, Db::DiagonalZeros) + check_matmul_sizes(Da, A) + A * Zeros(Db) + end end end for type2 in (Diagonal, DiagonalFill) @eval begin - *(Da::DiagonalOnes, A::$type1, Db::$type2) = A * Db - *(Da::$type2, A::$type1, Db::DiagonalOnes) = Da * A + function *(Da::DiagonalOnes, A::$type1, Db::$type2) + check_matmul_sizes(Da, A) + ones(eltype(Da)) * A * Db + end + function *(Da::$type2, A::$type1, Db::DiagonalOnes) + check_matmul_sizes(A, Db) + Da * A * ones(eltype(Db)) + end end end @eval begin *(Da::DiagonalZeros, A::$type1, Db::DiagonalZeros) = Zeros(Da) * A * Zeros(Db) - *(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes) = A + function *(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes) + check_matmul_sizes(Da, A) + check_matmul_sizes(A, Db) + (one(eltype(Da)) * one(eltype(Db))) * A + end - *(Da::DiagonalFill, A::$type1, Db::Diagonal) = getindex_value(Da.diag) * A * Db - *(Da::Diagonal, A::$type1, Db::DiagonalFill) = Da * A * getindex_value(Db.diag) - *(Da::DiagonalFill, A::$type1, Db::DiagonalFill) = getindex_value(Da.diag) * getindex_value(Db.diag) * A + function *(Da::DiagonalFill, A::$type1, Db::Diagonal) + check_matmul_sizes(Da, A) + getindex_value(Da.diag) * A * Db + end + function *(Da::Diagonal, A::$type1, Db::DiagonalFill) + check_matmul_sizes(A, Db) + Da * A * getindex_value(Db.diag) + end + function *(Da::DiagonalFill, A::$type1, Db::DiagonalFill) + check_matmul_sizes(Da, A) + check_matmul_sizes(A, Db) + (getindex_value(Da.diag) * getindex_value(Db.diag)) * A + end end end From 06a18c766ccaea807c4c0b558b00a59ee6b99ce0 Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Mon, 17 Feb 2025 16:34:06 +0800 Subject: [PATCH 8/9] optimize code & improve test coverage --- src/fillalgebra.jl | 130 +++++++---------------- src/fillbroadcast.jl | 15 +++ test/runtests.jl | 243 +++++++++++++++++++++++++++++++++---------- 3 files changed, 239 insertions(+), 149 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 099c40b7..2937c3ca 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -81,13 +81,18 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) *(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b) *(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b) -for type in (AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}) +for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector}) @eval begin function *(A::AbstractFillVector, B::$type) size(A,2) == size(B,1) || throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2)) end + function *(A::AbstractFillMatrix, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2)) + end end end @@ -97,11 +102,12 @@ end for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}) @eval begin *(A::AbstractOnesVector, B::$type) = mult_ones(A, B) + *(A::AbstractOnesMatrix, B::$type) = mult_ones(A, B) end end for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector}) - for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector) + for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector, AbstractFillMatrix, AbstractZerosMatrix, AbstractOnesMatrix) @eval begin function *(A::$type1, B::$type2) size(A,2) == size(B,1) || @@ -119,6 +125,11 @@ for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<: throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) end + function *(A::AbstractZerosMatrix, B::$type) + size(A,2) == size(B,1) || + throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))")) + Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2)) + end end end @@ -630,11 +641,11 @@ end # DiagonalFill Multiplication const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V} const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V} -linearalgebra_types = (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix, +mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular, LowerTriangular, UpperTriangular, LinearAlgebra.AbstractTriangular, Symmetric, Hermitian, SymTridiagonal, UpperHessenberg, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement) -for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...) +for type in tuple(AbstractVector, AbstractZerosVector, mat_types...) @eval begin function *(A::DiagonalFill, B::$type) check_matmul_sizes(A, B) @@ -643,12 +654,13 @@ for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...) *(A::DiagonalZeros, B::$type) = Zeros(A) * B function *(A::DiagonalOnes, B::$type) check_matmul_sizes(A, B) - one(eltype(A)) * B + convert(AbstractArray{promote_type(eltype(A), eltype(B))}, deepcopy(B)) end end end +*(A::DiagonalOnes, B::AbstractRange) = one(eltype(A)) * B -for type in linearalgebra_types +for type in mat_types @eval begin function *(A::$type, B::DiagonalFill) check_matmul_sizes(A, B) @@ -657,7 +669,7 @@ for type in linearalgebra_types *(A::$type, B::DiagonalZeros) = A * Zeros(B) function *(A::$type, B::DiagonalOnes) check_matmul_sizes(A, B) - one(eltype(B)) * A + convert(AbstractMatrix{promote_type(eltype(A), eltype(B))}, deepcopy(A)) end end end @@ -669,53 +681,22 @@ for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros) *(A::$type1, B::$type2) = A * Zeros(B) end end -end - -for type in (DiagonalFill, DiagonalOnes, RectDiagonalFill) @eval begin - *(A::$type, B::DiagonalZeros) = A * Zeros(B) - *(A::DiagonalZeros, B::$type) = Zeros(A) * B + *(A::Diagonal, B::$type1) = Diagonal(A.diag .* B.diag) + *(A::$type1, B::Diagonal) = Diagonal(A.diag .* B.diag) end end -function *(A::DiagonalZeros, B::DiagonalZeros) - check_matmul_sizes(A, B) - Zeros{promote_type(eltype(A),eltype(B))}(A) -end -for type1 in (DiagonalFill, DiagonalOnes) - for type2 in (DiagonalFill, DiagonalOnes) - if type1 !== DiagonalOnes || type2 !== DiagonalOnes - @eval begin - function *(A::$type1, B::$type2) - check_matmul_sizes(A, B) - getindex_value(A.diag) * B - end - end - end - end -end -function *(A::DiagonalOnes, B::DiagonalOnes) - check_matmul_sizes(A, B) - Diagonal(Ones{promote_type(eltype(A), eltype(B))}(size(A, 1))) -end - -for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix, AbstractOnesVector) - @eval begin - function *(A::DiagonalOnes, B::$type) - check_matmul_sizes(A, B) - Ones{promote_type(eltype(A),eltype(B))}(size(B)) - end - end -end -for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix) - @eval begin - function *(A::$type, B::DiagonalOnes) - check_matmul_sizes(A, B) - Ones{promote_type(eltype(A),eltype(B))}(size(A)) +for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros) + for type2 in (DiagonalFill, DiagonalOnes, DiagonalZeros) + @eval begin + *(A::$type1, B::$type2) = Diagonal(A.diag .* B.diag) end end end +*(A::RectDiagonalFill, B::DiagonalZeros) = A * Zeros(B) +*(A::DiagonalZeros, B::RectDiagonalFill) = Zeros(A) * B for type in (DiagonalFill, DiagonalOnes) @eval begin function *(A::$type, B::RectDiagonalFill) @@ -732,54 +713,15 @@ for type in (DiagonalFill, DiagonalOnes) end end -for type1 in (AbstractMatrix, Diagonal) - for type2 in (Diagonal, DiagonalOnes, DiagonalFill) - @eval begin - function *(Da::DiagonalZeros, A::$type1, Db::$type2) - check_matmul_sizes(A, Db) - Zeros(Da) * A - end - function *(Da::$type2, A::$type1, Db::DiagonalZeros) - check_matmul_sizes(Da, A) - A * Zeros(Db) - end - end - end - - for type2 in (Diagonal, DiagonalFill) - @eval begin - function *(Da::DiagonalOnes, A::$type1, Db::$type2) - check_matmul_sizes(Da, A) - ones(eltype(Da)) * A * Db - end - function *(Da::$type2, A::$type1, Db::DiagonalOnes) - check_matmul_sizes(A, Db) - Da * A * ones(eltype(Db)) - end - end - end - - @eval begin - *(Da::DiagonalZeros, A::$type1, Db::DiagonalZeros) = Zeros(Da) * A * Zeros(Db) - function *(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes) - check_matmul_sizes(Da, A) - check_matmul_sizes(A, Db) - (one(eltype(Da)) * one(eltype(Db))) * A - end - - function *(Da::DiagonalFill, A::$type1, Db::Diagonal) - check_matmul_sizes(Da, A) - getindex_value(Da.diag) * A * Db - end - function *(Da::Diagonal, A::$type1, Db::DiagonalFill) - check_matmul_sizes(A, Db) - Da * A * getindex_value(Db.diag) - end - function *(Da::DiagonalFill, A::$type1, Db::DiagonalFill) - check_matmul_sizes(Da, A) - check_matmul_sizes(A, Db) - (getindex_value(Da.diag) * getindex_value(Db.diag)) * A - end +function *(Da::Diagonal, A::RectDiagonal, Db::Diagonal) + check_matmul_sizes(Da, A) + check_matmul_sizes(A, Db) + len = Base.OneTo(minimum(size(A))) + diag = view(Da.diag, len) .* view(A.diag, len) .* view(Db.diag, len) + if diag isa Zeros + Zeros{eltype(diag)}(axes(A)) + else + RectDiagonal(diag, axes(A)) end end diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 2b5ea59c..0b0b4c3e 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -259,6 +259,21 @@ broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where { broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r)) broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r)) +# ternary broadcasting +for type1 in (AbstractArray, AbstractFill, AbstractZeros) + for type2 in (AbstractArray, AbstractFill, AbstractZeros) + for type3 in (AbstractArray, AbstractFill, AbstractZeros) + if type1 === AbstractZeros || type2 === AbstractZeros || type3 === AbstractZeros + @eval begin + broadcasted(::DefaultArrayStyle, ::typeof(*), a::$type1, b::$type2, c::$type3) = Zeros{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c))) + end + end + end + end +end +broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractOnes, b::AbstractOnes, c::AbstractOnes) = Ones{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c))) +broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractFill, c::AbstractFill) = Fill(getindex_value(a)*getindex_value(b)*getindex_value(c), broadcast_shape(axes(a), axes(b), axes(c))) + # support AbstractFill .^ k broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r)) broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r)) diff --git a/test/runtests.jl b/test/runtests.jl index 0254dace..9035fa7e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -388,40 +388,83 @@ end @testset "RectDiagonal multiplication" begin using FillArrays: RectDiagonalFill, RectDiagonalZeros, RectDiagonalOnes, DiagonalFill, DiagonalOnes, DiagonalZeros - m = 3 - n = 3 + val = 2.0 - # Create the instances as given. - instances = Dict( + n = 3 + square_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(n), n, n), + :RectDiagonalFill => RectDiagonal(Fill(val, n), n, n), + :RectDiagonalZeros => RectDiagonal(Zeros(n), n, n), + :RectDiagonalOnes => RectDiagonal(Ones(n), n, n), + :Diagonal => Diagonal(rand(n)), + :DiagonalFill => Diagonal(Fill(val, n)), + :DiagonalZeros => Diagonal(Zeros(n)), + :DiagonalOnes => Diagonal(Ones(n)), + :Zeros => Zeros(n, n), + :Ones => Ones(n, n), + :Fill => Fill(val, n, n), + :Mat => rand(n, n), + ) + + m = 1 + n = 3 + row_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(m), m, n), + :RectDiagonalFill => RectDiagonal(Fill(val, min(m, n)), m, n), + :RectDiagonalZeros => RectDiagonal(Zeros(min(m, n)), m, n), + :RectDiagonalOnes => RectDiagonal(Ones(min(m, n)), m, n), + :Zeros => Zeros(m, n), + :Ones => Ones(m, n), + :Fill => Fill(val, m, n), + :Mat => rand(m, n), + ) + + m = 3 + n = 1 + col_mat_instances = Dict( :RectDiagonal => RectDiagonal(rand(m), m, n), :RectDiagonalFill => RectDiagonal(Fill(val, min(m, n)), m, n), :RectDiagonalZeros => RectDiagonal(Zeros(min(m, n)), m, n), :RectDiagonalOnes => RectDiagonal(Ones(min(m, n)), m, n), - :Diagonal => Diagonal(rand(m)), - :DiagonalFill => Diagonal(Fill(val, m)), - :DiagonalZeros => Diagonal(Zeros(m)), - :DiagonalOnes => Diagonal(Ones(m)), :Zeros => Zeros(m, n), :Ones => Ones(m, n), :Fill => Fill(val, m, n), :Mat => rand(m, n), ) - mat_instances = Dict( + n = 3 + trans_vec_instances = Dict( + :TransVec => rand(n), :TransZerosVec => Zeros(n), :TransOnesVec => Ones(n), - :TransFillVec => Fill(1., n), - :TransVec => rand(n), + :TransFillVec => Fill(val, n), ) vec_instances = Dict( :Vec => rand(n), :ZerosVec => Zeros(n), :OnesVec => Ones(n), - :FillVec => Fill(1., n) + :FillVec => Fill(val, n) + ) + + n = 1 + one_dim_mat_instances = Dict( + :RectDiagonal => RectDiagonal(rand(n), n, n), + :RectDiagonalFill => RectDiagonal(Fill(val, n), n, n), + :RectDiagonalZeros => RectDiagonal(Zeros(n), n, n), + :RectDiagonalOnes => RectDiagonal(Ones(n), n, n), + :Diagonal => Diagonal(rand(n)), + :DiagonalFill => Diagonal(Fill(val, n)), + :DiagonalZeros => Diagonal(Zeros(n)), + :DiagonalOnes => Diagonal(Ones(n)), + :Zeros => Zeros(n, n), + :Ones => Ones(n, n), + :Fill => Fill(val, n, n), + :Mat => rand(n, n), ) + # Expected outcome table. # The header (in order) corresponds to the following instance symbols: # :RectDiagonal, :RectDiagonalFill, :RectDiagonalZeros, :RectDiagonalOnes, @@ -444,7 +487,11 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Vector, - :FillVec => Vector + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, ), :RectDiagonalFill => Dict( :RectDiagonal => RectDiagonal, @@ -462,7 +509,11 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Vector, - :FillVec => Vector + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, ), :RectDiagonalZeros => Dict( :RectDiagonal => Zeros, @@ -480,7 +531,11 @@ end :Vec => Zeros, :ZerosVec => Zeros, :OnesVec => Zeros, - :FillVec => Zeros + :FillVec => Zeros, + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, ), :RectDiagonalOnes => Dict( :RectDiagonal => RectDiagonal, @@ -498,7 +553,11 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Vector, - :FillVec => Vector + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, ), :Diagonal => Dict( :RectDiagonal => RectDiagonal, @@ -507,7 +566,7 @@ end :RectDiagonalOnes => RectDiagonal, :Diagonal => Diagonal, :DiagonalFill => Diagonal, - :DiagonalZeros => Zeros, + :DiagonalZeros => DiagonalZeros, :DiagonalOnes => Diagonal, :Zeros => Zeros, :Ones => Array, @@ -516,7 +575,11 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Vector, - :FillVec => Vector + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, ), :DiagonalFill => Dict( :RectDiagonal => RectDiagonal, @@ -525,7 +588,7 @@ end :RectDiagonalOnes => RectDiagonalFill, :Diagonal => Diagonal, :DiagonalFill => DiagonalFill, - :DiagonalZeros => Zeros, + :DiagonalZeros => DiagonalZeros, :DiagonalOnes => DiagonalFill, :Zeros => Zeros, :Ones => Fill, @@ -534,17 +597,21 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Fill, - :FillVec => Fill + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Fill, + :TransFillVec => Fill, ), :DiagonalZeros => Dict( :RectDiagonal => Zeros, :RectDiagonalFill => Zeros, :RectDiagonalZeros => Zeros, :RectDiagonalOnes => Zeros, - :Diagonal => Zeros, - :DiagonalFill => Zeros, - :DiagonalZeros => Zeros, - :DiagonalOnes => Zeros, + :Diagonal => DiagonalZeros, + :DiagonalFill => DiagonalZeros, + :DiagonalZeros => DiagonalZeros, + :DiagonalOnes => DiagonalZeros, :Zeros => Zeros, :Ones => Zeros, :Fill => Zeros, @@ -552,7 +619,11 @@ end :Vec => Zeros, :ZerosVec => Zeros, :OnesVec => Zeros, - :FillVec => Zeros + :FillVec => Zeros, + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, ), :DiagonalOnes => Dict( :RectDiagonal => RectDiagonal, @@ -561,7 +632,7 @@ end :RectDiagonalOnes => RectDiagonalOnes, :Diagonal => Diagonal, :DiagonalFill => DiagonalFill, - :DiagonalZeros => Zeros, + :DiagonalZeros => DiagonalZeros, :DiagonalOnes => DiagonalOnes, :Zeros => Zeros, :Ones => Ones, @@ -570,7 +641,11 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Ones, - :FillVec => Fill + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Ones, + :TransFillVec => Fill, ), :Zeros => Dict( :RectDiagonal => Zeros, @@ -588,7 +663,11 @@ end :Vec => Zeros, :ZerosVec => Zeros, :OnesVec => Zeros, - :FillVec => Zeros + :FillVec => Zeros, + :TransVec => Zeros, + :TransZerosVec => Zeros, + :TransOnesVec => Zeros, + :TransFillVec => Zeros, ), :Ones => Dict( :RectDiagonal => Array, @@ -606,7 +685,11 @@ end :Vec => Fill, :ZerosVec => Zeros, :OnesVec => Fill, - :FillVec => Fill + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Ones, + :TransFillVec => Fill, ), :Fill => Dict( :RectDiagonal => Array, @@ -624,7 +707,11 @@ end :Vec => Fill, :ZerosVec => Zeros, :OnesVec => Fill, - :FillVec => Fill + :FillVec => Fill, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Fill, + :TransFillVec => Fill, ), :Mat => Dict( :RectDiagonal => Array, @@ -642,7 +729,11 @@ end :Vec => Vector, :ZerosVec => Zeros, :OnesVec => Vector, - :FillVec => Vector + :FillVec => Vector, + :TransVec => Array, + :TransZerosVec => Zeros, + :TransOnesVec => Array, + :TransFillVec => Array, ), :Vec => Dict( :TransVec => Array, @@ -668,6 +759,20 @@ end :TransOnesVec => Fill, :TransFillVec => Fill, ), + :TransVec => Dict( + :RectDiagonal => Array, + :RectDiagonalFill => Array, + :RectDiagonalZeros => Zeros, + :RectDiagonalOnes => Array, + :Diagonal => Array, + :DiagonalFill => Array, + :DiagonalZeros => Zeros, + :DiagonalOnes => Array, + :Zeros => Zeros, + :Ones => Fill, + :Fill => Fill, + :Mat => Array, + ), :TransZerosVec => Dict( :RectDiagonal => Zeros, :RectDiagonalFill => Zeros, @@ -710,53 +815,81 @@ end :Fill => Fill, :Mat => Array, ), - :TransVec => Dict( - :RectDiagonal => Array, - :RectDiagonalFill => Array, - :RectDiagonalZeros => Zeros, - :RectDiagonalOnes => Array, - :Diagonal => Array, - :DiagonalFill => Array, - :DiagonalZeros => Zeros, - :DiagonalOnes => Array, - :Zeros => Zeros, - :Ones => Fill, - :Fill => Fill, - :Mat => Array, - ), ) - for (k2, B) in instances + for (k2, B) in square_mat_instances for op2 in (adjoint, transpose, identity) - for (k1, A) in instances + for (k1, A) in square_mat_instances for op1 in (adjoint, transpose, identity) - @test typeof(op1(A) * op2(B)) <: expected[k1][k2] + result = op1(A) * op2(B) + @test result isa expected[k1][k2] || result.parent isa expected[k1][k2] end end - for (k1, A) in mat_instances + for (k1, A) in trans_vec_instances for op1 in (adjoint, transpose) result = op1(A) * op2(B) - @test typeof(result) <: expected[k1][k2] || typeof(result.parent) <: expected[k1][k2] + @test result isa expected[k1][k2] || result.parent isa expected[k1][k2] end end end end for (k2, B) in vec_instances - for (k1, A) in instances + for (k1, A) in square_mat_instances for op1 in (adjoint, transpose, identity) - @test typeof(op1(A) * B) <: expected[k1][k2] + @test op1(A) * B isa expected[k1][k2] end end - for (k1, A) in mat_instances + for (k1, A) in trans_vec_instances for op1 in (adjoint, transpose) - @test typeof(op1(A)*B) <: Number - @test typeof(B*op1(A)) <: expected[k2][k1] + @test op1(A)*B isa Number + @test B*op1(A) isa expected[k2][k1] + end + end + end + + for (k1, A) in Iterators.flatten((col_mat_instances, one_dim_mat_instances)) + for (k2, B) in trans_vec_instances + for op2 in (adjoint, transpose) + result = A * op2(B) + @test result isa expected[k1][k2] || result.parent isa expected[k1][k2] end end end + + for (k1, A) in row_mat_instances + for op1 in (adjoint, transpose) + for (k2, B) in trans_vec_instances + for op2 in (adjoint, transpose) + @test op1(A) * op2(B) isa expected[k1][k2] + end + end + end + end + + num = 3. + @test square_mat_instances[:RectDiagonalFill] * num == num *square_mat_instances[:RectDiagonalFill] == num * Matrix(square_mat_instances[:RectDiagonalFill]) + + for (k1, Da) in square_mat_instances + for (k2, Db) in square_mat_instances + for (k3, A) in square_mat_instances + @test typeof(Da * A * Db) === typeof((Da * A) * Db) === typeof((Da * (A * Db))) + if !(typeof(Da * A * Db) === typeof((Da * A) * Db) === typeof((Da * (A * Db)))) + @show typeof(Da) typeof(A) typeof(Db) typeof(Da*A*Db) typeof((Da*A)*Db) + end + end + end + end + + ind = (1, 2) + sz = (3, 3) + oneele = OneElement(val, ind, sz) + @test oneele * Diagonal(Zeros(3)) === Diagonal(Zeros(3)) * oneele === Zeros(3,3) + @test oneele * Diagonal(Fill(val, 3)) === Diagonal(Fill(val, 3)) * oneele === OneElement(val*val, ind, sz) + @test oneele * Diagonal(Ones(3)) === Diagonal(Ones(3)) * oneele === oneele + end # Check that all pair-wise combinations of + / - elements of As and Bs yield the correct From 8c374a1212ac74b5550693f8b1794001355a5a1e Mon Sep 17 00:00:00 2001 From: Chenyang Wu Date: Mon, 17 Feb 2025 17:35:08 +0800 Subject: [PATCH 9/9] fix julia-1.10.8 ambiguities --- src/fillalgebra.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 2937c3ca..0b81d71d 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -548,7 +548,7 @@ function *(A::Diagonal, B::RectDiagonal) RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), (size(A, 1), size(B, 2))) end -for type in (AbstractMatrix, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) +for type in (AbstractMatrix, AbstractTriangular, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) @eval begin function *(A::RectDiagonal, B::$type) check_matmul_sizes(A, B) @@ -614,7 +614,7 @@ for type in (RectDiagonal, RectDiagonalZeros) end end -for type in (AbstractMatrix, RectDiagonal, Diagonal, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) +for type in (AbstractMatrix, RectDiagonal, Diagonal, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}, AbstractTriangular) @eval begin function *(A::$type, B::RectDiagonalZeros) check_matmul_sizes(A, B) @@ -643,8 +643,8 @@ const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V} const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V} mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix, AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular, - LowerTriangular, UpperTriangular, LinearAlgebra.AbstractTriangular, Symmetric, Hermitian, - SymTridiagonal, UpperHessenberg, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement) + LowerTriangular, UpperTriangular, AbstractTriangular, Symmetric, Hermitian, LinearAlgebra.HermOrSym, + SymTridiagonal, UpperHessenberg, LinearAlgebra.AdjOrTransAbsMat, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement) for type in tuple(AbstractVector, AbstractZerosVector, mat_types...) @eval begin function *(A::DiagonalFill, B::$type)