Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ version = "0.4.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -15,8 +19,8 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
SSGraphBLAS_jll = "5.1.2"
julia = "1.6"
CEnum = "0.4.1"
ContextVariablesX = "0.1.1"
MacroTools = "0.5.6"
SSGraphBLAS_jll = "5.1.2"
julia = "1.6"
9 changes: 7 additions & 2 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,14 @@ include("operations/kronecker.jl")
include("print.jl")
include("import.jl")
include("export.jl")

#EXPERIMENTAL
include("options.jl")
#EXPERIMENTAL
include("chainrules/chainruleutils.jl")
include("chainrules/mulrules.jl")
include("chainrules/ewiserules.jl")
include("chainrules/maprules.jl")
include("chainrules/reducerules.jl")
include("chainrules/selectrules.jl")
#include("random.jl")
include("misc.jl")
export libgb
Expand Down
46 changes: 46 additions & 0 deletions src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import FiniteDifferences
import LinearAlgebra
import ChainRulesCore: frule, rrule
using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

#Required for ChainRulesTestUtils
function FiniteDifferences.to_vec(M::GBMatrix)
I, J, X = findnz(M)
function backtomat(xvec)
return GBMatrix(I, J, xvec; nrows = size(M, 1), ncols = size(M, 2))
end
return X, backtomat
end

function FiniteDifferences.to_vec(v::GBVector)
i, x = findnz(v)
function backtovec(xvec)
return GBVector(i, xvec; nrows=size(v, 1))
end
return x, backtovec
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBMatrix{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, J, _ = findnz(x)
return GBMatrix(I, J, v; nrows = size(x, 1), ncols = size(x, 2))
end

function FiniteDifferences.rand_tangent(
rng::AbstractRNG,
x::GBVector{T}
) where {T <: Union{AbstractFloat, Complex}}
n = nnz(x)
v = rand(rng, -9:0.01:9, n)
I, _ = findnz(x)
return GBVector(I, v; nrows = size(x, 1))
end

FiniteDifferences.rand_tangent(rng::AbstractRNG, x::AbstractOp) = NoTangent()
# LinearAlgebra.norm freaks over the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
64 changes: 64 additions & 0 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#emul TIMES
function frule(
(_, ΔA, ΔB, _),
::typeof(emul),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.TIMES)
)
Ω = emul(A, B, BinaryOps.TIMES)
∂Ω = emul(ΔA, B, BinaryOps.TIMES) + emul(ΔB, A, BinaryOps.TIMES)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(emul), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, BinaryOps.TIMES)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(BinaryOps.TIMES))
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
function timespullback(ΔΩ)
∂A = emul(ΔΩ, B)
∂B = emul(ΔΩ, A)
return NoTangent(), ∂A, ∂B
end
return emul(A, B, BinaryOps.TIMES), timespullback
end

#eadd PLUS
function frule(
(_, ΔA, ΔB, _),
::typeof(eadd),
A::GBArray,
B::GBArray,
::typeof(BinaryOps.PLUS)
)
Ω = eadd(A, B, BinaryOps.PLUS)
∂Ω = eadd(ΔA, ΔB, BinaryOps.PLUS)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB), ::typeof(eadd), A::GBArray, B::GBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, BinaryOps.PLUS)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(BinaryOps.PLUS))
function pluspullback(ΔΩ)
return NoTangent(), ΔΩ, ΔΩ, NoTangent()
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end

# Do I have to duplicate this? I get 4 tangents instead of 3 if I call the previous rule.
function rrule(::typeof(eadd), A::GBArray, B::GBArray)
function pluspullback(ΔΩ)
return NoTangent(), ΔΩ, ΔΩ
end
return eadd(A, B, BinaryOps.PLUS), pluspullback
end
17 changes: 17 additions & 0 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Per Lyndon. Needs adaptation, and/or needs redefinition of map to use functions rather
# than AbstractOp.
#function rrule(map, f, xs)
# # Rather than 3 maps really want 1 multimap
# ys_and_pullbacks = map(x->rrule(f, x), xs) #Take this to ys = map(f, x)
# ys = map(first, ys_and_pullbacks)
# pullbacks = map(last, ys_and_pullbacks)
# function map_pullback(dys)
# _call(f, x) = f(x)
# dfs_and_dxs = map(_call, pullbacks, dys)
# # but in your case you know it will be NoTangent() so can skip
# df = sum(first, dfs_and_dxs)
# dxs = map(last, dfs_and_dxs)
# return NoTangent(), df, dxs
# end
# return ys, map_pullback
#end
54 changes: 54 additions & 0 deletions src/chainrules/mulrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Standard arithmetic mul:
function frule(
(_, ΔA, ΔB),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
frule((nothing, ΔA, ΔB, nothing), mul, A, B, Semirings.PLUS_TIMES)
end
function frule(
(_, ΔA, ΔB, _),
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
Ω = mul(A, B)
∂Ω = mul(ΔA, B) + mul(A, ΔB)
return Ω, ∂Ω
end
# Tests will not pass for this. For two reasons.
# First is #25, the output inference is not type stable.
# That's it's own issue.

# Second, to_vec currently works by mapping materialized values back and forth, ie. it knows nothing about nothings.
# This means they give different answers. FiniteDifferences is probably "incorrect", but I have no proof.

function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose,
::typeof(Semirings.PLUS_TIMES)
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B')
∂B = mul(A', ΔΩ)
return NoTangent(), ∂A, ∂B, NoTangent()
end
return mul(A, B), mulpullback
end

# Do I have to duplicate this? :/
function rrule(
::typeof(mul),
A::GBMatOrTranspose,
B::GBMatOrTranspose
)
function mulpullback(ΔΩ)
∂A = mul(ΔΩ, B')
∂B = mul(A', ΔΩ)
return NoTangent(), ∂A, ∂B
end
return mul(A, B), mulpullback
end
Empty file added src/chainrules/reducerules.jl
Empty file.
Empty file added src/chainrules/selectrules.jl
Empty file.
22 changes: 11 additions & 11 deletions src/lib/LibGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,27 @@ macro wraperror(code)
elseif info == GrB_NO_VALUE
return nothing
else
if info == GrB_UNINITIALIZED_OBJECT
if info == GrB_UNINITIALIZED_OBJECT
throw(UninitializedObjectError)
elseif info == GrB_INVALID_OBJECT
elseif info == GrB_INVALID_OBJECT
throw(InvalidObjectError)
elseif info == GrB_NULL_POINTER
elseif info == GrB_NULL_POINTER
throw(NullPointerError)
elseif info == GrB_INVALID_VALUE
elseif info == GrB_INVALID_VALUE
throw(InvalidValueError)
elseif info == GrB_INVALID_INDEX
elseif info == GrB_INVALID_INDEX
throw(InvalidIndexError)
elseif info == GrB_DOMAIN_MISMATCH
elseif info == GrB_DOMAIN_MISMATCH
throw(DomainError(nothing, "GraphBLAS Domain Mismatch"))
elseif info == GrB_DIMENSION_MISMATCH
throw(DimensionMismatch())
elseif info == GrB_OUTPUT_NOT_EMPTY
elseif info == GrB_OUTPUT_NOT_EMPTY
throw(OutputNotEmptyError)
elseif info == GrB_OUT_OF_MEMORY
elseif info == GrB_OUT_OF_MEMORY
throw(OutOfMemoryError())
elseif info == GrB_INSUFFICIENT_SPACE
elseif info == GrB_INSUFFICIENT_SPACE
throw(InsufficientSpaceError)
elseif info == GrB_INDEX_OUT_OF_BOUNDS
elseif info == GrB_INDEX_OUT_OF_BOUNDS
throw(BoundsError())
elseif info == GrB_PANIC
throw(PANIC)
Expand Down Expand Up @@ -843,7 +843,7 @@ for T ∈ valid_vec
nvals = GrB_Vector_nvals(v)
I = Vector{GrB_Index}(undef, nvals)
X = Vector{$type}(undef, nvals)
nvals = Ref{GrB_Index}()
nvals = Ref{GrB_Index}(nvals)
$func(I, X, nvals, v)
nvals[] == length(I) == length(X) || throw(DimensionMismatch())
return I .+ 1, X
Expand Down
3 changes: 2 additions & 1 deletion src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ function Base.show(io::IO, ::MIME"text/plain", A::GBMatrix)
gxbprint(io, A)
end

SparseArrays.nonzeros(A::GBArray) = findnz(A)[3]
SparseArrays.nonzeros(A::GBArray) = findnz(A)[end]


# Indexing functions
####################
Expand Down
8 changes: 7 additions & 1 deletion src/operations/ewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function emul!(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.TIMES)

size(w) == size(u) == size(v) || throw(DimensionMismatch())
op = getoperator(op, optype(u, v))
accum = getoperator(accum, eltype(w))
Expand Down Expand Up @@ -291,6 +290,13 @@ function eadd(
return eadd!(C, A, B, op; mask, accum, desc)
end

function Base.:+(A::GBArray, B::GBArray)
eadd(A, B, nothing)
end

function Base.:-(A::GBArray, B::GBArray)
eadd(A, B, BinaryOps.MINUS)
end
#Elementwise Broadcasts
#######################

Expand Down
1 change: 0 additions & 1 deletion src/operations/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ function LinearAlgebra.mul!(
return w
end


"""
mul(A::GBArray, B::GBArray; kwargs...)::GBArray

Expand Down
4 changes: 2 additions & 2 deletions src/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ GBVector{T}(dims::Dims{1}) where {T} = GBVector{T}(dims...)

Create a GBVector from a vector of indices `I` and a vector of values `X`.
"""
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS) where {T}
x = GBVector{T}(maximum(I))
function GBVector(I::Vector, X::Vector{T}; dup = BinaryOps.PLUS, nrows = maximum(I)) where {T}
x = GBVector{T}(nrows)
build(x, I, X, dup = dup)
return x
end
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using SuiteSparseGraphBLAS
using SparseArrays
using Test
using Random

using ChainRulesTestUtils
Random.seed!(1)

function include_test(path)
Expand All @@ -14,4 +14,5 @@ println("Testing SuiteSparseGraphBLAS.jl")
@testset "SuiteSparseGraphBLAS" begin
include_test("gbarray.jl")
include_test("operations.jl")
include_test("testrules.jl")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the structure of test folder mirrors the src/ folder, which makes it easier to find things when the package grows.

end
40 changes: 40 additions & 0 deletions test/testrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@testset "Dense" begin
@testset "arithmetic semiring" begin
#dense first
M = GBMatrix(rand(-10.0:0.05:10.0, 10, 10))
Y = GBMatrix(rand(-10.0:0.05:10.0, 10))
test_frule(mul, M, Y; check_inferred=false)
test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
test_rrule(mul, M, Y; check_inferred=false)
test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
X = GBMatrix(rand(-10.0:0.05:10.0, 10))
test_frule(eadd, X, Y; check_inferred=false)
test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_rrule(eadd, X, Y; check_inferred=false)
test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_frule(emul, X, Y; check_inferred=false)
test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
test_rrule(emul, X, Y; check_inferred=false)
test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
end
end

@testset "Sparse" begin
@testset "arithmetic semiring" begin
M = GBMatrix(sprand(10, 10, 0.5))
Y = GBMatrix(sprand(10, 0.5)) #using matrix for now until I work out transpose(v::GBVector)
test_frule(mul, M, Y; check_inferred=false)
test_frule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
test_rrule(mul, M, Y; check_inferred=false)
test_rrule(mul, M, Y, Semirings.PLUS_TIMES; check_inferred=false)
X = GBMatrix(sprand(10, 0.5))
test_frule(eadd, X, Y; check_inferred=false)
test_frule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_rrule(eadd, X, Y; check_inferred=false)
test_rrule(eadd, X, Y, BinaryOps.PLUS; check_inferred=false)
test_frule(emul, X, Y; check_inferred=false)
test_frule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
test_rrule(emul, X, Y; check_inferred=false)
test_rrule(emul, X, Y, BinaryOps.TIMES; check_inferred=false)
end
end