Skip to content

Commit 1aee21d

Browse files
Merge pull request #472 from oscardssmith/ensemble-indexing-fixes
fix ensemble indexing
2 parents 091c91a + 30c3f74 commit 1aee21d

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

src/ensemble/ensemble_solutions.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,13 @@ end
208208
end
209209

210210

211-
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s)
211+
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
212212
return [xi[s] for xi in x]
213213
end
214214

215215
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
216216
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
217217
end
218-
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...)
219-
return [xi[args...] for xi in x]
220-
end
221218

222219
function (sol::AbstractEnsembleSolution)(args...; kwargs...)
223220
[s(args...; kwargs...) for s in sol]
Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
using ModelingToolkit, OrdinaryDiffEq, Test
22

3-
@variables t, x(t)
3+
@variables t, x(t), y(t)
44
D = Differential(t)
55

6-
@named sys1 = ODESystem([D(x) ~ 1.1*x])
7-
@named sys2 = ODESystem([D(x) ~ 1.2*x])
6+
@named sys1 = ODESystem([D(x) ~ x,
7+
D(y) ~ -y])
8+
@named sys2 = ODESystem([D(x) ~ 2x,
9+
D(y) ~ -2y])
10+
@named sys3 = ODESystem([D(x) ~ 3x,
11+
D(y) ~ -3y])
812

9-
prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0))
10-
prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0))
13+
prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0))
14+
prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0))
15+
prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0))
1116

1217
# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately
13-
ensemble_prob = EnsembleProblem([prob1, prob2])
18+
ensemble_prob = EnsembleProblem([prob1, prob2, prob3])
1419
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
15-
@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4)
20+
for i in 1:3
21+
@test sol[x, :][i] == sol[i][x]
22+
@test sol[y, :][i] == sol[i][y]
23+
end
1624
# Ensemble is a recursive array
17-
@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1)
25+
@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :])
1826
# TODO: fix the interpolation
19-
@test sol(1.0, idxs=[x]) last.(sol[:, x], 1)
27+
@test only.(sol(1.0, idxs=[x])) last.(sol[x, :])

0 commit comments

Comments
 (0)