Skip to content

Commit eb171d8

Browse files
committed
aux input to f_loss
1 parent 9e6424b commit eb171d8

File tree

2 files changed

+89
-84
lines changed

2 files changed

+89
-84
lines changed

src/base/sensitivity_analysis.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ function get_sensitivity_functions(
196196
sys_reinit,
197197
callbacks,
198198
tstops,
199+
aux,
199200
)
200201
p_new = sim_inputs.parameters
201202
p_new[param_ixs] .= p
@@ -226,7 +227,7 @@ function get_sensitivity_functions(
226227
sol = solve_with_callback(prob_new, callbacks, solver)
227228
ix_t = unique(i -> sol.t[i], eachindex(sol.t))
228229
states = [sol[ix, ix_t] for ix in state_ixs]
229-
return f_loss(p, states, data)
230+
return f_loss(p, states, data, aux)
230231
end
231232
function f_Zygote(
232233
p,
@@ -238,6 +239,7 @@ function get_sensitivity_functions(
238239
init_level,
239240
sys_reinit,
240241
perts,
242+
aux,
241243
)
242244
p_new = sim_inputs.parameters
243245
p_new_buff = Zygote.Buffer(p_new)
@@ -286,9 +288,9 @@ function get_sensitivity_functions(
286288
ix_t = vact(1:length(sol.t))
287289
end
288290
states = [sol[ix, ix_t] for ix in state_ixs]
289-
return f_loss(p, states, data)
291+
return f_loss(p, states, data, aux)
290292
end
291-
function f_forward(p, perts, data)
293+
function f_forward(p, perts, data, aux)
292294
callbacks, tstops = convert_perturbations_to_callbacks(sys, sim_inputs, perts)
293295
f_enzyme(
294296
p,
@@ -301,9 +303,10 @@ function get_sensitivity_functions(
301303
sys_reinit,
302304
callbacks,
303305
tstops,
306+
aux,
304307
)
305308
end
306-
function f_grad(p, perts, data)
309+
function f_grad(p, perts, data, aux)
307310
callbacks, tstops = convert_perturbations_to_callbacks(sys, sim_inputs, perts)
308311
dp = Enzyme.make_zero(p)
309312
dx0 = Enzyme.make_zero(x0)
@@ -328,10 +331,11 @@ function get_sensitivity_functions(
328331
Enzyme.Const(sys_reinit),
329332
Enzyme.Duplicated(callbacks, dcallbacks),
330333
Enzyme.Duplicated(tstops, dtstops),
334+
Enzyme.Const(aux),
331335
)
332336
return dp
333337
end
334-
function f_forward_zygote(p, perts, data)
338+
function f_forward_zygote(p, perts, data, aux)
335339
f_Zygote(
336340
p,
337341
x0,
@@ -342,6 +346,7 @@ function get_sensitivity_functions(
342346
init_level,
343347
sys_reinit,
344348
perts,
349+
aux,
345350
)
346351
end
347352
f_forward, f_grad, f_forward_zygote

0 commit comments

Comments
 (0)