Skip to content

Commit 8f7aea3

Browse files
committed
...
1 parent 947aba7 commit 8f7aea3

File tree

9 files changed

+117
-1
lines changed

9 files changed

+117
-1
lines changed

include/amici/misc.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,18 @@ class CpuTimer {
344344
};
345345
#endif
346346

347+
/**
348+
* @brief The sign function.
349+
*
350+
* @param x The value to determine the sign of.
351+
* @return -1, 0, or 1 depending on the sign of x.
352+
*/
353+
template <typename T>
354+
int sign(T x) {
355+
return (T(0) < x) - (x < T(0));
356+
}
357+
358+
347359
} // namespace amici
348360

349361
#endif // AMICI_MISC_H

include/amici/model.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,17 @@ class Model : public AbstractModel, public ModelDimensions {
13621362
*/
13631363
void updateHeaviside(std::vector<int> const& rootsfound);
13641364

1365+
/**
1366+
* @brief Disable the event with index `ie` because it just triggered.
1367+
*
1368+
* Not to be called by user code.
1369+
*
1370+
* @param ie Event index.
1371+
*/
1372+
void register_root(int const ie, int direction) {
1373+
state_.root_enabled.at(ie) = false;
1374+
state_.root_last_sign.at(ie) = direction;
1375+
}
13651376
/**
13661377
* @brief Check if the given array has only finite elements.
13671378
*

include/amici/model_state.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ struct ModelState {
2929
stotal_cl.resize((dim.nx_rdata - dim.nx_solver) * dim.np, 0.0);
3030
unscaledParameters.resize(dim.np);
3131
fixedParameters.resize(dim.nk);
32+
root_enabled.resize(dim.ne, true);
33+
root_last_sign.resize(dim.ne, 0);
3234
}
3335

3436
/**
@@ -56,14 +58,26 @@ struct ModelState {
5658
* (dimension: nplist)
5759
*/
5860
std::vector<int> plist;
61+
62+
/**
63+
* Flags indicating whether a root function element is enabled
64+
* (dimension: `ne`)
65+
*/
66+
std::vector<bool> root_enabled;
67+
68+
/**
69+
* The sign of the root function elements at the last root function call
70+
* (dimension: `ne`).
71+
*/
72+
std::vector<int> root_last_sign;
5973
};
6074

6175
inline bool operator==(ModelState const& a, ModelState const& b) {
6276
return is_equal(a.h, b.h) && is_equal(a.total_cl, b.total_cl)
6377
&& is_equal(a.stotal_cl, b.stotal_cl)
6478
&& is_equal(a.unscaledParameters, b.unscaledParameters)
6579
&& is_equal(a.fixedParameters, b.fixedParameters)
66-
&& a.plist == b.plist;
80+
&& a.plist == b.plist && a.root_enabled == b.root_enabled && a.root_last_sign == b.root_last_sign;
6781
}
6882

6983
/**

include/amici/serialization.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ void serialize(Archive& ar, amici::Model& m, unsigned int const /*version*/) {
145145
ar & m.state_.unscaledParameters;
146146
ar & m.state_.fixedParameters;
147147
ar & m.state_.plist;
148+
ar & m.state_.root_enabled;
149+
ar & m.state_.root_last_sign;
148150
ar & m.x0data_;
149151
ar & m.sx0data_;
150152
ar & m.nmaxevent_;

python/tests/test_events.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,3 +1072,47 @@ def test_event_uses_values_from_trigger_time(tempdir):
10721072
)
10731073

10741074
# TODO: test ASA after https://github.com/AMICI-dev/AMICI/pull/1539
1075+
1076+
1077+
def test_simultaneous_events(tempdir):
1078+
"""Test simultaneously firing events with different trigger functions."""
1079+
from amici.antimony_import import antimony2amici
1080+
1081+
model_name = "test_simultaneous_events"
1082+
antimony2amici(
1083+
r"""
1084+
target1_0 = 1
1085+
target1 = target1_0
1086+
one = 1
1087+
target1' = one
1088+
two = 2
1089+
target2_0 = two
1090+
target2 = target2_0
1091+
target2' = 1
1092+
some_time = time
1093+
some_time' = 1
1094+
trigger_time = 1000
1095+
1096+
E1: at some_time >= trigger_time, priority=10, fromTrigger=false:
1097+
target1 = target1 + 10;
1098+
E2: at time >= trigger_time, priority=20, fromTrigger=false:
1099+
target2 = target2 + 10;
1100+
""",
1101+
model_name=model_name,
1102+
output_dir=tempdir,
1103+
)
1104+
1105+
model_module = import_model_module(model_name, tempdir)
1106+
1107+
model = model_module.get_model()
1108+
model.setTimepoints([0, 2])
1109+
solver = model.getSolver()
1110+
solver.setRelativeTolerance(1e-6)
1111+
solver.setAbsoluteTolerance(1e-6)
1112+
solver.setSensitivityOrder(SensitivityOrder.first)
1113+
solver.setSensitivityMethod(SensitivityMethod.forward)
1114+
1115+
rdata = amici.runAmiciSimulation(model, solver)
1116+
assert rdata.status == amici.AMICI_SUCCESS
1117+
assert_allclose(rdata.by_id("target1"), [1.0, 13.0])
1118+
assert_allclose(rdata.by_id("target2"), [2.0, 14.0])

src/forwardproblem.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ void EventHandlingSimulator::handle_event(
359359
? std::optional<SimulationState>(get_simulation_state())
360360
: std::nullopt)}
361361
);
362+
363+
}
364+
if(ws_->roots_found.at(ie) != 0) {
365+
model_->register_root(ie, ws_->roots_found.at(ie));
362366
}
363367
}
364368

@@ -533,6 +537,7 @@ int EventHandlingSimulator::detect_secondary_events() {
533537
} else {
534538
ws_->roots_found.at(ie) = -1;
535539
}
540+
model_->register_root(ie, ws_->roots_found.at(ie));
536541
secondevent++;
537542
} else {
538543
ws_->roots_found.at(ie) = 0;

src/model.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ void Model::initEvents(
437437
roots_found.at(ie) = 1;
438438
}
439439
}
440+
state_.root_enabled[ie] = rootvals[ie] != 0;
441+
state_.root_last_sign[ie] = events_[ie].get_initial_value()?1:-1;
440442
}
441443
}
442444

src/model_dae.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,19 @@ void Model_DAE::froot(
110110
state_.unscaledParameters.data(), state_.fixedParameters.data(),
111111
state_.h.data(), N_VGetArrayPointerConst(dx)
112112
);
113+
114+
for (int ie = 0; ie < ne; ++ie) {
115+
if (!state_.root_enabled[ie]) {
116+
if (root[ie] < 0.0) {
117+
// If the disabled root function becomes negative,
118+
// re-enable it.
119+
state_.root_enabled[ie] = true;
120+
} else {
121+
// If the root function is disabled, mask it
122+
root[ie] = 1.0;
123+
}
124+
}
125+
}
113126
}
114127

115128
void Model_DAE::fxdot(

src/model_ode.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ void Model_ODE::froot(realtype t, const_N_Vector x, gsl::span<realtype> root) {
9797
state_.unscaledParameters.data(), state_.fixedParameters.data(),
9898
state_.h.data(), state_.total_cl.data()
9999
);
100+
101+
for (int ie = 0; ie < ne; ++ie) {
102+
auto sgn = sign(root[ie]);
103+
if (!state_.root_enabled[ie] && sgn != 0 && sgn != state_.root_last_sign[ie]) {
104+
// The sign flipped, so we re-enable the root function
105+
state_.root_enabled[ie] = true;
106+
}
107+
108+
if(!state_.root_enabled[ie]) {
109+
// If the root function is disabled, mask it
110+
root[ie] = (state_.root_last_sign[ie] > 0) ? 1.0 : -1.0;
111+
}
112+
}
100113
}
101114

102115
void Model_ODE::fxdot(

0 commit comments

Comments
 (0)