From eefbdc5fb4d3ca32ebbb374e9b3b8ed0c87043c1 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 17 Jul 2023 17:08:39 -0400 Subject: [PATCH 1/2] fix ensemble indexing --- src/ensemble/ensemble_solutions.jl | 5 +---- test/downstream/ensemble_multi_prob.jl | 26 +++++++++++++++++--------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 64a2f857c..ee9850e20 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -208,16 +208,13 @@ end end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon) return [xi[s] for xi in x] end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return [xi[args...] for xi in x] -end function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index e8e7640ed..96facd548 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -1,19 +1,27 @@ using ModelingToolkit, OrdinaryDiffEq, Test -@variables t, x(t) +@variables t, x(t), y(t) D = Differential(t) -@named sys1 = ODESystem([D(x) ~ 1.1*x]) -@named sys2 = ODESystem([D(x) ~ 1.2*x]) +@named sys1 = ODESystem([D(x) ~ x, + D(y) ~ -y]) +@named sys2 = ODESystem([D(x) ~ 2x, + D(y) ~ -2y]) +@named sys3 = ODESystem([D(x) ~ 3x, + D(y) ~ -3y]) -prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0)) -prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0)) +prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0)) +prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0)) +prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately -ensemble_prob = EnsembleProblem([prob1, prob2]) +ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) -@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4) +for i in 1:3 + @test sol[x, :][i] == sol[i][x] + @test sol[y, :][i] == sol[i][y] +end # Ensemble is a recursive array -@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1) +@test Matrix(sol(0.0, idxs=[x])) == sol[1:1, 1, :] == Matrix(first(eachrow(sol[x, :]))') # TODO: fix the interpolation -@test sol(1.0, idxs=[x]) ≈ last.(sol[:, x], 1) +@test vec(sol(1.0, idxs=[x])) ≈ last.(sol[x, :].u) From 30c3f7484a793cbe41f5505bcd1dfb96fc0d5657 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 21 Jul 2023 09:16:36 -0400 Subject: [PATCH 2/2] fix test --- test/downstream/ensemble_multi_prob.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index 96facd548..dfb61f90c 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -22,6 +22,6 @@ for i in 1:3 @test sol[y, :][i] == sol[i][y] end # Ensemble is a recursive array -@test Matrix(sol(0.0, idxs=[x])) == sol[1:1, 1, :] == Matrix(first(eachrow(sol[x, :]))') +@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :]) # TODO: fix the interpolation -@test vec(sol(1.0, idxs=[x])) ≈ last.(sol[x, :].u) +@test only.(sol(1.0, idxs=[x])) ≈ last.(sol[x, :])