diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 64a2f857c..ee9850e20 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -208,16 +208,13 @@ end end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon) return [xi[s] for xi in x] end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return [xi[args...] for xi in x] -end function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index e8e7640ed..dfb61f90c 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -1,19 +1,27 @@ using ModelingToolkit, OrdinaryDiffEq, Test -@variables t, x(t) +@variables t, x(t), y(t) D = Differential(t) -@named sys1 = ODESystem([D(x) ~ 1.1*x]) -@named sys2 = ODESystem([D(x) ~ 1.2*x]) +@named sys1 = ODESystem([D(x) ~ x, + D(y) ~ -y]) +@named sys2 = ODESystem([D(x) ~ 2x, + D(y) ~ -2y]) +@named sys3 = ODESystem([D(x) ~ 3x, + D(y) ~ -3y]) -prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0)) -prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0)) +prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0)) +prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0)) +prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately -ensemble_prob = EnsembleProblem([prob1, prob2]) +ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) -@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4) +for i in 1:3 + @test sol[x, :][i] == sol[i][x] + @test sol[y, :][i] == sol[i][y] +end # Ensemble is a recursive array -@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1) +@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :]) # TODO: fix the interpolation -@test sol(1.0, idxs=[x]) ≈ last.(sol[:, x], 1) +@test only.(sol(1.0, idxs=[x])) ≈ last.(sol[x, :])