diff --git a/Project.toml b/Project.toml index 17c28a38..0a5c562e 100644 --- a/Project.toml +++ b/Project.toml @@ -9,10 +9,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] FillArraysPDMatsExt = "PDMats" FillArraysSparseArraysExt = "SparseArrays" +FillArraysStaticArraysExt = "StaticArrays" FillArraysStatisticsExt = "Statistics" [compat] diff --git a/ext/FillArraysSparseArraysExt.jl b/ext/FillArraysSparseArraysExt.jl index 4cb29dcd..68bfc5dd 100644 --- a/ext/FillArraysSparseArraysExt.jl +++ b/ext/FillArraysSparseArraysExt.jl @@ -4,7 +4,7 @@ using SparseArrays using SparseArrays: SparseVectorUnion import Base: convert, kron using FillArrays -using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value, AbstractFillVector, _fill_dot +using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value, AbstractFillVector, _fill_dot, OneElementVector, OneElementMatrix # Specifying the full namespace is necessary because of https://github.com/JuliaLang/julia/issues/48533 # See https://github.com/JuliaStats/LogExpFunctions.jl/pull/63 using FillArrays.LinearAlgebra @@ -63,4 +63,16 @@ end # Ambiguity. see #178 dot(x::AbstractFillVector, y::SparseVectorUnion) = _fill_dot(x, y) +# OneElements with different indices should return +# SparseArrays under addition +function FillArrays.oneelement_addsub(a::OneElementVector, b::OneElementVector, aval, bval) + return sparsevec([a.ind[1], b.ind[1]], [aval, bval], length(a)) +end + +function FillArrays.oneelement_addsub(a::OneElementMatrix, b::OneElementMatrix, aval, bval) + nzval = [aval, bval] + nzind = ntuple(i -> [a.ind[i], b.ind[i]], ndims(a)) + return sparse(nzind..., nzval, size(a)...) +end + end # module diff --git a/ext/FillArraysStaticArraysExt.jl b/ext/FillArraysStaticArraysExt.jl new file mode 100644 index 00000000..f120be1a --- /dev/null +++ b/ext/FillArraysStaticArraysExt.jl @@ -0,0 +1,28 @@ +module FillArraysStaticArraysExt + +using FillArrays +using StaticArrays + +import Base: promote_op +import FillArrays: elconvert + +# Disambiguity methods for StaticArrays + +function Base.:+(a::FillArrays.Zeros, b::StaticArray) + promote_shape(a,b) + return elconvert(promote_op(+,eltype(a),eltype(b)),b) +end +function Base.:+(a::StaticArray, b::FillArrays.Zeros) + promote_shape(a,b) + return elconvert(promote_op(+,eltype(a),eltype(b)),a) +end +function Base.:-(a::StaticArray, b::FillArrays.Zeros) + promote_shape(a,b) + return elconvert(promote_op(-,eltype(a),eltype(b)),a) +end +function Base.:-(a::FillArrays.Zeros, b::StaticArray) + promote_shape(a,b) + return elconvert(promote_op(-,eltype(a),eltype(b)),-b) +end + +end # module diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index deb7e74d..62e390ef 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -436,16 +436,29 @@ function +(a::AbstractZeros{T}, b::AbstractZeros{V}) where {T, V} # for disambig promote_shape(a,b) return elconvert(promote_op(+,T,V),a) end -# no AbstractArray. Otherwise incompatible with StaticArrays.jl -# AbstractFill for disambiguity -for TYPE in (:Array, :AbstractFill, :AbstractRange, :Diagonal) + +# AbstractFill and Array for disambiguity +for TYPE in (:Array, :AbstractFill, :AbstractRange, :AbstractArray) @eval function +(a::$TYPE{T}, b::AbstractZeros{V}) where {T, V} promote_shape(a,b) return elconvert(promote_op(+,T,V),a) end + @eval function -(a::$TYPE{T}, b::AbstractZeros{V}) where {T, V} + promote_shape(a,b) + return elconvert(promote_op(-,T,V),a) + end + @eval function -(a::AbstractZeros{T}, b::$TYPE{V}) where {T, V} + promote_shape(a,b) + return elconvert(promote_op(-,T,V),-b) + end @eval +(a::AbstractZeros, b::$TYPE) = b + a end +function -(a::AbstractZeros, b::AbstractZeros) + promote_shape(a,b) + return elconvert(promote_type(eltype(a), eltype(b)),-b) +end + # for VERSION other than 1.6, could use ZerosMatrix only function +(a::AbstractFillMatrix{T}, b::UniformScaling) where {T} n = checksquare(a) diff --git a/src/oneelement.jl b/src/oneelement.jl index f2db8565..dabfbf06 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -148,6 +148,90 @@ end /(x::OneElement, b::Number) = OneElement(x.val / b, x.ind, x.axes) \(b::Number, x::OneElement) = OneElement(b \ x.val, x.ind, x.axes) +# Addition + +# O(1) addition with arbitrary array types +function add_one_elem(a::OneElement, b::AbstractArray) + axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) + + ret = copy(b) + try + ret[a.ind...] += getindex_value(a) + catch + # Fallback to materialising dense array if setindex! + # goes wrong (e.g on a Diagonal) + ret = Array(ret) + checkbounds(Bool, ret, a.ind...) && (ret[a.ind...] += getindex_value(a)) + end + return ret +end + +function sub_one_elem(a::AbstractArray, b::OneElement) + axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) + ret = copy(a) + try + ret[b.ind...] -= getindex_value(b) + catch + # Fallback to materialising dense array if setindex! + # goes wrong (e.g on a Diagonal) + ret = Array(ret) + checkbounds(Bool, ret, b.ind...) && (ret[b.ind...] -= getindex_value(b)) + end + return ret +end + +function sub_one_elem(a::OneElement, b::AbstractArray) + axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) + ret = copy(-b) + try + ret[a.ind...] += getindex_value(a) + catch + # Fallback to materialising dense array if setindex! + # goes wrong (e.g on a Diagonal) + ret = Array(ret) + checkbounds(Bool, ret, a.ind...) && (ret[a.ind...] += getindex_value(a)) + end + return ret +end + ++(a::OneElement, b::AbstractArray) = add_one_elem(a, b) ++(a::AbstractArray, b::OneElement) = add_one_elem(b, a) +-(a::AbstractArray, b::OneElement) = sub_one_elem(a, b) +-(a::OneElement, b::AbstractArray) = sub_one_elem(a, b) +# disambiguity +function +(a::AbstractZeros, b::OneElement) + promote_shape(a,b) + return elconvert(promote_op(+,eltype(a),eltype(b)),b) +end + ++(a::OneElement, b::AbstractZeros) = b + a + +# Adding/subtracting OneElements +# (Without SparseArrays) materialise dense vector if indices are different + +# Sparse Arrays extension overrides this for OneElementVector and OneElementMatrix +function oneelement_addsub(a::OneElement, b::OneElement, aval, bval) + ret = similar(a) + fill!(ret, zero(eltype(ret))) + ret[a.ind...] = aval + ret[b.ind...] = bval + return ret +end + +for (op, bop) in (:+ => :(getindex_value(b)), + :- => :(-getindex_value(b))) + @eval begin + function $op(a::OneElement, b::OneElement) + axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) + if a.ind == b.ind + return OneElement($op(getindex_value(a), getindex_value(b)), a.ind, axes(a)) + else + return oneelement_addsub(a, b, getindex_value(a), $bop) + end + end + end +end + # matrix-vector and matrix-matrix multiplication # Fill and OneElement diff --git a/test/runtests.jl b/test/runtests.jl index 174075cc..7ab9585a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -505,6 +505,14 @@ end for A in As, Z in (TZ -> Zeros{TZ}(3)).((Int, Float64, Int8, ComplexF64)) test_addition_and_subtraction_dim_mismatch(A, Z) end + + # Zeros should be additive identity for matrices + D = Diagonal([1, 1]) + Z = Zeros(2, 2) + @test D + Z isa Diagonal + @test D + Z == D + @test D - Z == D + @test Z - D == -D end end @@ -588,6 +596,12 @@ end testsparsediag(E) end end + + # Adding one elements with different indices should return a sparse array. + A = OneElement(2, 5) + B = OneElement(3, 5) + @test A + B isa SparseVector + @test A + B == [0, 1, 1, 0, 0] end @testset "==" begin @@ -2374,6 +2388,19 @@ end @test_throws ArgumentError isassigned(f, true) end + @testset "Addition/Subtraction" begin + A = OneElement(2, 5) + B = OneElement(3, 5) + + @test A + A isa OneElement + @test A + A == OneElement(2, 2, 5) + @test A + B == [0, 1, 1, 0, 0] + + @test B - B isa OneElement + @test B - B == OneElement(0, 2, 5) + @test B - A == [0, -1, 1, 0, 0] + end + @testset "matmul" begin A = reshape(Float64[1:9;], 3, 3) v = reshape(Float64[1:3;], 3)