Skip to content

Commit 4b6e441

Browse files
Merge pull request #119 from JuliaDiffEq/hg/feature/varfuns
Refactor variables as functions
2 parents 28f75f1 + 98fcad0 commit 4b6e441

14 files changed

+311
-224
lines changed

README.md

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@ eqs = [D(x) ~ σ*(y-x),
4343

4444
Each operation builds an `Operation` type, and thus `eqs` is an array of
4545
`Operation` and `Variable`s. This holds a tree of the full system that can be
46-
analyzed by other programs. We can turn this into a `DiffEqSystem` via:
46+
analyzed by other programs. We can turn this into a `ODESystem` via:
4747

4848
```julia
49-
de = DiffEqSystem(eqs,t,[x,y,z],[σ,ρ,β])
50-
de = DiffEqSystem(eqs)
49+
de = ODESystem(eqs)
5150
```
5251

5352
where we tell it the variable types and ordering in the first version, or let it
@@ -56,7 +55,7 @@ This can then generate the function. For example, we can see the
5655
generated code via:
5756

5857
```julia
59-
generate_function(de)
58+
generate_function(de, [x,y,z], [σ,ρ,β])
6059

6160
## Which returns:
6261
:((##363, u, p, t)->begin
@@ -71,7 +70,7 @@ generate_function(de)
7170
and get the generated function via:
7271
7372
```julia
74-
f = ODEFunction(de)
73+
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
7574
```
7675
7776
### Example: Nonlinear System
@@ -88,8 +87,8 @@ derivatives are zero. We use (unknown) variables for our nonlinear system.
8887
eqs = [0 ~ σ*(y-x),
8988
0 ~ x*-z)-y,
9089
0 ~ x*y - β*z]
91-
ns = NonlinearSystem(eqs)
92-
nlsys_func = generate_function(ns)
90+
ns = NonlinearSystem(eqs, [x,y,z])
91+
nlsys_func = generate_function(ns, [x,y,z], [ρ,σ,β])
9392
```
9493
9594
which generates:
@@ -130,17 +129,49 @@ In this section we define the core pieces of the IR and what they mean.
130129
131130
### Variables
132131
133-
The most fundamental part of the IR is the `Variable`. The `Variable` is the
132+
The most fundamental part of the IR is the `Variable`. In order to mirror the
133+
intention of solving for variables and representing function-like parameters,
134+
we treat each instance of `Variable` as a function which is called on its
135+
arguments using the natural syntax. Rather than having additional mechanisms
136+
for handling constant variables and parameters, we simply represent them as
137+
constant functions.
138+
139+
The `Variable` is the
134140
context-aware single variable of the IR. Its fields are described as follows:
135141
136142
- `name`: the name of the `Variable`. Note that this is not necessarily
137143
the same as the name of the Julia variable. But this symbol itself is considered
138144
the core identifier of the `Variable` in the sense of equality.
139145
- `known`: the main denotation of context, storing whether or not the value of
140146
the variable is known.
141-
- `dependents`: the vector of variables on which the current variable
142-
is dependent. For example, `u(t,x)` has dependents `[t,x]`. Derivatives thus
143-
require this information in order to simplify down.
147+
148+
For example, the following code defines an independent variable `t`, a parameter
149+
`α`, a function parameter `σ`, a variable `x` which depends on `t`, a variable
150+
`y` with no dependents, and a variable `z` which depends on `t`, `α`, and `x(t)`.
151+
152+
```julia
153+
t = Variable(:t; known = true)() # independent variables are treated as known
154+
α = Variable(; known = true)() # parameters are known
155+
σ = Variable(; known = true) # left uncalled, since it is used as a function
156+
w = Variable(:w; known = false) # unknown, left uncalled
157+
x = Variable(:x; known = false)(t) # unknown, depends on `t`
158+
y = Variable(:y; known = false)() # unknown, no dependents
159+
z = Variable(:z; known = false)(t, α, x) # unknown, multiple arguments
160+
161+
expr = x + y^α + σ(3) * (z - t) - w(t - 1)
162+
```
163+
164+
We can rewrite this more concisely using macros. Note the difference between
165+
including and excluding empty parentheses. When in call format, variables are
166+
aliased to the given call, allowing implicit use of dependents for convenience.
167+
168+
```julia
169+
@parameters t() α() σ
170+
@variables w x(t) y() z(t, α, x)
171+
172+
expr = x + y^α + σ(3) * (z - t) - w(t - 1)
173+
```
174+
144175
145176
### Constants
146177
@@ -243,37 +274,37 @@ is accessible via a function-based interface. This means that all macros are
243274
syntactic sugar in some form. For example, the variable construction:
244275
245276
```julia
246-
@parameters t σ ρ β
277+
@parameters t() σ ρ() β()
247278
@variables x(t) y(t) z(t)
248279
@derivatives D'~t
249280
```
250281
251282
is syntactic sugar for:
252283
253284
```julia
254-
t = Variable(:t; known = true)
255-
x = Variable(:x, [t])
256-
y = Variable(:y, [t])
257-
z = Variable(:z, [t])
258-
D = Differential(t)
285+
t = Variable(:t; known = true)()
259286
σ = Variable(; known = true)
260-
ρ = Variable(; known = true)
261-
β = Variable(; known = true)
287+
ρ = Variable(; known = true)()
288+
β = Variable(; known = true)()
289+
x = Variable(:x)(t)
290+
y = Variable(:y)(t)
291+
z = Variable(:z)(t)
292+
D = Differential(t)
262293
```
263294
264295
### Intermediate Calculations
265296
266297
The system building functions can handle intermediate calculations. For example,
267298
268299
```julia
269-
@variables x y z
270-
@parameters σ ρ β
300+
@variables x() y() z()
301+
@parameters σ() ρ() β()
271302
a = y - x
272303
eqs = [0 ~ σ*a,
273304
0 ~ x*-z)-y,
274305
0 ~ x*y - β*z]
275-
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
276-
nlsys_func = generate_function(ns)
306+
ns = NonlinearSystem(eqs, [x,y,z])
307+
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
277308
```
278309
279310
expands to:

src/ModelingToolkit.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ModelingToolkit
22

33
export Operation, Expression
44
export calculate_jacobian, generate_jacobian, generate_function
5+
export independent_variables, dependent_variables, parameters
56
export @register
67

78

@@ -22,6 +23,10 @@ function calculate_jacobian end
2223
function generate_jacobian end
2324
function generate_function end
2425

26+
function independent_variables end
27+
function dependent_variables end
28+
function parameters end
29+
2530
@enum FunctionVersion ArrayFunction=1 SArrayFunction=2
2631

2732
include("variables.jl")

src/differentials.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,27 @@ export Differential, expand_derivatives, @derivatives
44
struct Differential <: Function
55
x::Expression
66
end
7+
(D::Differential)(x) = Operation(D, Expression[x])
78

89
Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
910
Base.convert(::Type{Expr}, D::Differential) = D
1011

11-
(D::Differential)(x::Operation) = Operation(D, Expression[x])
12-
function (D::Differential)(x::Variable)
13-
D.x === x && return Constant(1)
14-
has_dependent(x, D.x) || return Constant(0)
15-
return Operation(D, Expression[x])
16-
end
17-
(::Differential)(::Any) = Constant(0)
1812
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
1913

2014
function expand_derivatives(O::Operation)
2115
@. O.args = expand_derivatives(O.args)
2216

23-
if O.op isa Differential
24-
D = O.op
25-
o = O.args[1]
26-
isa(o, Operation) || return O
27-
return simplify_constants(sum(i->derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args)))
17+
if isa(O.op, Differential)
18+
(D, o) = (O.op, O.args[1])
19+
20+
isequal(o, D.x) && return Constant(1)
21+
occursin(D.x, o) || return Constant(0)
22+
isa(o, Operation) || return O
23+
isa(o.op, Variable) && return O
24+
25+
return sum(1:length(o.args)) do i
26+
derivative(o, i) * expand_derivatives(D(o.args[i]))
27+
end |> simplify_constants
2828
end
2929

3030
return O
@@ -80,6 +80,6 @@ macro derivatives(x...)
8080
esc(_differential_macro(x))
8181
end
8282

83-
function calculate_jacobian(eqs,vars)
84-
Expression[Differential(vars[j])(eqs[i]) for i in 1:length(eqs), j in 1:length(vars)]
83+
function calculate_jacobian(eqs, dvs)
84+
Expression[Differential(dv)(eq) for eq eqs, dv dvs]
8585
end

src/equations.jl

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,3 @@ Base.:(==)(a::Equation, b::Equation) = isequal((a.lhs, a.rhs), (b.lhs, b.rhs))
1010
Base.:~(lhs::Expression, rhs::Expression) = Equation(lhs, rhs)
1111
Base.:~(lhs::Expression, rhs::Number ) = Equation(lhs, rhs)
1212
Base.:~(lhs::Number , rhs::Expression) = Equation(lhs, rhs)
13-
14-
15-
_is_dependent(x::Variable) = !x.known && !isempty(x.dependents)
16-
_is_parameter(iv) = x -> x.known && !isequal(x, iv)
17-
_is_known(x::Variable) = x.known
18-
_is_unknown(x::Variable) = !x.known
19-
20-
function extract_elements(eqs, predicates)
21-
result = [Variable[] for p predicates]
22-
vars = foldl(vars!, eqs; init=Set{Variable}())
23-
24-
for var vars
25-
for (i, p) enumerate(predicates)
26-
p(var) && (push!(result[i], var); break)
27-
end
28-
end
29-
30-
return result
31-
end
32-
33-
get_args(O::Operation) = O.args
34-
get_args(eq::Equation) = Expression[eq.lhs, eq.rhs]
35-
function vars!(vars, op)
36-
for arg get_args(op)
37-
if isa(arg, Operation)
38-
vars!(vars, arg)
39-
elseif isa(arg, Variable)
40-
push!(vars, arg)
41-
for dep arg.dependents
42-
push!(vars, dep)
43-
end
44-
end
45-
end
46-
47-
return vars
48-
end

src/operations.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ struct Operation <: Expression
33
args::Vector{Expression}
44
end
55

6-
# Recursive ==
7-
function Base.isequal(x::Operation,y::Operation)
6+
Base.isequal(x::Operation,y::Operation) =
87
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
9-
end
108
Base.isequal(::Operation, ::Number ) = false
119
Base.isequal(::Number , ::Operation) = false
1210
Base.isequal(::Operation, ::Variable ) = false

0 commit comments

Comments
 (0)