-
Notifications
You must be signed in to change notification settings - Fork 39
Zeros and OneElement addition/subtraction #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ba5a543
90e4cfa
a99e4aa
159119f
6d37f34
33466e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ using SparseArrays | |
| using SparseArrays: SparseVectorUnion | ||
| import Base: convert, kron | ||
| using FillArrays | ||
| using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value, AbstractFillVector, _fill_dot | ||
| using FillArrays: RectDiagonalFill, RectOrDiagonalFill, ZerosVector, ZerosMatrix, getindex_value, AbstractFillVector, _fill_dot, OneElementVector, OneElementMatrix | ||
| # Specifying the full namespace is necessary because of https://github.com/JuliaLang/julia/issues/48533 | ||
| # See https://github.com/JuliaStats/LogExpFunctions.jl/pull/63 | ||
| using FillArrays.LinearAlgebra | ||
|
|
@@ -63,4 +63,16 @@ end | |
| # Ambiguity. see #178 | ||
| dot(x::AbstractFillVector, y::SparseVectorUnion) = _fill_dot(x, y) | ||
|
|
||
| # OneElements with different indices should return | ||
| # SparseArrays under addition | ||
| function FillArrays.oneelement_addsub(a::OneElementVector, b::OneElementVector, aval, bval) | ||
| return sparsevec([a.ind[1], b.ind[1]], [aval, bval], length(a)) | ||
| end | ||
|
|
||
| function FillArrays.oneelement_addsub(a::OneElementMatrix, b::OneElementMatrix, aval, bval) | ||
| nzval = [aval, bval] | ||
| nzind = ntuple(i -> [a.ind[i], b.ind[i]], ndims(a)) | ||
| return sparse(nzind..., nzval, size(a)...) | ||
| end | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ADd one element + sparse array special case |
||
| end # module | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| module FillArraysStaticArraysExt | ||
|
|
||
| using FillArrays | ||
| using StaticArrays | ||
|
|
||
| import Base: promote_op | ||
| import FillArrays: elconvert | ||
|
|
||
| # Disambiguity methods for StaticArrays | ||
|
|
||
| function Base.:+(a::FillArrays.Zeros, b::StaticArray) | ||
| promote_shape(a,b) | ||
| return elconvert(promote_op(+,eltype(a),eltype(b)),b) | ||
| end | ||
| function Base.:+(a::StaticArray, b::FillArrays.Zeros) | ||
| promote_shape(a,b) | ||
| return elconvert(promote_op(+,eltype(a),eltype(b)),a) | ||
| end | ||
| function Base.:-(a::StaticArray, b::FillArrays.Zeros) | ||
| promote_shape(a,b) | ||
| return elconvert(promote_op(-,eltype(a),eltype(b)),a) | ||
| end | ||
| function Base.:-(a::FillArrays.Zeros, b::StaticArray) | ||
| promote_shape(a,b) | ||
| return elconvert(promote_op(-,eltype(a),eltype(b)),-b) | ||
| end | ||
|
|
||
| end # module |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,6 +148,90 @@ end | |
| /(x::OneElement, b::Number) = OneElement(x.val / b, x.ind, x.axes) | ||
| \(b::Number, x::OneElement) = OneElement(b \ x.val, x.ind, x.axes) | ||
|
|
||
| # Addition | ||
|
|
||
| # O(1) addition with arbitrary array types | ||
| function add_one_elem(a::OneElement, b::AbstractArray) | ||
| axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) | ||
|
|
||
| ret = copy(b) | ||
| try | ||
| ret[a.ind...] += getindex_value(a) | ||
| catch | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a non-standard idiom and introduces a type instability |
||
| # Fallback to materialising dense array if setindex! | ||
| # goes wrong (e.g on a Diagonal) | ||
| ret = Array(ret) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type unstable |
||
| checkbounds(Bool, ret, a.ind...) && (ret[a.ind...] += getindex_value(a)) | ||
| end | ||
| return ret | ||
| end | ||
|
|
||
| function sub_one_elem(a::AbstractArray, b::OneElement) | ||
| axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) | ||
| ret = copy(a) | ||
| try | ||
| ret[b.ind...] -= getindex_value(b) | ||
| catch | ||
| # Fallback to materialising dense array if setindex! | ||
| # goes wrong (e.g on a Diagonal) | ||
| ret = Array(ret) | ||
| checkbounds(Bool, ret, b.ind...) && (ret[b.ind...] -= getindex_value(b)) | ||
| end | ||
| return ret | ||
| end | ||
|
|
||
| function sub_one_elem(a::OneElement, b::AbstractArray) | ||
| axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) | ||
| ret = copy(-b) | ||
| try | ||
| ret[a.ind...] += getindex_value(a) | ||
| catch | ||
| # Fallback to materialising dense array if setindex! | ||
| # goes wrong (e.g on a Diagonal) | ||
| ret = Array(ret) | ||
| checkbounds(Bool, ret, a.ind...) && (ret[a.ind...] += getindex_value(a)) | ||
| end | ||
| return ret | ||
| end | ||
|
|
||
| +(a::OneElement, b::AbstractArray) = add_one_elem(a, b) | ||
| +(a::AbstractArray, b::OneElement) = add_one_elem(b, a) | ||
| -(a::AbstractArray, b::OneElement) = sub_one_elem(a, b) | ||
| -(a::OneElement, b::AbstractArray) = sub_one_elem(a, b) | ||
| # disambiguity | ||
| function +(a::AbstractZeros, b::OneElement) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we be overloading broadcast + too? |
||
| promote_shape(a,b) | ||
| return elconvert(promote_op(+,eltype(a),eltype(b)),b) | ||
| end | ||
|
|
||
| +(a::OneElement, b::AbstractZeros) = b + a | ||
|
|
||
| # Adding/subtracting OneElements | ||
| # (Without SparseArrays) materialise dense vector if indices are different | ||
|
|
||
| # Sparse Arrays extension overrides this for OneElementVector and OneElementMatrix | ||
| function oneelement_addsub(a::OneElement, b::OneElement, aval, bval) | ||
| ret = similar(a) | ||
| fill!(ret, zero(eltype(ret))) | ||
| ret[a.ind...] = aval | ||
| ret[b.ind...] = bval | ||
| return ret | ||
| end | ||
|
|
||
| for (op, bop) in (:+ => :(getindex_value(b)), | ||
| :- => :(-getindex_value(b))) | ||
| @eval begin | ||
| function $op(a::OneElement, b::OneElement) | ||
| axes(a) == axes(b) || throw(DimensionMismatch(LazyString("A has dimensions ", size(a), " but B has dimensions ", size(b)))) | ||
| if a.ind == b.ind | ||
| return OneElement($op(getindex_value(a), getindex_value(b)), a.ind, axes(a)) | ||
| else | ||
| return oneelement_addsub(a, b, getindex_value(a), $bop) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type unstable |
||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| # matrix-vector and matrix-matrix multiplication | ||
|
|
||
| # Fill and OneElement | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -505,6 +505,14 @@ end | |||||
| for A in As, Z in (TZ -> Zeros{TZ}(3)).((Int, Float64, Int8, ComplexF64)) | ||||||
| test_addition_and_subtraction_dim_mismatch(A, Z) | ||||||
| end | ||||||
|
|
||||||
| # Zeros should be additive identity for matrices | ||||||
| D = Diagonal([1, 1]) | ||||||
| Z = Zeros(2, 2) | ||||||
| @test D + Z isa Diagonal | ||||||
| @test D + Z == D | ||||||
| @test D - Z == D | ||||||
| @test Z - D == -D | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
|
|
@@ -588,6 +596,12 @@ end | |||||
| testsparsediag(E) | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| # Adding one elements with different indices should return a sparse array. | ||||||
| A = OneElement(2, 5) | ||||||
| B = OneElement(3, 5) | ||||||
| @test A + B isa SparseVector | ||||||
| @test A + B == [0, 1, 1, 0, 0] | ||||||
| end | ||||||
|
|
||||||
| @testset "==" begin | ||||||
|
|
@@ -2374,6 +2388,19 @@ end | |||||
| @test_throws ArgumentError isassigned(f, true) | ||||||
| end | ||||||
|
|
||||||
| @testset "Addition/Subtraction" begin | ||||||
| A = OneElement(2, 5) | ||||||
| B = OneElement(3, 5) | ||||||
|
|
||||||
| @test A + A isa OneElement | ||||||
| @test A + A == OneElement(2, 2, 5) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| @test A + B == [0, 1, 1, 0, 0] | ||||||
|
|
||||||
| @test B - B isa OneElement | ||||||
| @test B - B == OneElement(0, 2, 5) | ||||||
| @test B - A == [0, -1, 1, 0, 0] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add test with adding |
||||||
| end | ||||||
|
|
||||||
| @testset "matmul" begin | ||||||
| A = reshape(Float64[1:9;], 3, 3) | ||||||
| v = reshape(Float64[1:3;], 3) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether having the behaviour depend on loading an extension is a good idea...