Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 7c06168

Browse files
Merge pull request #41 from JuliaDiffEq/refactor
Refactor and update README to include sparsity detection
2 parents 1317f16 + ffdc07d commit 7c06168

File tree

7 files changed

+61
-268
lines changed

7 files changed

+61
-268
lines changed

README.md

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,35 @@ end
2929
```
3030

3131
For this function, we know that the sparsity pattern of the Jacobian is a
32-
`Tridiagonal` matrix. We represent our sparsity by that matrix:
32+
`Tridiagonal` matrix. However, if we didn't know the sparsity pattern for
33+
the Jacobian, we could use the `sparsity!` function to automatically
34+
detect the sparsity pattern. We declare that it outputs a length 30 vector
35+
and takes in a length 30 vector, and it spits out a `Sparsity` object
36+
which we can turn into a `SparseMatrixCSC`:
3337

3438
```julia
35-
sparsity_pattern = Tridiagonal(ones(29),ones(30),ones(29))
39+
sparsity_pattern = sparsity!(f,output,input)
40+
jac = Float64.(sparse(sparsity_pattern))
3641
```
3742

3843
Now we call `matrix_colors` to get the color vector for that matrix:
3944

4045
```julia
41-
colors = matrix_colors(sparsity_pattern)
46+
colors = matrix_colors(jac)
4247
```
4348

4449
Since `maximum(colors)` is 3, this means that finite differencing can now
4550
compute the Jacobian in just 4 `f`-evaluations:
4651

4752
```julia
48-
J = DiffEqDiffTools.finite_difference_jacobian(f, rand(30), color=colors)
53+
DiffEqDiffTools.finite_difference_jacobian!(jac, f, rand(30), color=colors)
4954
@show fcalls # 4
5055
```
5156

5257
In addition, a faster forward-mode autodiff call can be utilized as well:
5358

5459
```julia
55-
forwarddiff_color_jacobian!(sparsity_pattern, f, x, color = colors)
60+
forwarddiff_color_jacobian!(jac, f, x, color = colors)
5661
```
5762

5863
If one only need to compute products, one can use the operators. For example,
@@ -83,6 +88,36 @@ gmres!(res,J,v)
8388

8489
## Documentation
8590

91+
### Automated Sparsity Detection
92+
93+
Automated sparsity detection is provided by the `sparsity!` function whose
94+
syntax is:
95+
96+
```julia
97+
`sparsity!(f, Y, X, args...; sparsity=Sparsity(length(X), length(Y)), verbose=true)`
98+
```
99+
100+
The arguments are:
101+
102+
- `f`: the function
103+
- `Y`: the output array
104+
- `X`: the input array
105+
- `args`: trailing arguments to `f`. They are considered subject to change, unless wrapped as `Fixed(arg)`
106+
- `S`: (optional) the sparsity pattern
107+
- `verbose`: (optional) whether to describe the paths taken by the sparsity detection.
108+
109+
The function `f` is assumed to take arguments of the form `f(dx,x,args...)`.
110+
`sparsity!` returns a `Sparsity` object which describes where the non-zeros
111+
of the Jacobian occur. `sparse(::Sparsity)` transforms the pattern into
112+
a sparse matrix.
113+
114+
This function utilizes non-standard interpretation, which we denote
115+
combinatoric concolic analysis, to directly realize the sparsity pattern from the program's AST. It requires that the function `f` is a Julia function. It does not
116+
work numerically, meaning that it is not prone to floating point error or
117+
cancelation. It allows for branching and will automatically check all of the
118+
branches. However, a while loop of indeterminate length which is dependent
119+
on the input argument is not allowed.
120+
86121
### Matrix Coloring
87122

88123
Matrix coloring allows you to reduce the number of times finite differencing

src/SparseDiffTools.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
module SparseDiffTools
22

33
using SparseArrays, LinearAlgebra, BandedMatrices, BlockBandedMatrices,
4-
LightGraphs, VertexSafeGraphs, DiffEqDiffTools, ForwardDiff, Zygote
4+
LightGraphs, VertexSafeGraphs, DiffEqDiffTools, ForwardDiff, Zygote,
5+
SparseArrays
56
using BlockBandedMatrices:blocksize,nblocks
67
using ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
78

9+
using Cassette
10+
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
11+
import Cassette: tagged_new_tuple, ContextTagged, BindingMeta, DisableHooks, nametype
12+
import Core: SSAValue
13+
814
export contract_color,
915
greedy_d1,
1016
matrix2graph,
@@ -20,7 +26,8 @@ export contract_color,
2026
auto_hesvecgrad,auto_hesvecgrad!,
2127
numback_hesvec,numback_hesvec!,
2228
autoback_hesvec,autoback_hesvec!,
23-
JacVec,HesVec,HesVecGrad
29+
JacVec,HesVec,HesVecGrad,
30+
Sparsity, sparsity!
2431

2532

2633
include("coloring/high_level.jl")
@@ -30,5 +37,8 @@ include("coloring/matrix2graph.jl")
3037
include("differentiation/compute_jacobian_ad.jl")
3138
include("differentiation/jaches_products.jl")
3239
include("program_sparsity/program_sparsity.jl")
40+
include("program_sparsity/sparsity_tracker.jl")
41+
include("program_sparsity/path.jl")
42+
include("program_sparsity/take_all_branches.jl")
3343

3444
end # module

src/program_sparsity/program_sparsity.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
include("sparsity_tracker.jl")
2-
include("path.jl")
3-
include("take_all_branches.jl")
4-
51
struct Fixed
62
value
73
end
84

95
"""
10-
`sparsity!(f, Y, X, args...; sparsity=Sparsity(length(X), length(Y)))`
6+
`sparsity!(f, Y, X, args...; sparsity=Sparsity(length(X), length(Y)), verbose=true)`
117
128
Execute the program that figures out the sparsity pattern of
139
the jacobian of the function `f`.
@@ -18,10 +14,12 @@ the jacobian of the function `f`.
1814
- `X`: the input array
1915
- `args`: trailing arguments to `f`. They are considered subject to change, unless wrapped as `Fixed(arg)`
2016
- `S`: (optional) the sparsity pattern
17+
- `verbose`: (optional) whether to describe the paths taken by the sparsity detection.
2118
2219
Returns a `Sparsity`
2320
"""
24-
function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)))
21+
function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)),
22+
verbose = true)
2523
path = Path()
2624
ctx = SparsityContext(metadata=(sparsity, path), pass=BranchesPass)
2725
ctx = Cassette.enabletagging(ctx, f!)
@@ -36,7 +34,7 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)))
3634
map(arg -> arg isa Fixed ?
3735
arg.value : tag(arg, ctx, ProvinanceSet(())), args)...)
3836

39-
println("Explored path: ", path)
37+
verbose && println("Explored path: ", path)
4038
alldone(path) && break
4139
reset!(path)
4240
end

src/program_sparsity/sparsity_tracker.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
using Cassette
2-
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
3-
import Core: SSAValue
4-
using SparseArrays
5-
6-
export Sparsity, sparsity!
7-
81
"""
92
The sparsity pattern.
103

src/program_sparsity/tuple.jl

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)