@@ -196,6 +196,7 @@ function get_sensitivity_functions(
196196 sys_reinit,
197197 callbacks,
198198 tstops,
199+ aux,
199200 )
200201 p_new = sim_inputs. parameters
201202 p_new[param_ixs] .= p
@@ -226,7 +227,7 @@ function get_sensitivity_functions(
226227 sol = solve_with_callback (prob_new, callbacks, solver)
227228 ix_t = unique (i -> sol. t[i], eachindex (sol. t))
228229 states = [sol[ix, ix_t] for ix in state_ixs]
229- return f_loss (p, states, data)
230+ return f_loss (p, states, data, aux )
230231 end
231232 function f_Zygote (
232233 p,
@@ -238,6 +239,7 @@ function get_sensitivity_functions(
238239 init_level,
239240 sys_reinit,
240241 perts,
242+ aux,
241243 )
242244 p_new = sim_inputs. parameters
243245 p_new_buff = Zygote. Buffer (p_new)
@@ -286,9 +288,9 @@ function get_sensitivity_functions(
286288 ix_t = vact (1 : length (sol. t))
287289 end
288290 states = [sol[ix, ix_t] for ix in state_ixs]
289- return f_loss (p, states, data)
291+ return f_loss (p, states, data, aux )
290292 end
291- function f_forward (p, perts, data)
293+ function f_forward (p, perts, data, aux )
292294 callbacks, tstops = convert_perturbations_to_callbacks (sys, sim_inputs, perts)
293295 f_enzyme (
294296 p,
@@ -301,9 +303,10 @@ function get_sensitivity_functions(
301303 sys_reinit,
302304 callbacks,
303305 tstops,
306+ aux,
304307 )
305308 end
306- function f_grad (p, perts, data)
309+ function f_grad (p, perts, data, aux )
307310 callbacks, tstops = convert_perturbations_to_callbacks (sys, sim_inputs, perts)
308311 dp = Enzyme. make_zero (p)
309312 dx0 = Enzyme. make_zero (x0)
@@ -328,10 +331,11 @@ function get_sensitivity_functions(
328331 Enzyme. Const (sys_reinit),
329332 Enzyme. Duplicated (callbacks, dcallbacks),
330333 Enzyme. Duplicated (tstops, dtstops),
334+ Enzyme. Const (aux),
331335 )
332336 return dp
333337 end
334- function f_forward_zygote (p, perts, data)
338+ function f_forward_zygote (p, perts, data, aux )
335339 f_Zygote (
336340 p,
337341 x0,
@@ -342,6 +346,7 @@ function get_sensitivity_functions(
342346 init_level,
343347 sys_reinit,
344348 perts,
349+ aux,
345350 )
346351 end
347352 f_forward, f_grad, f_forward_zygote
0 commit comments