37
37
38
38
# ## Operator Forms
39
39
40
- struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffVecProd
40
+ """
41
+ VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff())
42
+
43
+ Returns SciMLOperators.FunctionOperator which computes vector-jacobian
44
+ product `df/du * v`.
45
+
46
+ ```
47
+ L = VecJac(f, u)
48
+
49
+ L * v # = df/du * v
50
+ mul!(w, L, v) # = df/du * v
51
+
52
+ L(v, p, t; VJP_input = w) # = df/dw * v
53
+ L(x, v, p, t; VJP_input = w) # = df/dw * v
54
+ ```
55
+ """
56
+ function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ;
57
+ autodiff = AutoFiniteDiff (), kwargs... )
58
+
59
+ L = _vecjac (f, u, autodiff)
60
+ IIP, OOP = get_iip_oop (L)
61
+
62
+ if isa (autodiff, AutoZygote) & ! OOP
63
+ msg = " Zygote requires an out of place method with signature f(u)."
64
+ throw (ArgumentError (msg))
65
+ end
66
+
67
+ FunctionOperator (L, u, u; isinplace = IIP, outofplace = OOP,
68
+ p = p, t = t, islinear = true ,
69
+ accepted_kwargs = (:VJP_input ,), kwargs... )
70
+ end
71
+
72
+ function _vecjac (f, u, autodiff:: AutoFiniteDiff )
73
+
74
+ cache = (similar (u), similar (u))
75
+ pullback = nothing
76
+
77
+ AutoDiffVJP (f, u, cache, autodiff, pullback)
78
+ end
79
+
80
+ mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd
81
+ """ Compute VJP of `f` at `u`, applied to vector `v`: `df/du' * u` """
41
82
f:: F
83
+ """ input to `f` """
42
84
u:: U
85
+ """ Cache for num_vecjac! when autodiff isa AutoFintieDiff """
43
86
cache:: C
44
- vecprod:: V
45
- vecprod!:: V!
87
+ """ Type of automatic differentiation algorithm """
88
+ autodiff:: AD
89
+ """ stores the result of Zygote.pullback for AutoZygote """
90
+ pullback:: PB
91
+
92
+ function AutoDiffVJP (f, u, cache, autodiff, pullback)
46
93
47
- function RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!;
48
- autodiff = AutoFiniteDiff (),
49
- isinplace = false , outofplace = true )
50
- @assert isinplace || outofplace
94
+ outofplace = static_hasmethod (f, typeof ((u,)))
95
+ isinplace = static_hasmethod (f, typeof ((u, u)))
96
+
97
+ if ! (isinplace) & ! (outofplace)
98
+ msg = " $f must have signature f(u), or f(du, u)"
99
+ throw (ArgumentError (msg))
100
+ end
51
101
52
102
new{
53
103
typeof (autodiff),
@@ -56,72 +106,58 @@ struct RevModeAutoDiffVecProd{ad, iip, oop, F, U, C, V, V!} <: AbstractAutoDiffV
56
106
typeof (f),
57
107
typeof (u),
58
108
typeof (cache),
59
- typeof (vecprod),
60
- typeof (vecprod!)
61
- }(f, u, cache, vecprod, vecprod!)
109
+ typeof (pullback),
110
+ }(
111
+ f, u, cache, autodiff, pullback,
112
+ )
62
113
end
63
114
end
64
115
65
- function update_coefficients (L:: RevModeAutoDiffVecProd , u, p, t)
66
- f = update_coefficients (L. f, u, p, t)
67
- RevModeAutoDiffVecProd (f, u, L. vecprod, L. vecprod!, L. cache)
116
+ function get_iip_oop (:: AutoDiffVJP{AD, IIP, OOP} ) where {AD, IIP, OOP}
117
+ IIP, OOP
68
118
end
69
119
70
- function update_coefficients! (L:: RevModeAutoDiffVecProd , u, p, t)
71
- update_coefficients! (L. f, u, p, t)
72
- copy! (L. u, u)
73
- L
120
+ function update_coefficients (L:: AutoDiffVJP{AD} , u, p, t; VJP_input = nothing ,
121
+ ) where {AD <: AutoFiniteDiff }
122
+
123
+ if ! isnothing (VJP_input)
124
+ @set! L. u = VJP_input
125
+ end
126
+
127
+ @set! L. f = update_coefficients (L. f, L. u, p, t)
74
128
end
75
129
76
- # Interpret the call as df/du' * u
77
- function (L:: RevModeAutoDiffVecProd )(v, p, t)
78
- L. vecprod (L. f, L. u, v)
130
+ function update_coefficients! (L:: AutoDiffVJP{AD} , u, p, t; VJP_input = nothing ,
131
+ ) where {AD <: AutoFiniteDiff }
132
+
133
+ if ! isnothing (VJP_input)
134
+ copy! (L. u, VJP_input)
135
+ end
136
+
137
+ update_coefficients! (L. f, L. u, p, t)
138
+
139
+ L
79
140
end
80
141
81
- # prefer non in-place method
82
- function (L:: RevModeAutoDiffVecProd{ad, iip, true} )(dv, v, p, t) where {ad, iip}
83
- L. vecprod! (dv, L. f, L. u, v, L. cache... )
142
+ # Interpret the call as df/du' * v
143
+ function (L:: AutoDiffVJP{AD} )(v, p, t; VJP_input = nothing ,) where {AD <: AutoFiniteDiff }
144
+ # ignore VJP_input as L.u was set in update_coefficients(...)
145
+ num_vecjac (L. f, L. u, v)
84
146
end
85
147
86
- function (L:: RevModeAutoDiffVecProd{ad, true, false} )(dv, v, p, t) where {ad}
87
- L. vecprod! (dv, L. f, L. u, v, L. cache... )
148
+ function (L:: AutoDiffVJP{AD} )(dv, v, p, t; VJP_input = nothing ,) where {AD <: AutoFiniteDiff }
149
+ # ignore VJP_input as L.u was set in update_coefficients!(...)
150
+ num_vecjac! (dv, L. f, L. u, v, L. cache... )
88
151
end
89
152
90
- function Base. resize! (L:: RevModeAutoDiffVecProd , n:: Integer )
153
+ function Base. resize! (L:: AutoDiffVJP , n:: Integer )
91
154
92
155
static_hasmethod (resize!, typeof ((L. f, n))) && resize! (L. f, n)
93
156
resize! (L. u, n)
94
157
95
158
for v in L. cache
96
159
resize! (v, n)
97
160
end
98
- end
99
-
100
- function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ; autodiff = AutoFiniteDiff (),
101
- kwargs... )
102
- vecprod, vecprod! = if autodiff isa AutoFiniteDiff
103
- num_vecjac, num_vecjac!
104
- elseif autodiff isa AutoZygote
105
- @assert static_hasmethod (auto_vecjac, typeof ((f, u, u))) " To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
106
161
107
- auto_vecjac, auto_vecjac!
108
- end
109
-
110
- cache = (similar (u), similar (u))
111
-
112
- outofplace = static_hasmethod (f, typeof ((u,)))
113
- isinplace = static_hasmethod (f, typeof ((u, u)))
114
-
115
- if ! (isinplace) & ! (outofplace)
116
- error (" $f must have signature f(u), or f(du, u)" )
117
- end
118
-
119
- L = RevModeAutoDiffVecProd (f, u, cache, vecprod, vecprod!; autodiff = autodiff,
120
- isinplace = isinplace, outofplace = outofplace)
121
-
122
- FunctionOperator (L, u, u;
123
- isinplace = isinplace, outofplace = outofplace,
124
- p = p, t = t, islinear = true ,
125
- kwargs... )
126
162
end
127
163
#
0 commit comments