Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 1317f16

Browse files
Merge pull request #13 from shashi/s/try-all-paths
sparsity detection: a pass to try all tainted branches
2 parents 26a3fc0 + 5188a21 commit 1317f16

File tree

14 files changed

+542
-1
lines changed

14 files changed

+542
-1
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ os:
44
- linux
55
- osx
66
julia:
7-
- 1.1
7+
- 1.2
88
- nightly
99
matrix:
1010
allow_failures:

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.0"
66
[deps]
77
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
88
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
9+
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
910
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
1011
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1112
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"

src/SparseDiffTools.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ export contract_color,
2222
autoback_hesvec,autoback_hesvec!,
2323
JacVec,HesVec,HesVecGrad
2424

25+
2526
include("coloring/high_level.jl")
2627
include("coloring/contraction_coloring.jl")
2728
include("coloring/greedy_d1_coloring.jl")
2829
include("coloring/matrix2graph.jl")
2930
include("differentiation/compute_jacobian_ad.jl")
3031
include("differentiation/jaches_products.jl")
32+
include("program_sparsity/program_sparsity.jl")
3133

3234
end # module

src/program_sparsity/path.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#### Path
2+
3+
# First just do it for the case where there we assume
4+
# tainted gotoifnots do not go in a loop!
5+
# TODO: write a thing to detect this! (overdub predicates only in tainted ifs)
6+
# implement snapshotting function state as an optimization for branch exploration
7+
mutable struct Path
8+
path::BitVector
9+
cursor::Int
10+
end
11+
12+
Path() = Path([], 1)
13+
14+
function increment!(bitvec)
15+
for i=1:length(bitvec)
16+
if bitvec[i] === true
17+
bitvec[i] = false
18+
else
19+
bitvec[i] = true
20+
break
21+
end
22+
end
23+
end
24+
25+
function reset!(p::Path)
26+
p.cursor=1
27+
increment!(p.path)
28+
nothing
29+
end
30+
31+
function alldone(p::Path) # must be called at the end of the function!
32+
all(identity, p.path)
33+
end
34+
35+
function this_here_predicate!(p::Path)
36+
if p.cursor > length(p.path)
37+
push!(p.path, false)
38+
else
39+
p.path[p.cursor]
40+
end
41+
val = p.path[p.cursor]
42+
p.cursor+=1
43+
val
44+
end
45+
46+
alldone(c::SparsityContext) = alldone(c.metadata[2])
47+
reset!(c::SparsityContext) = reset!(c.metadata[2])
48+
this_here_predicate!(c::SparsityContext) = this_here_predicate!(c.metadata[2])
49+
50+
#=
51+
julia> p=Path()
52+
Path(Bool[], 1)
53+
54+
julia> alldone(p) # must be called at the end of a full run
55+
true
56+
57+
julia> this_here_predicate!(p)
58+
false
59+
60+
julia> alldone(p) # must be called at the end of a full run
61+
false
62+
63+
julia> this_here_predicate!(p)
64+
false
65+
66+
julia> p
67+
Path(Bool[false, false], 3)
68+
69+
julia> alldone(p)
70+
false
71+
72+
julia> reset!(p)
73+
74+
julia> p
75+
Path(Bool[true, false], 1)
76+
77+
julia> this_here_predicate!(p)
78+
true
79+
80+
julia> this_here_predicate!(p)
81+
false
82+
83+
julia> alldone(p)
84+
false
85+
86+
julia> reset!(p)
87+
88+
julia> p
89+
Path(Bool[false, true], 1)
90+
91+
julia> this_here_predicate!(p)
92+
false
93+
94+
julia> this_here_predicate!(p)
95+
true
96+
97+
julia> reset!(p)
98+
99+
julia> this_here_predicate!(p)
100+
true
101+
102+
julia> this_here_predicate!(p)
103+
true
104+
105+
julia> alldone(p)
106+
true
107+
=#
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
include("sparsity_tracker.jl")
2+
include("path.jl")
3+
include("take_all_branches.jl")
4+
5+
struct Fixed
6+
value
7+
end
8+
9+
"""
10+
`sparsity!(f, Y, X, args...; sparsity=Sparsity(length(X), length(Y)))`
11+
12+
Execute the program that figures out the sparsity pattern of
13+
the jacobian of the function `f`.
14+
15+
# Arguments:
16+
- `f`: the function
17+
- `Y`: the output array
18+
- `X`: the input array
19+
- `args`: trailing arguments to `f`. They are considered subject to change, unless wrapped as `Fixed(arg)`
20+
- `S`: (optional) the sparsity pattern
21+
22+
Returns a `Sparsity`
23+
"""
24+
function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)))
25+
path = Path()
26+
ctx = SparsityContext(metadata=(sparsity, path), pass=BranchesPass)
27+
ctx = Cassette.enabletagging(ctx, f!)
28+
ctx = Cassette.disablehooks(ctx)
29+
30+
while true
31+
Cassette.recurse(ctx,
32+
f!,
33+
tag(Y, ctx, Output()),
34+
tag(X, ctx, Input()),
35+
# TODO: make this recursive
36+
map(arg -> arg isa Fixed ?
37+
arg.value : tag(arg, ctx, ProvinanceSet(())), args)...)
38+
39+
println("Explored path: ", path)
40+
alldone(path) && break
41+
reset!(path)
42+
end
43+
sparsity
44+
end
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
using Cassette
2+
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
3+
import Core: SSAValue
4+
using SparseArrays
5+
6+
export Sparsity, sparsity!
7+
8+
"""
9+
The sparsity pattern.
10+
11+
- `I`: Input index
12+
- `J`: Ouput index
13+
14+
`(i, j)` means the `j`th element of the output depends on
15+
the `i`th element of the input. Therefore `length(I) == length(J)`
16+
"""
17+
struct Sparsity
18+
m::Int
19+
n::Int
20+
I::Vector{Int} # Input
21+
J::Vector{Int} # Output
22+
end
23+
24+
SparseArrays.sparse(s::Sparsity) = sparse(s.I, s.J, true, s.m, s.n)
25+
26+
Sparsity(m, n) = Sparsity(m, n, Int[], Int[])
27+
28+
function Base.push!(S::Sparsity, i::Int, j::Int)
29+
push!(S.I, i)
30+
push!(S.J, j)
31+
end
32+
33+
# Tags:
34+
struct Input end
35+
struct Output end
36+
37+
struct ProvinanceSet{T}
38+
set::T # Set, Array, Int, Tuple, anything!
39+
end
40+
41+
# note: this is not strictly set union, just some efficient way of concating
42+
Base.union(p::ProvinanceSet{<:Tuple},
43+
q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set..., q.set,))
44+
Base.union(p::ProvinanceSet{<:Integer},
45+
q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set, q.set...,))
46+
Base.union(p::ProvinanceSet{<:Integer},
47+
q::ProvinanceSet{<:Integer}) = ProvinanceSet((p.set, q.set,))
48+
Base.union(p::ProvinanceSet{<:Tuple},
49+
q::ProvinanceSet{<:Tuple}) = ProvinanceSet((p.set..., q.set...,))
50+
Base.union(p::ProvinanceSet,
51+
q::ProvinanceSet) = ProvinanceSet(union(p.set, q.set))
52+
Base.union(p::ProvinanceSet,
53+
q::ProvinanceSet,
54+
rs::ProvinanceSet...) = union(union(p, q), rs...)
55+
Base.union(p::ProvinanceSet) = p
56+
57+
function Base.push!(S::Sparsity, i::Int, js::ProvinanceSet)
58+
for j in js.set
59+
push!(S, i, j)
60+
end
61+
end
62+
63+
Cassette.@context SparsityContext
64+
65+
const TagType = Union{Input, Output, ProvinanceSet}
66+
Cassette.metadatatype(::Type{<:SparsityContext}, ::DataType) = TagType
67+
68+
metatype(x, ctx) = hasmetadata(x, ctx) && istagged(x, ctx) && typeof(metadata(x, ctx))
69+
function ismetatype(x, ctx, T)
70+
hasmetadata(x, ctx) && istagged(x, ctx) && (metadata(x, ctx) isa T)
71+
end
72+
73+
# Dummy type when you getindex
74+
struct Tainted end
75+
76+
# getindex on the input
77+
function Cassette.overdub(ctx::SparsityContext,
78+
f::typeof(getindex),
79+
X::Tagged,
80+
idx::Int...)
81+
if ismetatype(X, ctx, Input)
82+
i = LinearIndices(untag(X, ctx))[idx...]
83+
val = Cassette.fallback(ctx, f, X, idx...)
84+
tag(val, ctx, ProvinanceSet(i))
85+
else
86+
Cassette.recurse(ctx, f, X, idx...)
87+
end
88+
end
89+
90+
# setindex! on the output
91+
function Cassette.overdub(ctx::SparsityContext,
92+
f::typeof(setindex!),
93+
Y::Tagged,
94+
val::Tagged,
95+
idx::Int...)
96+
S, path = ctx.metadata
97+
if ismetatype(Y, ctx, Output)
98+
set = metadata(val, ctx)
99+
if set isa ProvinanceSet
100+
i = LinearIndices(untag(Y, ctx))[idx...]
101+
push!(S, i, set)
102+
end
103+
Cassette.fallback(ctx, f, Y, val, idx...)
104+
else
105+
Cassette.recurse(ctx, f, Y, val, idx...)
106+
end
107+
end
108+
109+
function get_provinance(ctx, arg::Tagged)
110+
if metadata(arg, ctx) isa ProvinanceSet
111+
metadata(arg, ctx)
112+
else
113+
ProvinanceSet(())
114+
end
115+
end
116+
117+
get_provinance(ctx, arg) = ProvinanceSet(())
118+
119+
# Any function acting on a value tagged with ProvinanceSet
120+
function _overdub_union_provinance(::Val{eval}, ctx::SparsityContext, f, args...) where {eval}
121+
idxs = findall(x->ismetatype(x, ctx, ProvinanceSet), args)
122+
if isempty(idxs)
123+
Cassette.fallback(ctx, f, args...)
124+
else
125+
provinance = union(map(arg->get_provinance(ctx, arg), args[idxs])...)
126+
if eval
127+
val = Cassette.fallback(ctx, f, args...)
128+
tag(val, ctx, provinance)
129+
else
130+
tag(Tainted(), ctx, provinance)
131+
end
132+
end
133+
end
134+
135+
function Cassette.overdub(ctx::SparsityContext, f, args...)
136+
haspsets = any(x->ismetatype(x, ctx, ProvinanceSet), args)
137+
hasinput = any(x->ismetatype(x, ctx, Input), args)
138+
if haspsets && !hasinput # && !canrecurse(ctx, f, args...)
139+
_overdub_union_provinance(Val{true}(), ctx, f, args...)
140+
else
141+
Cassette.recurse(ctx, f, args...)
142+
end
143+
end
144+
145+
#=
146+
# Examples:
147+
#
148+
using UnicodePlots
149+
150+
sspy(s::Sparsity) = spy(sparse(s))
151+
152+
julia> sparsity!([0,0,0], [23,53,83]) do Y, X
153+
Y[:] .= X
154+
Y == X
155+
end
156+
(true, Sparsity([1, 2, 3], [1, 2, 3]))
157+
158+
julia> sparsity!([0,0,0], [23,53,83]) do Y, X
159+
for i=1:3
160+
for j=i:3
161+
Y[j] += X[i]
162+
end
163+
end; Y
164+
end
165+
([23, 76, 159], Sparsity(3, 3, [1, 2, 3, 2, 3, 3], [1, 1, 1, 2, 2, 3]))
166+
167+
julia> sspy(ans[2])
168+
Sparsity Pattern
169+
┌─────┐
170+
1 │⠀⠄⠀⠀⠀│ > 0
171+
3 │⠀⠅⠨⠠⠀│ < 0
172+
└─────┘
173+
1 3
174+
nz = 6
175+
176+
julia> sparsity!(f, zeros(Int, 3,3), [23,53,83])
177+
([23, 53, 83], Sparsity(9, 3, [2, 5, 8], [1, 2, 3]))
178+
179+
julia> sspy(ans[2])
180+
Sparsity Pattern
181+
┌─────┐
182+
1 │⠀⠄⠀⠀⠀│ > 0
183+
│⠀⠀⠠⠀⠀│ < 0
184+
9 │⠀⠀⠀⠐⠀│
185+
└─────┘
186+
1 3
187+
nz = 3
188+
=#

0 commit comments

Comments
 (0)