diff --git a/Manifest.toml b/Manifest.toml index e8a588d..d81184c 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.13.0-DEV" manifest_format = "2.0" -project_hash = "d2c28a8e33664424dc750db4dae46c782768f682" +project_hash = "746cb775f4faad2538ec3bf8181fbd2c66618df8" [[deps.ADTypes]] git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32" diff --git a/benchmark/main.jl b/benchmark/main.jl new file mode 100644 index 0000000..1466230 --- /dev/null +++ b/benchmark/main.jl @@ -0,0 +1,10 @@ +using Test +using SciMLBase, Sundials + +include("thermalfluid.jl") + +Benchmark{3}()() +@test isa(code_lowered(DAECompiler.factory, Tuple{Val{DAECompiler.Settings(mode=DAECompiler.DAENoInit)}, Benchmark{3}})[1], Core.CodeInfo) +let sol = solve(DAECProblem(Benchmark{3}(), [1:9;] .=> 0.), IDA()) + @test_broken sol.retcode == ReturnCode.Success +end diff --git a/benchmark/thermalfluid.jl b/benchmark/thermalfluid.jl index b7ea3fb..1ed8cef 100644 --- a/benchmark/thermalfluid.jl +++ b/benchmark/thermalfluid.jl @@ -2,8 +2,6 @@ using DAECompiler using DAECompiler.Intrinsics using Polynomials: fit using XSteam: my_pT, rho_pT, Cp_pT, tc_pT -using SciMLBase, Sundials -using Test struct Sink; end function (::Sink)(inlet) @@ -282,9 +280,3 @@ function (::Benchmark{N})() where {N} PreinsulatedPipe{N}()(in[2], out[1]) Sink()(out[2]) end - -Benchmark{3}()() -@test isa(code_lowered(DAECompiler.factory, Tuple{Val{DAECompiler.Settings(mode=DAECompiler.DAENoInit)}, Benchmark{3}})[1], Core.CodeInfo) -let sol = solve(DAECProblem(Benchmark{3}(), [1:9;] .=> 0.), IDA()) - @test_broken sol.retcode == ReturnCode.Success -end diff --git a/src/DAECompiler.jl b/src/DAECompiler.jl index e77ab60..a1d7d7f 100644 --- a/src/DAECompiler.jl +++ b/src/DAECompiler.jl @@ -6,7 +6,7 @@ module DAECompiler using Diffractor using OrderedCollections using Compiler - using Compiler: IRCode, IncrementalCompact, DebugInfoStream, NewInstruction, argextype, singleton_type, isexpr, widenconst + using Compiler: AbstractLattice, IRCode, IncrementalCompact, DebugInfoStream, NewInstruction, argextype, singleton_type, isexpr, widenconst using Core.IR using SciMLBase using AutoHashEquals @@ -16,18 +16,21 @@ module DAECompiler include("settings.jl") include("utils.jl") include("intrinsics.jl") + include("reflection.jl") include("analysis/utils.jl") include("analysis/lattice.jl") include("analysis/ADAnalyzer.jl") include("analysis/scopes.jl") + include("analysis/flattening.jl") include("analysis/cache.jl") include("analysis/refiner.jl") include("analysis/ipoincidence.jl") include("analysis/structural.jl") - include("analysis/flattening.jl") include("transform/state_selection.jl") include("transform/common.jl") include("transform/runtime.jl") + include("transform/unoptimized.jl") + include("transform/reconstruct.jl") include("transform/tearing/schedule.jl") include("transform/codegen/dae_factory.jl") include("transform/codegen/ode_factory.jl") @@ -40,5 +43,4 @@ module DAECompiler include("analysis/consistency.jl") include("interface.jl") include("problem_interface.jl") - include("reflection.jl") end diff --git a/src/analysis/ADAnalyzer.jl b/src/analysis/ADAnalyzer.jl index a47756f..70ab31e 100644 --- a/src/analysis/ADAnalyzer.jl +++ b/src/analysis/ADAnalyzer.jl @@ -62,6 +62,7 @@ end struct AnalyzedSource ir::Compiler.IRCode + slotnames::Vector{Any} inline_cost::Compiler.InlineCostType nargs::UInt isva::Bool @@ -72,10 +73,16 @@ end Core.svec(edges..., interp.edges...) end +function get_slotnames(def::Method) + names = split(def.slot_syms, '\0') + return map(Symbol, names) +end + @override function Compiler.transform_result_for_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector) ir = result.src.optresult.ir + slotnames = get_slotnames(result.linfo.def) params = Compiler.OptimizationParams(interp) - return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva) + return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva) end @override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult) @@ -83,8 +90,9 @@ end return nothing end ir = result.src.optresult.ir + slotnames = get_slotnames(result.linfo.def) params = Compiler.OptimizationParams(interp) - return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva) + return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva) end function Compiler.retrieve_ir_for_inlining(ci::CodeInstance, result::AnalyzedSource) diff --git a/src/analysis/cache.jl b/src/analysis/cache.jl index 0228f77..c5d219e 100644 --- a/src/analysis/cache.jl +++ b/src/analysis/cache.jl @@ -21,6 +21,19 @@ end var_schedule::Vector{Pair{BitSet, BitSet}} end +struct VariableReplacement + replaced::Incidence + by::Int + equation::Int +end + +mutable struct NonlinearReplacementMap + const variables::Vector{VariableReplacement} + variable_counter::Int + equation_counter::Int +end +NonlinearReplacementMap() = NonlinearReplacementMap(VariableReplacement[], 0, 0) + """ StructuralSSARef @@ -35,6 +48,7 @@ struct DAEIPOResult opaque_eligible::Bool extended_rt::Any argtypes + argmap::ArgumentMap nexternalargvars::Int # total vars is length(var_to_diff) nsysmscopes::Int nexternaleqs::Int @@ -44,6 +58,7 @@ struct DAEIPOResult total_incidence::Vector{Any} eqclassification::Vector{VarEqClassification} eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{StructuralSSARef, Int}}}} + replacement_map::NonlinearReplacementMap names::OrderedDict{Any, ScopeDictEntry} # TODO: OrderedIdDict varkinds::Vector{Union{Intrinsics.VarKind, Nothing}} eqkinds::Vector{Union{Intrinsics.EqKind, Nothing}} diff --git a/src/analysis/flattening.jl b/src/analysis/flattening.jl index 95a018a..4e7bb3d 100644 --- a/src/analysis/flattening.jl +++ b/src/analysis/flattening.jl @@ -1,117 +1,187 @@ -function _flatten_parameter!(๐•ƒ, compact, argtypes, ntharg, line, settings) - list = Any[] - for (argn, argt) in enumerate(argtypes) - if isa(argt, Const) - continue - elseif Base.issingletontype(argt) - continue - elseif Base.isprimitivetype(argt) || isa(argt, Incidence) - push!(list, ntharg(argn)) - elseif argt === equation || isa(argt, Eq) - continue - elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope - continue - elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct)) - continue - else - if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing - continue - end - this = ntharg(argn) - nthfield(i) = @insert_instruction_here(compact, line, settings, getfield(this, i)::Compiler.getfield_tfunc(๐•ƒ, argextype(this, compact), Const(i))) - if isa(argt, PartialStruct) - fields = _flatten_parameter!(๐•ƒ, compact, argt.fields, nthfield, line, settings) - else - fields = _flatten_parameter!(๐•ƒ, compact, fieldtypes(argt), nthfield, line, settings) - end - append!(list, fields) - end - end - return list +const CompositeIndex = Vector{Int} + +struct ArgumentMap + variables::Vector{CompositeIndex} # index into argument tuple type + equations::Vector{CompositeIndex} # index into argument tuple type end +ArgumentMap() = ArgumentMap(CompositeIndex[], CompositeIndex[]) -function flatten_parameter!(๐•ƒ, compact, argtypes, ntharg, line, settings) - return @insert_instruction_here(compact, line, settings, tuple(_flatten_parameter!(๐•ƒ, compact, argtypes, ntharg, line, settings)...)::Tuple) +function ArgumentMap(argtypes::Vector{Any}) + map = ArgumentMap() + index = CompositeIndex() + fill_argument_map!(map, index, argtypes) + return map end -# Needs to match flatten_arguments! -function process_template_arg!(๐•ƒ, coeffs, eq_mapping, applied_scopes, argt, template_argt, offset=0, eqoffset=0)::Pair{Int, Int} - if isa(template_argt, Const) - @assert isa(argt, Const) && argt.val === template_argt.val - return Pair{Int, Int}(offset, eqoffset) - elseif Base.issingletontype(template_argt) - @assert isa(template_argt, Type) && argt.instance === template_argt.instance - return Pair{Int, Int}(offset, eqoffset) - elseif Base.isprimitivetype(template_argt) - coeffs[offset+1] = argt - return Pair{Int, Int}(offset + 1, eqoffset) - elseif template_argt === equation - eq_mapping[eqoffset+1] = argt.id - return Pair{Int, Int}(offset, eqoffset + 1) - elseif isabstracttype(template_argt) || ismutabletype(template_argt) || (!isa(template_argt, DataType) && !isa(template_argt, PartialStruct)) - return Pair{Int, Int}(offset, eqoffset) - else - if !isa(template_argt, PartialStruct) && Base.datatype_fieldcount(template_argt) === nothing - return Pair{Int, Int}(offset, eqoffset) - end - template_fields = isa(template_argt, PartialStruct) ? template_argt.fields : collect(fieldtypes(template_argt)) - return process_template!(๐•ƒ, coeffs, eq_mapping, applied_scopes, Any[Compiler.getfield_tfunc(๐•ƒ, argt, Const(i)) for i = 1:length(template_fields)], template_fields, offset) +function fill_argument_map!(map::ArgumentMap, index::CompositeIndex, types::Vector{Any}) + for (i, type) in enumerate(types) + push!(index, i) + fill_argument_map!(map, index, type) + pop!(index) + end +end + +function fill_argument_map!(map::ArgumentMap, index::CompositeIndex, @nospecialize(type)) + if isprimitivetype(type) || isa(type, Incidence) + push!(map.variables, copy(index)) + elseif type === equation + push!(map.equations, copy(index)) + elseif isa(type, PartialStruct) || isstructtype(type) + fields = isa(type, PartialStruct) ? type.fields : collect(Any, fieldtypes(type)) + fill_argument_map!(map, index, fields) end end -function process_template!(๐•ƒ, coeffs, eq_mapping, applied_scopes, argtypes, template_argtypes, offset=0, eqoffset=0) - @assert length(argtypes) == length(template_argtypes) - for (i, template_arg) in enumerate(template_argtypes) - (offset, eqoffset) = process_template_arg!(๐•ƒ, coeffs, eq_mapping, applied_scopes, argtypes[i], template_arg, offset) +struct FlatteningState + compact::IncrementalCompact + settings::Settings + map::ArgumentMap + nvariables::Int + nequations::Int +end + +function FlatteningState(compact::IncrementalCompact, settings::Settings, map::ArgumentMap) + FlatteningState(compact, settings, deepcopy(map), length(map.variables), length(map.equations)) +end + +function next_variable!(state::FlatteningState) + popfirst!(state.map.variables) + return state.nvariables - length(state.map.variables) +end + +function next_equation!(state::FlatteningState) + popfirst!(state.map.equations) + return state.nequations - length(state.map.equations) +end + +function flatten_arguments!(state::FlatteningState) + argtypes = copy(state.compact.ir.argtypes) + empty!(state.compact.ir.argtypes) # will be recomputed during flattening + args = flatten_arguments!(state, argtypes) + if args !== nothing + @assert isempty(state.map.variables) + @assert isempty(state.map.equations) end - return Pair{Int, Int}(offset, eqoffset) + return args end -struct TransformedArg - ssa::Any - offset::Int - eqoffset::Int - TransformedArg(@nospecialize(arg), new_offset::Int, new_eqoffset::Int) = new(arg, new_offset, new_eqoffset) +function flatten_arguments!(state::FlatteningState, argtypes::Vector{Any}) + args = Any[] + for argt in argtypes + arg = flatten_argument!(state, argt) + arg === nothing && return nothing + push!(args, arg) + end + return args end -function flatten_argument!(compact::Compiler.IncrementalCompact, settings::Settings, @nospecialize(argt), offset::Int, eqoffset::Int, argtypes::Vector{Any})::TransformedArg +function flatten_argument!(state::FlatteningState, @nospecialize(argt)) @assert !isa(argt, Incidence) && !isa(argt, Eq) + (; compact, settings) = state if isa(argt, Const) - return TransformedArg(argt.val, offset, eqoffset) + return argt.val elseif Base.issingletontype(argt) - return TransformedArg(argt.instance, offset, eqoffset) - elseif Base.isprimitivetype(argt) - push!(argtypes, argt) - return TransformedArg(Argument(offset+1), offset+1, eqoffset) + return argt.instance + elseif isprimitivetype(argt) + push!(state.compact.ir.argtypes, argt) + return Argument(next_variable!(state)) elseif argt === equation + eq = next_equation!(state) line = compact[Compiler.OldSSAValue(1)][:line] - ssa = @insert_instruction_here(compact, line, settings, (:invoke)(nothing, InternalIntrinsics.external_equation)::Eq(eqoffset+1)) - return TransformedArg(ssa, offset, eqoffset+1) + ssa = @insert_instruction_here(compact, line, settings, (:invoke)(nothing, InternalIntrinsics.external_equation)::Eq(eq)) + return ssa + elseif argt <: Type + return argt.parameters[1] elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct)) line = compact[Compiler.OldSSAValue(1)][:line] ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{}) - return TransformedArg(ssa, -1, eqoffset) + return nothing else if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing line = compact[Compiler.OldSSAValue(1)][:line] ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{}) - return TransformedArg(ssa, -1, eqoffset) + return nothing end - (args, _, offset) = flatten_arguments!(compact, settings, isa(argt, PartialStruct) ? argt.fields : collect(Any, fieldtypes(argt)), offset, eqoffset, argtypes) - offset == -1 && return TransformedArg(ssa, -1, eqoffset) + fields = isa(argt, PartialStruct) ? argt.fields : collect(Any, fieldtypes(argt)) + args = flatten_arguments!(state, fields) + args === nothing && return nothing this = Expr(:new, isa(argt, PartialStruct) ? argt.typ : argt, args...) line = compact[Compiler.OldSSAValue(1)][:line] ssa = @insert_instruction_here(compact, line, settings, this::argt) - return TransformedArg(ssa, offset, eqoffset) + return ssa end end -function flatten_arguments!(compact::Compiler.IncrementalCompact, settings::Settings, argtypes::Vector{Any}, offset::Int=0, eqoffset::Int=0, new_argtypes::Vector{Any} = Any[]) - args = Any[] - for argt in argtypes - (; ssa, offset, eqoffset) = flatten_argument!(compact, settings, argt, offset, eqoffset, new_argtypes) - offset == -1 && break - push!(args, ssa) +function flatten_arguments_for_callee!(compact::IncrementalCompact, map::ArgumentMap, argtypes, args, line, settings, ๐•ƒ = Compiler.fallback_lattice) + list = Any[] + this = nothing + last_index = CompositeIndex() + for index in map.variables + from = findfirst(j -> get(last_index, j, -1) !== index[j], eachindex(index))::Int + for i in from:length(index) + field = index[i] + if i == 1 + this = args[field] + else + thistype = argextype(this, compact) + fieldtype = Compiler.getfield_tfunc(๐•ƒ, thistype, Const(field)) + this = @insert_instruction_here(compact, line, settings, getfield(this, field)::fieldtype) + end + end + push!(list, this) + end + return list +end + +remove_variable_and_equation_annotations(argtypes) = Any[widenconst(T) for T in argtypes] + +function annotate_variables_and_equations(argtypes::Vector{Any}, map::ArgumentMap) + argtypes_annotated = Any[] + pstructs = Dict{CompositeIndex,PartialStruct}() + for (i, arg) in enumerate(argtypes) + if arg !== equation && arg !== Incidence && isstructtype(arg) && (any(==(i) โˆ˜ first, map.variables) || any(==(i) โˆ˜ first, map.equations)) + arg = init_partialstruct(arg) + pstructs[[i]] = arg + end + push!(argtypes_annotated, arg) + end + + function fields_for_index(index) + length(index) > 1 || return argtypes_annotated + # Find the parent `PartialStruct` that holds the variable field, + # creating any further `PartialStruct` going down if necessary. + i, base = find_base(pstructs, index) + local fields = base.fields + for j in @view index[(i + 1):(end - 1)] + pstruct = init_partialstruct(fields[j]) + fields[j] = pstruct + fields = pstruct.fields + end + return fields + end + + # Populate `PartialStruct` variable fields with an `Incidence` lattice element. + for (variable, index) in enumerate(map.variables) + fields = fields_for_index(index) + type = get_fieldtype(argtypes, index) + fields[index[end]] = Incidence(type, variable) + end + + # Do the same for equations with an `Eq` lattice element. + for (equation, index) in enumerate(map.equations) + fields = fields_for_index(index) + fields[index[end]] = Eq(equation) + end + + return argtypes_annotated +end + +init_partialstruct(@nospecialize(T)) = PartialStruct(T, collect(Any, fieldtypes(T))) +init_partialstruct(pstruct::PartialStruct) = pstruct + +function find_base(dict::Dict{CompositeIndex}, index::CompositeIndex) + for i in reverse(eachindex(index)) + base = get(dict, @view(index[1:i]), nothing) + base !== nothing && return i, base end - return (args, new_argtypes, offset, eqoffset) end diff --git a/src/analysis/ipoincidence.jl b/src/analysis/ipoincidence.jl index 6804315..ce3610d 100644 --- a/src/analysis/ipoincidence.jl +++ b/src/analysis/ipoincidence.jl @@ -18,10 +18,10 @@ function compute_missing_coeff!(coeffs, (;callee_result, caller_var_to_diff, cal # First find the rootvar, and if we already have a coeff for it # apply the derivatives. ndiffs = 0 - calle_inv = invview(callee_result.var_to_diff) - while calle_inv[v] !== nothing && !isassigned(coeffs, v) + callee_inv = invview(callee_result.var_to_diff) + while callee_inv[v] !== nothing && !isassigned(coeffs, v) ndiffs += 1 - v = calle_inv[v] + v = callee_inv[v] end if !isassigned(coeffs, v) @@ -43,9 +43,9 @@ function compute_missing_coeff!(coeffs, (;callee_result, caller_var_to_diff, cal return nothing end -apply_linear_incidence(๐•ƒ, ret::Type, caller::CallerMappingState, mapping::CalleeMapping) = ret -apply_linear_incidence(๐•ƒ, ret::Const, caller::CallerMappingState, mapping::CalleeMapping) = ret -function apply_linear_incidence(๐•ƒ, ret::Incidence, caller::Union{CallerMappingState, Nothing}, mapping::CalleeMapping) +apply_linear_incidence!(mapping::CalleeMapping, ๐•ƒ, ret::Type, caller::CallerMappingState) = ret +apply_linear_incidence!(mapping::CalleeMapping, ๐•ƒ, ret::Const, caller::CallerMappingState) = ret +function apply_linear_incidence!(mapping::CalleeMapping, ๐•ƒ, ret::Incidence, caller::Union{CallerMappingState, Nothing}) # Substitute variables returned by the callee with the incidence defined by the caller. # The composition will be additive in the constant terms, and multiplicative for linear coefficients. caller_variables = mapping.var_coeffs @@ -141,7 +141,7 @@ function compose_additive_term(@nospecialize(a), @nospecialize(b), coeff) return Const(val) end -function apply_linear_incidence(๐•ƒ, ret::Eq, caller::CallerMappingState, mapping::CalleeMapping) +function apply_linear_incidence!(mapping::CalleeMapping, ๐•ƒ, ret::Eq, caller::CallerMappingState) eq_mapping = mapping.eqs[ret.id] if eq_mapping == 0 push!(caller.caller_eqclassification, Owned) @@ -151,21 +151,48 @@ function apply_linear_incidence(๐•ƒ, ret::Eq, caller::CallerMappingState, mappi return Eq(eq_mapping) end -function apply_linear_incidence(๐•ƒ, ret::PartialStruct, caller::CallerMappingState, mapping::CalleeMapping) - return PartialStruct(๐•ƒ, ret.typ, Any[apply_linear_incidence(๐•ƒ, f, caller, mapping) for f in ret.fields]) +function apply_linear_incidence!(mapping::CalleeMapping, ๐•ƒ, ret::PartialStruct, caller::CallerMappingState) + return PartialStruct(๐•ƒ, ret.typ, Any[apply_linear_incidence!(mapping, ๐•ƒ, f, caller) for f in ret.fields]) end -function CalleeMapping(๐•ƒ::Compiler.AbstractLattice, argtypes::Vector{Any}, callee_ci::CodeInstance, callee_result::DAEIPOResult, template_argtypes) +function CalleeMapping(๐•ƒ::AbstractLattice, argtypes::Vector{Any}, callee_ci::CodeInstance, callee_result::DAEIPOResult) + caller_argtypes = Compiler.va_process_argtypes(๐•ƒ, argtypes, callee_ci.inferred.nargs, callee_ci.inferred.isva) + callee_argtypes = callee_ci.inferred.ir.argtypes + argmap = ArgumentMap(callee_argtypes) + nvars = length(callee_result.var_to_diff) + neqs = length(callee_result.total_incidence) + @assert length(argmap.variables) โ‰ค nvars + @assert length(argmap.equations) โ‰ค neqs + applied_scopes = Any[] - coeffs = Vector{Any}(undef, length(callee_result.var_to_diff)) - eq_mapping = fill(0, length(callee_result.total_incidence)) + coeffs = Vector{Any}(undef, nvars) + eq_mapping = fill(0, neqs) + mapping = CalleeMapping(coeffs, eq_mapping, applied_scopes) - va_argtypes = Compiler.va_process_argtypes(๐•ƒ, argtypes, callee_ci.inferred.nargs, callee_ci.inferred.isva) - process_template!(๐•ƒ, coeffs, eq_mapping, applied_scopes, va_argtypes, template_argtypes) + fill_callee_mapping!(mapping, argmap, caller_argtypes, ๐•ƒ) + return mapping +end - return CalleeMapping(coeffs, eq_mapping, applied_scopes) +function fill_callee_mapping!(mapping::CalleeMapping, argmap::ArgumentMap, argtypes::Vector{Any}, ๐•ƒ::AbstractLattice) + for (i, index) in enumerate(argmap.variables) + type = get_fieldtype(argtypes, index, ๐•ƒ) + mapping.var_coeffs[i] = type + end + for (i, index) in enumerate(argmap.equations) + eq = get_fieldtype(argtypes, index, ๐•ƒ)::Eq + mapping.eqs[i] = eq.id + end end +function get_fieldtype(argtypes::Vector{Any}, index::CompositeIndex, ๐•ƒ::AbstractLattice = Compiler.fallback_lattice) + @assert !isempty(index) + index = copy(index) + type = argtypes[popfirst!(index)] + while !isempty(index) + type = Compiler.getfield_tfunc(๐•ƒ, type, Const(popfirst!(index))) + end + return type +end struct MappingInfo <: Compiler.CallInfo info::Any diff --git a/src/analysis/lattice.jl b/src/analysis/lattice.jl index b5570c5..c9a8dd6 100644 --- a/src/analysis/lattice.jl +++ b/src/analysis/lattice.jl @@ -4,7 +4,7 @@ using SparseArrays ########################## EqStructureLattice #################################### """ - struct EqStructureLattice <: Compiler.AbstractLattice + struct EqStructureLattice <: AbstractLattice This lattice implements the `AbstractLattice` interface. It adjoins `Incidence` and `Eq`. @@ -34,7 +34,7 @@ the taint of %phi depends not only on `%a` and `%b`, but also on the taint of the branch condition `%cond`. This is a common feature of taint analysis, but is somewhat unusual from the perspective of other Julia type lattices. """ -struct EqStructureLattice <: Compiler.AbstractLattice; end +struct EqStructureLattice <: AbstractLattice; end Compiler.widenlattice(::EqStructureLattice) = Compiler.ConstsLattice() Compiler.is_valid_lattice_norec(::EqStructureLattice, @nospecialize(v)) = isa(v, Incidence) || isa(v, Eq) || isa(v, PartialScope) || isa(v, PartialKeyValue) Compiler.has_extended_unionsplit(::EqStructureLattice) = true @@ -537,7 +537,7 @@ struct PartialKeyValue end PartialKeyValue(typ) = PartialKeyValue(typ, typ, IdDict{Any, Any}()) -function getkeyvalue_tfunc(๐•ƒ::Compiler.AbstractLattice, +function getkeyvalue_tfunc(๐•ƒ::AbstractLattice, @nospecialize(collection), @nospecialize(key)) isa(key, Const) || return Tuple{Any} if haskey(collection.vals, key.val) diff --git a/src/analysis/refiner.jl b/src/analysis/refiner.jl index be5b849..1aea8a2 100644 --- a/src/analysis/refiner.jl +++ b/src/analysis/refiner.jl @@ -15,7 +15,11 @@ struct StructuralRefiner <: Compiler.AbstractInterpreter eqclassification::Vector{VarEqClassification} end -struct StructureCache; end +struct StructureCache + optimized::Bool +end +StructureCache() = StructureCache(true) +StructureCache(settings::Settings) = StructureCache(!settings.skip_optimizations) Compiler.optimizer_lattice(interp::StructuralRefiner) = Compiler.PartialsLattice(EqStructureLattice()) Compiler.typeinf_lattice(interp::StructuralRefiner) = Compiler.PartialsLattice(EqStructureLattice()) @@ -24,7 +28,7 @@ Compiler.ipo_lattice(interp::StructuralRefiner) = Compiler.PartialsLattice(EqStr Compiler.InferenceParams(interp::StructuralRefiner) = Compiler.InferenceParams() Compiler.OptimizationParams(interp::StructuralRefiner) = Compiler.OptimizationParams() Compiler.get_inference_world(interp::StructuralRefiner) = interp.world -Compiler.cache_owner(::StructuralRefiner) = StructureCache() +Compiler.cache_owner(interp::StructuralRefiner) = StructureCache(interp.settings) # This is the main logic. We visit an :invoke instruction and either apply the known transfer function for one of our # DAECompiler intrinsics or lookup the structural incidence matrix in the cache, applying it as appropriate. @@ -60,9 +64,9 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache() end argtypes = Compiler.collect_argtypes(interp, stmt.args, Compiler.StatementState(nothing, false), irsv)[2:end] - mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_codeinst, callee_result, callee_codeinst.inferred.ir.argtypes) - new_rt = apply_linear_incidence(Compiler.optimizer_lattice(interp), callee_result.extended_rt, - CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, interp.eqclassification, interp.eqkinds), mapping) + mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_codeinst, callee_result) + new_rt = apply_linear_incidence!(mapping, Compiler.optimizer_lattice(interp), callee_result.extended_rt, + CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, interp.eqclassification, interp.eqkinds)) # Remember this mapping, both for performance of not having to recompute it # and because we may have assigned caller variables to internal variables diff --git a/src/analysis/structural.jl b/src/analysis/structural.jl index 644f551..12b93d7 100644 --- a/src/analysis/structural.jl +++ b/src/analysis/structural.jl @@ -18,14 +18,14 @@ end function structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings) # Check if we have aleady done this work - if so return the cached result - result_ci = find_matching_ci(ci->ci.owner == StructureCache(), ci.def, world) + result_ci = find_matching_ci(ci->ci.owner == StructureCache(settings), ci.def, world) if result_ci !== nothing return result_ci.inferred end result = _structural_analysis!(ci, world, settings) # TODO: The world bounds might have been narrowed - cache_dae_ci!(ci, result, nothing, nothing, StructureCache()) + cache_dae_ci!(ci, result, nothing, nothing, StructureCache(settings)) return result end @@ -81,25 +81,34 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings warnings = BadDAECompilerInputException[] compact = IncrementalCompact(ir) - old_argtypes = copy(ir.argtypes) - empty!(ir.argtypes) - (arg_replacements, new_argtypes, nexternalargvars, nexternaleqs) = flatten_arguments!(compact, settings, old_argtypes, 0, 0, ir.argtypes) - if nexternalargvars == -1 - return UncompilableIPOResult(warnings, UnsupportedIRException("Unhandled argument types", Compiler.finish(compact))) + argmap = ArgumentMap(ir.argtypes) + nexternalargvars = length(argmap.variables) + nexternaleqs = length(argmap.equations) + if !settings.skip_optimizations + state = FlatteningState(compact, settings, argmap) + arg_replacements = flatten_arguments!(state) + if arg_replacements === nothing + return UncompilableIPOResult(warnings, UnsupportedIRException("Unhandled argument types", Compiler.finish(compact))) + end + argtypes = Any[Incidence(ir.argtypes[i], i) for i = 1:nexternalargvars] + else + argtypes = annotate_variables_and_equations(ir.argtypes, argmap) + arg_replacements = nothing end + for i = 1:nexternalargvars # TODO: Need to handle different var kinds for IPO + # TODO: Don't use `Argument` when we don't flatten, maybe something + # like an `ArgumentView` with composite indices from the `ArgumentMap`. add_variable!(Argument(i)) end for i = 1:nexternaleqs # Not technically an argument, but let's use it for now add_equation!(Argument(i)) end - argtypes = Any[Incidence(new_argtypes[i], i) for i = 1:nexternalargvars] # Allocate variable and equation numbers of any incoming arguments refiner = StructuralRefiner(world, settings, var_to_diff, varkinds, varclassification, eqkinds, eqclassification) - nexternalargvars = length(var_to_diff) # Go through the IR, annotating each intrinsic with an appropriate taint # source lattice element. @@ -107,10 +116,12 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings for ((old_idx, i), stmt) in compact urs = userefs(stmt) compact[SSAValue(i)] = nothing - for ur in urs - if isa(ur[], Argument) - repl = arg_replacements[ur[].n] - ur[] = repl + if arg_replacements !== nothing + for ur in urs + if isa(ur[], Argument) + repl = arg_replacements[ur[].n] + ur[] = repl + end end end stmt = urs[] @@ -238,10 +249,12 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings record_scope!(ir, warnings, names, scope, ScopeDictEntry(true, var_num)) end - # Delete - we've recorded this into our our side table, we don't need to - # keep it around in the IR - inst.args[3] = nothing - inst.args[4] = nothing + if !settings.skip_optimizations + # Delete - we've recorded this into our our side table, we don't need to + # keep it around in the IR + inst.args[3] = nothing + inst.args[4] = nothing + end end # Do the same for equations @@ -272,10 +285,12 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings record_scope!(ir, warnings, names, scope, ScopeDictEntry(false, eq_num)) end - # Delete - we've recorded this into our our side table, we don't need to - # keep it around in the IR - inst.args[3] = nothing - inst.args[4] = nothing + if !settings.skip_optimizations + # Delete - we've recorded this into our our side table, we don't need to + # keep it around in the IR + inst.args[3] = nothing + inst.args[4] = nothing + end end # Now record the association of (::equation)() calls with the equations that they originate from @@ -286,6 +301,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings handler_info = Compiler.compute_trycatch(ir) ncallees = 0 compact = IncrementalCompact(ir) + replacement_map = NonlinearReplacementMap() opaque_eligible = isempty(total_incidence) && all(==(External), varclassification) for ((old_idx, i), stmt) in compact stmt === nothing && continue @@ -302,7 +318,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings eqeq = argextype(stmt.args[2], compact) if !isa(eqeq, Eq) - return UncompilableIPOResult(warnings, UnsupportedIRException("Equation call at $ssa has unknown equation reference.", ir)) + return UncompilableIPOResult(warnings, UnsupportedIRException("Equation call at $(SSAValue(i)) has unknown equation reference.", ir)) end ieq = eqeq.id @@ -345,7 +361,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings end callee_argtypes = Any[argextype(stmt.args[i], compact) for i in 2:length(stmt.args)] - mapping = CalleeMapping(Compiler.optimizer_lattice(refiner), callee_argtypes, callee_codeinst, result, callee_codeinst.inferred.ir.argtypes) + mapping = CalleeMapping(Compiler.optimizer_lattice(refiner), callee_argtypes, callee_codeinst, result) inst[:info] = info = MappingInfo(info, result, mapping) end @@ -357,28 +373,39 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings opaque_eligible = false end - # Rewrite to flattened ABI - compact[SSAValue(i)] = nothing - compact.result_idx -= 1 - new_args = _flatten_parameter!(Compiler.optimizer_lattice(refiner), compact, callee_codeinst.inferred.ir.argtypes, arg->stmt.args[arg+1], line, settings) - - new_call = insert_instruction_here!(compact, settings, @__SOURCE__, - NewInstruction(Expr(:invoke, (StructuralSSARef(compact.result_idx), callee_codeinst), new_args...), stmtype, info, line, stmtflags)) - compact.ssa_rename[compact.idx - 1] = new_call + if !settings.skip_optimizations + # Rewrite to flattened ABI + compact[SSAValue(i)] = nothing + compact.result_idx -= 1 + callee_argtypes = callee_codeinst.inferred.ir.argtypes + callee_argmap = ArgumentMap(callee_argtypes) + args = @view(stmt.args[2:end]) + ๐•ƒ = Compiler.optimizer_lattice(refiner) + new_args = flatten_arguments_for_callee!(compact, callee_argmap, callee_argtypes, args, line, settings, ๐•ƒ) + new_call = insert_instruction_here!(compact, settings, @__SOURCE__, + NewInstruction(Expr(:invoke, (StructuralSSARef(compact.result_idx), callee_codeinst), new_args...), stmtype, info, line, stmtflags)) + compact.ssa_rename[compact.idx - 1] = new_call + ssa = StructuralSSARef(new_call.id) + else + ssa = StructuralSSARef(i) + end cms = CallerMappingState(result, refiner.var_to_diff, refiner.varclassification, refiner.varkinds, eqclassification, eqkinds) - err = add_internal_equations_to_structure!(refiner, cms, total_incidence, eq_callee_mapping, StructuralSSARef(new_call.id), - result, mapping) + err = add_internal_equations_to_structure!(refiner, cms, total_incidence, eq_callee_mapping, + ssa, result, mapping) if err !== true return UncompilableIPOResult(warnings, UnsupportedIRException(err, ir)) end + if !settings.skip_optimizations + add_replacements_from_callee!(replacement_map, result.replacement_map, mapping, cms, Compiler.typeinf_lattice(refiner)) + end end eqvars = EqVarState(var_to_diff, varclassification, varkinds, total_incidence, eqclassification, eqkinds, eq_callee_mapping) # Replace non linear return by a new variable and return that variable - if !opaque_eligible + if !opaque_eligible && !settings.skip_optimizations last_ssa = SSAValue(compact.result_idx - 1) ret_stmt_inst = compact[last_ssa] ret_stmt = ret_stmt_inst[:stmt] @@ -386,7 +413,12 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings line = ret_stmt_inst[:line] Compiler.delete_inst_here!(compact) - (new_ret, ultimate_rt) = rewrite_ipo_return!(Compiler.typeinf_lattice(refiner), compact, line, settings, ret_stmt.val, ultimate_rt, eqvars) + ๐•ƒ = Compiler.typeinf_lattice(refiner) + replacement_map.variable_counter = length(refiner.var_to_diff) + replacement_map.equation_counter = length(eqkinds) + plan_nonlinear_replacements!(replacement_map, ultimate_rt, ๐•ƒ) + + (new_ret, ultimate_rt) = rewrite_ipo_return!(๐•ƒ, compact, line, settings, ret_stmt.val, ultimate_rt, eqvars) insert_instruction_here!(compact, settings, @__SOURCE__, NewInstruction(ReturnNode(new_ret), ultimate_rt, Compiler.NoCallInfo(), line, Compiler.IR_FLAG_REFINED), reverse_affinity = true) elseif isa(ultimate_rt, Type) # If we don't have any internal variables (in which case we might have to to do a more aggressive rewrite), strengthen the incidence @@ -401,7 +433,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings var_to_diff = StateSelection.complete(var_to_diff) names = OrderedDict{Any, ScopeDictEntry}() - return DAEIPOResult(ir, opaque_eligible, ultimate_rt, argtypes, + return DAEIPOResult(ir, opaque_eligible, ultimate_rt, argtypes, argmap, nexternalargvars, nsysmscopes, nexternaleqs, @@ -409,12 +441,43 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, settings::Settings var_to_diff, varclassification, total_incidence, eqclassification, eq_callee_mapping, + replacement_map, names, varkinds, eqkinds, warnings) end +function plan_nonlinear_replacements!(map::NonlinearReplacementMap, @nospecialize(type), ๐•ƒ::AbstractLattice) + if isa(type, PartialStruct) + for field in eachindex(type.fields) + ftype = Compiler.getfield_tfunc(๐•ƒ, type, Const(field)) + plan_nonlinear_replacements!(map, ftype, ๐•ƒ) + end + elseif isa(type, Incidence) + nnz(type.row) > 0 || return + nnz(type.row) > 1 || first(nonzeros(type.row)) === nonlinear || return + add_variable_replacement!(map, type) + end +end + +function add_variable_replacement!(map::NonlinearReplacementMap, incidence::Incidence) + by = (map.variable_counter += 1) + equation = (map.equation_counter += 1) + push!(map.variables, VariableReplacement(incidence, by, equation)) +end + +function add_replacements_from_callee!(map::NonlinearReplacementMap, callee::NonlinearReplacementMap, callee_mapping::CalleeMapping, caller::CallerMappingState, ๐•ƒ::AbstractLattice) + for (; replaced, by, equation) in callee.variables + caller_replaced = apply_linear_incidence!(callee_mapping, ๐•ƒ, replaced, caller) + caller_by = idnum(callee_mapping.var_coeffs[by]) + caller_equation = callee_mapping.eqs[equation] + replacement = VariableReplacement(caller_replaced, caller_by, caller_equation) + push!(map.variables, replacement) + end + return map +end + function rewrite_ipo_return!(๐•ƒ, compact::IncrementalCompact, line, settings, ssa, ultimate_rt::Any, eqvars::EqVarState) if isa(ultimate_rt, Eq) return Pair{Any, Any}(ssa, ultimate_rt) @@ -426,7 +489,7 @@ function rewrite_ipo_return!(๐•ƒ, compact::IncrementalCompact, line, settings, for i = 1:length(ultimate_rt.fields) ssa_type = Compiler.getfield_tfunc(๐•ƒ, ultimate_rt, Const(i)) ssa_field = insert_instruction_here!(compact, settings, @__SOURCE__, - NewInstruction(Expr(:call, getfield, variable), ssa_type, Compiler.NoCallInfo(), line, Compiler.IR_FLAG_REFINED), reverse_affinity = true) + NewInstruction(Expr(:call, getfield, ssa, i), ssa_type, Compiler.NoCallInfo(), line, Compiler.IR_FLAG_REFINED), reverse_affinity = true) (new_field, new_type) = rewrite_ipo_return!(๐•ƒ, compact, line, settings, ssa_field, ssa_type, eqvars) push!(new_fields, new_field) @@ -499,7 +562,7 @@ function add_internal_equations_to_structure!(refiner::StructuralRefiner, cms::C # we're here is because it leaked an explicit reference. continue end - mapped_inc = apply_linear_incidence(Compiler.typeinf_lattice(refiner), callee_result.total_incidence[eq], cms, callee_mapping) + mapped_inc = apply_linear_incidence!(callee_mapping, Compiler.typeinf_lattice(refiner), callee_result.total_incidence[eq], cms) if isassigned(total_incidence, mapped_eq) total_incidence[mapped_eq] = tfunc(Val(Core.Intrinsics.add_float), total_incidence[mapped_eq], @@ -518,7 +581,7 @@ function add_internal_equations_to_structure!(refiner::StructuralRefiner, cms::C callee_result.eqclassification[ieq] === External && continue isassigned(callee_result.total_incidence, ieq) || continue inc = callee_result.total_incidence[ieq] - extinc = apply_linear_incidence(Compiler.typeinf_lattice(refiner), inc, cms, callee_mapping) + extinc = apply_linear_incidence!(callee_mapping, Compiler.typeinf_lattice(refiner), inc, cms) if !isa(extinc, Incidence) && !isa(extinc, Const) return "Failed to map internal incidence for equation $ieq (internal result $inc) - got $extinc while processing $thisssa" end diff --git a/src/interface.jl b/src/interface.jl index d884061..767cf9a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -30,7 +30,7 @@ function factory_gen(@nospecialize(fT), settings::Settings, world::UInt = Base.g end structure = make_structure_from_ipo(result) - tstate = TransformationState(result, structure, copy(result.total_incidence)) + tstate = TransformationState(result, structure) # Ensure that the system is a consistent DAE system err = StateSelection.check_consistency(tstate, nothing) @@ -41,20 +41,22 @@ function factory_gen(@nospecialize(fT), settings::Settings, world::UInt = Base.g end # Select differential and algebraic states - ret = top_level_state_selection!(tstate) - - if isa(ret, UncompilableIPOResult) - return Base.generated_body_to_codeinfo( - Expr(:lambda, Any[:var"#self", :settings, :f], Expr(:block, Expr(:return, Expr(:call, throw, ret.error)))), - @__MODULE__, false) - end - (diff_key, init_key) = ret - - if settings.mode in (DAE, DAENoInit, ODE, ODENoInit) - tearing_schedule!(tstate, ci, diff_key, world, settings) - end - if settings.mode in (InitUncompress, DAE, ODE) - tearing_schedule!(tstate, ci, init_key, world, settings) + if settings.skip_optimizations + diff_key = UnoptimizedKey() + else + ret = top_level_state_selection!(tstate) + if isa(ret, UncompilableIPOResult) + return Base.generated_body_to_codeinfo( + Expr(:lambda, Any[:var"#self", :settings, :f], Expr(:block, Expr(:return, Expr(:call, throw, ret.error)))), + @__MODULE__, false) + end + (diff_key, init_key) = ret + if settings.mode in (DAE, DAENoInit, ODE, ODENoInit) + tearing_schedule!(tstate, ci, diff_key, world, settings) + end + if settings.mode in (InitUncompress, DAE, ODE) + tearing_schedule!(tstate, ci, init_key, world, settings) + end end # Generate the IR implementation of `factory`, returning the DAEFunction/ODEFunction diff --git a/src/problem_interface.jl b/src/problem_interface.jl index f8ab5ab..6c4d1a5 100644 --- a/src/problem_interface.jl +++ b/src/problem_interface.jl @@ -25,8 +25,9 @@ function DAECProblem(f, init::Union{Vector, Tuple{Vararg{Pair}}}, tspan::Tuple{R force_inline_all=false, insert_stmt_debuginfo=false, insert_ssa_debuginfo=false, + skip_optimizations=false, kwargs...) - settings = Settings(; mode = DAENoInit, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo) + settings = Settings(; mode = DAENoInit, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations) DAECProblem(f, init, guesses, tspan, kwargs, settings, missing, nothing, nothing) end @@ -35,8 +36,9 @@ function DAECProblem(f, tspan::Tuple{Real, Real} = (0., 1.); force_inline_all=false, insert_stmt_debuginfo=false, insert_ssa_debuginfo=false, + skip_optimizations=false, kwargs...) - settings = Settings(; mode = DAE, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo) + settings = Settings(; mode = DAE, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations) DAECProblem(f, nothing, guesses, tspan, kwargs, settings, missing, nothing, nothing) end @@ -76,8 +78,9 @@ function ODECProblem(f, init::Union{Vector, Tuple{Vararg{Pair}}}, tspan::Tuple{R force_inline_all=false, insert_stmt_debuginfo=false, insert_ssa_debuginfo=false, + skip_optimizations=false, kwargs...) - settings = Settings(; mode = ODENoInit, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo) + settings = Settings(; mode = ODENoInit, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations) ODECProblem(f, init, guesses, tspan, kwargs, settings, missing, nothing) end @@ -86,8 +89,9 @@ function ODECProblem(f, tspan::Tuple{Real, Real} = (0., 1.); force_inline_all=false, insert_stmt_debuginfo=false, insert_ssa_debuginfo=false, + skip_optimizations=false, kwargs...) - settings = Settings(; mode = ODE, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo) + settings = Settings(; mode = ODE, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations) ODECProblem(f, nothing, guesses, tspan, kwargs, settings, missing, nothing) end diff --git a/src/reflection.jl b/src/reflection.jl index aa15b76..391c657 100644 --- a/src/reflection.jl +++ b/src/reflection.jl @@ -26,9 +26,9 @@ end code_ad_by_type(@nospecialize(tt::Type); kwargs...) = _code_ad_by_type(tt; kwargs...).inferred.ir -function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, matched = false, mode = DAE, force_inline_all = false, insert_stmt_debuginfo = false, insert_ssa_debuginfo = false, kwargs...) +function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_world_age(), result = false, matched = false, mode = DAE, force_inline_all = false, insert_stmt_debuginfo = false, insert_ssa_debuginfo = false, skip_optimizations = false, kwargs...) ci = _code_ad_by_type(tt; world, kwargs...) - settings = Settings(; mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo) + settings = Settings(; mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations) _result = structural_analysis!(ci, world, settings) isa(_result, UncompilableIPOResult) && throw(_result.error) !matched && return result ? _result : _result.ir @@ -36,7 +36,7 @@ function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_ structure = make_structure_from_ipo(result) - tstate = TransformationState(result, structure, copy(result.total_incidence)) + tstate = TransformationState(result, structure) err = StateSelection.check_consistency(tstate, nothing) err !== nothing && throw(err) diff --git a/src/settings.jl b/src/settings.jl index 62a0758..ab65159 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -14,5 +14,6 @@ struct Settings force_inline_all::Bool insert_stmt_debuginfo::Bool insert_ssa_debuginfo::Bool + skip_optimizations::Bool end -Settings(; mode::GenerationMode=DAE, force_inline_all::Bool=false, insert_stmt_debuginfo::Bool=false, insert_ssa_debuginfo::Bool=false) = Settings(mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo) +Settings(; mode::GenerationMode=DAE, force_inline_all::Bool=false, insert_stmt_debuginfo::Bool=false, insert_ssa_debuginfo::Bool=false, skip_optimizations::Bool = false) = Settings(mode, force_inline_all, insert_stmt_debuginfo, insert_ssa_debuginfo, skip_optimizations) diff --git a/src/transform/codegen/dae_factory.jl b/src/transform/codegen/dae_factory.jl index 1741c89..f53ed39 100644 --- a/src/transform/codegen/dae_factory.jl +++ b/src/transform/codegen/dae_factory.jl @@ -39,101 +39,51 @@ function make_daefunction(f, initf) DAEFunction(f; initialization_data = SciMLBase.OverrideInitData(NonlinearProblem((args...)->nothing, nothing, nothing), nothing, initf, nothing, nothing, Val{false}())) end -""" - dae_factory_gen(ci, key) - -Generate the `factory` function for CodeInstance `ci`, returning a DAEFunction. -The resulting function is roughly: - -``` -function factory(settings, f) - # Run all parts of `f` that do not depend on state - state_invariant_pieces = f_state_invariant() - f! = %new_opaque_closure(f_rhs, state_invariant_pieces) - DAEFunction(f!), differential_vars +function continuous_variables(state::TransformationState) + filter(var -> varkind(state, var) == Intrinsics.Continuous, 1:length(state.result.var_to_diff)) end -``` - -""" -function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings, init_key::Union{TornCacheKey, Nothing}) - result = state.result - torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world) - torn_ir = torn_ci.inferred - - (;ir_sicm) = torn_ir - ir_factory = copy(ci.inferred.ir) - pushfirst!(ir_factory.argtypes, Settings) - pushfirst!(ir_factory.argtypes, typeof(factory)) - compact = IncrementalCompact(ir_factory) - - local line - if ir_sicm !== nothing - sicm_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world) - @assert sicm_ci !== nothing - - line = result.ir[SSAValue(1)][:line] - param_list = flatten_parameter!(Compiler.fallback_lattice, compact, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings) - sicm = @insert_instruction_here(compact, line, settings, invoke(param_list, sicm_ci)::Tuple) - else - sicm = () - end +const SCIML_ABI = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64} - argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64} - - daef_ci = rhs_finish!(state, ci, key, world, settings, 1) - - # Create a small opaque closure to adapt from SciML ABI to our own internal - # ABI +function sciml_to_internal_abi!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, key::TornCacheKey, var_eq_matching, settings::Settings) + (; result, structure) = state numstates = zeros(Int, Int(LastEquationStateKind)) - - all_states = Int[] - for var = 1:length(result.var_to_diff) - varkind(state, var) == Intrinsics.Continuous || continue + for var in continuous_variables(state) kind = classify_var(result.var_to_diff, key, var) kind == nothing && continue numstates[kind] += 1 - (kind != AlgebraicDerivative) && push!(all_states, var) end - ir_oc = copy(ci.inferred.ir) - empty!(ir_oc.argtypes) - push!(ir_oc.argtypes, Tuple) - push!(ir_oc.argtypes, Vector{Float64}) - push!(ir_oc.argtypes, Vector{Float64}) - push!(ir_oc.argtypes, Vector{Float64}) - push!(ir_oc.argtypes, SciMLBase.NullParameters) - push!(ir_oc.argtypes, Float64) + empty!(ir.argtypes) + push!(ir.argtypes, Tuple) # opaque closure captures + append!(ir.argtypes, fieldtypes(SCIML_ABI)) - oc_compact = IncrementalCompact(ir_oc) + compact = IncrementalCompact(ir) # Zero the output - line = ir_oc[SSAValue(1)][:line] - @insert_instruction_here(oc_compact, line, settings, zero!(Argument(2))::VectorViewType) + line = ir[SSAValue(1)][:line] + @insert_instruction_here(compact, line, settings, zero!(Argument(2))::VectorViewType) # out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg nassgn = numstates[AssignedDiff] ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] - out_du_mm = @insert_instruction_here(oc_compact, line, settings, view(Argument(2), 1:nassgn)::VectorViewType) - out_eq = @insert_instruction_here(oc_compact, line, settings, view(Argument(2), (nassgn+1):ntotalstates)::VectorViewType) + out_du_mm = @insert_instruction_here(compact, line, settings, view(Argument(2), 1:nassgn)::VectorViewType) + out_eq = @insert_instruction_here(compact, line, settings, view(Argument(2), (nassgn+1):ntotalstates)::VectorViewType) - (in_du_assgn, in_du_unassgn) = sciml_dae_split_du!(oc_compact, line, settings, Argument(3), numstates) - (in_u_mm, in_u_unassgn, in_alg) = sciml_dae_split_u!(oc_compact, line, settings, Argument(4), numstates) + (in_du_assgn, in_du_unassgn) = sciml_dae_split_du!(compact, line, settings, Argument(3), numstates) + (in_u_mm, in_u_unassgn, in_alg) = sciml_dae_split_u!(compact, line, settings, Argument(4), numstates) # Call DAECompiler-generated RHS with internal ABI - oc_sicm = @insert_instruction_here(oc_compact, line, settings, getfield(Argument(1), 1)::Core.OpaqueClosure) + oc_sicm = @insert_instruction_here(compact, line, settings, getfield(Argument(1), 1)::Core.OpaqueClosure) # N.B: The ordering of arguments should match the ordering in the StateKind enum - @insert_instruction_here(oc_compact, line, settings, (:invoke)(daef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument(6))::Nothing) - - # TODO: We should not have to recompute this here - var_eq_matching = matching_for_key(state, key) - (slot_assignments, var_assignment, eq_assignment) = assign_slots(state, key, var_eq_matching) + @insert_instruction_here(compact, line, settings, (:invoke)(internal_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument(6))::Nothing) # Manually apply mass matrix and implicit equations between selected states - for v = 1:ndsts(state.structure.graph) - vdiff = state.structure.var_to_diff[v] + (_, var_assignment, _) = assign_slots(state, key, var_eq_matching) + for v = 1:ndsts(structure.graph) + vdiff = structure.var_to_diff[v] vdiff === nothing && continue if var_eq_matching[v] !== SelectedState() || var_eq_matching[vdiff] !== SelectedState() @@ -146,22 +96,85 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn @assert kind == AssignedDiff @assert dkind in (AssignedDiff, UnassignedDiff) - v_val = @insert_instruction_here(oc_compact, line, settings, getindex(dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot)::Any) - @insert_instruction_here(oc_compact, line, settings, setindex!(out_du_mm, v_val, slot)::Any) + v_val = @insert_instruction_here(compact, line, settings, getindex(dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot)::Any) + @insert_instruction_here(compact, line, settings, setindex!(out_du_mm, v_val, slot)::Any) end - bc = @insert_instruction_here(oc_compact, line, settings, Base.Broadcast.broadcasted(-, out_du_mm, in_du_assgn)::Any) - @insert_instruction_here(oc_compact, line, settings, Base.Broadcast.materialize!(out_du_mm, bc)::Nothing) + bc = @insert_instruction_here(compact, line, settings, Base.Broadcast.broadcasted(-, out_du_mm, in_du_assgn)::Any) + @insert_instruction_here(compact, line, settings, Base.Broadcast.materialize!(out_du_mm, bc)::Nothing) # Return - @insert_instruction_here(oc_compact, line, settings, (return nothing)::Union{}) + @insert_instruction_here(compact, line, settings, (return nothing)::Union{}) + + ir = Compiler.finish(compact) + maybe_rewrite_debuginfo!(ir, settings) + resize!(ir.cfg.blocks, 1) + empty!(ir.cfg.blocks[1].succs) + Compiler.verify_ir(ir) + + @async @eval Main begin + interface_ir = $ir + end - ir_oc = Compiler.finish(oc_compact) - maybe_rewrite_debuginfo!(ir_oc, settings) - resize!(ir_oc.cfg.blocks, 1) - empty!(ir_oc.cfg.blocks[1].succs) - Compiler.verify_ir(ir_oc) - oc = Core.OpaqueClosure(ir_oc) + return Core.OpaqueClosure(ir; slotnames = [:captures, :out, :du, :u, :p, :t]) +end + +""" + dae_factory_gen(ci, key) + +Generate the `factory` function for CodeInstance `ci`, returning a DAEFunction. +The resulting function is roughly: + +``` +function factory(settings, f) + # Run all parts of `f` that do not depend on state + state_invariant_pieces = f_state_invariant() + f! = %new_opaque_closure(f_rhs, state_invariant_pieces) + DAEFunction(f!), differential_vars +end +``` + +""" +function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Union{TornCacheKey, UnoptimizedKey}, world::UInt, settings::Settings, init_key::Union{TornCacheKey, Nothing}) + (; result, structure) = state + + ir_factory = copy(ci.inferred.ir) + pushfirst!(ir_factory.argtypes, Settings) + pushfirst!(ir_factory.argtypes, typeof(factory)) + compact = IncrementalCompact(ir_factory) + + # Create a small opaque closure to adapt from SciML ABI to our own internal ABI + argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64} + sicm = () + if settings.skip_optimizations + daef_ci = rhs_finish_noopt!(state, ci, key, world, settings; opaque_closure = true) + oc = sciml_to_internal_abi_noopt!(copy(ci.inferred.ir), state, daef_ci, settings) + else + # TODO: We should not have to recompute this here + var_eq_matching = matching_for_key(state, key) + + torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world) + torn_ir = torn_ci.inferred + + (; ir_sicm) = torn_ir + + local line + if ir_sicm !== nothing + sicm_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world) + @assert sicm_ci !== nothing + + line = result.ir[SSAValue(1)][:line] + callee_argtypes = ci.inferred.ir.argtypes + callee_argmap = ArgumentMap(callee_argtypes) + args = Argument.(2 .+ eachindex(callee_argtypes)) + new_args = flatten_arguments_for_callee!(compact, callee_argmap, callee_argtypes, args, line, settings) + list = @insert_instruction_here(compact, line, settings, tuple(new_args...)::Tuple) + sicm = @insert_instruction_here(compact, line, settings, invoke(list, sicm_ci)::Tuple) + end + + daef_ci = rhs_finish!(state, ci, key, world, settings, 1) + oc = sciml_to_internal_abi!(copy(ci.inferred.ir), state, daef_ci, key, var_eq_matching, settings) + end line = result.ir[SSAValue(1)][:line] @@ -173,7 +186,16 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn new_oc = @insert_instruction_here(compact, line, settings, (:new_opaque_closure)(argt, Union{}, Nothing, true, oc_source_method, sicm)::Core.OpaqueClosure, true) - differential_states = Bool[v in key.diff_states for v in all_states] + if settings.skip_optimizations + differential_states = Bool[structure.var_to_diff[v] !== nothing for v in continuous_variables(state)] + else + all_states = filter(continuous_variables(state)) do var + kind = classify_var(result, key, var) + kind === nothing && return false + return kind !== AlgebraicDerivative + end + differential_states = Bool[v in key.diff_states for v in all_states] + end if init_key !== nothing initf = init_uncompress_gen!(compact, result, ci, init_key, key, world, settings) @@ -192,6 +214,6 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn empty!(ir_factory.cfg.blocks[1].succs) Compiler.verify_ir(ir_factory) - slotnames = [[:factory, :settings]; Symbol.(:arg, 1:(length(ir_factory.argtypes) - 2))] + slotnames = [:factory, :settings, :f] return ir_factory, slotnames end diff --git a/src/transform/codegen/init_factory.jl b/src/transform/codegen/init_factory.jl index 0b96a6a..996cb28 100644 --- a/src/transform/codegen/init_factory.jl +++ b/src/transform/codegen/init_factory.jl @@ -27,7 +27,11 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI @assert sicm_ci !== nothing line = result.ir[SSAValue(1)][:line] - param_list = flatten_parameter!(Compiler.fallback_lattice, compact, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings) + callee_argtypes = ci.inferred.ir.argtypes + callee_argmap = ArgumentMap(callee_argtypes) + args = Argument.(2 .+ eachindex(callee_argtypes)) + new_args = flatten_arguments_for_callee!(compact, callee_argmap, callee_argtypes, args, line, settings) + param_list = @insert_instruction_here(compact, line, settings, tuple(new_args...)::Tuple) sicm = @insert_instruction_here(compact, line, settings, invoke(param_list, sicm_ci)::Tuple) else sicm = () diff --git a/src/transform/codegen/init_uncompress.jl b/src/transform/codegen/init_uncompress.jl index 94acabc..cc36945 100644 --- a/src/transform/codegen/init_uncompress.jl +++ b/src/transform/codegen/init_uncompress.jl @@ -1,5 +1,5 @@ """ - struct RHSSpec + struct InitUncompressSpec Cache partition for the RHS """ @@ -11,7 +11,7 @@ end function gen_init_uncompress!(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, settings::Settings, ordinal::Int, indexT=Int) structure = make_structure_from_ipo(result) - tstate = TransformationState(result, structure, copy(result.total_incidence)) + tstate = TransformationState(result, structure) return gen_init_uncompress!(tstate, ci, init_key, diff_key, world, settings, ordinal, indexT) end @@ -44,8 +44,6 @@ function gen_init_uncompress!( cis = Vector{CodeInstance}() for (ir_ordinal, ir) in enumerate(torn.ir_seq) - ir = torn.ir_seq[ir_ordinal] - # Read in from the last level before any DAE or ODE-specific `ir_levels` # We assume this is named `tearing_schedule!` ir = copy(ir) diff --git a/src/transform/codegen/ode_factory.jl b/src/transform/codegen/ode_factory.jl index 450dd1c..f725db7 100644 --- a/src/transform/codegen/ode_factory.jl +++ b/src/transform/codegen/ode_factory.jl @@ -70,8 +70,12 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn @assert sicm_ci !== nothing line = result.ir[SSAValue(1)][:line] - param_list = flatten_parameter!(Compiler.fallback_lattice, returned_ic, ci.inferred.ir.argtypes[1:end], argn->Argument(2+argn), line, settings) - sicm_state = @insert_instruction_here(returned_ic, line, settings, (:call)(invoke, param_list, sicm_ci)::Tuple) + callee_argtypes = ci.inferred.ir.argtypes + callee_argmap = ArgumentMap(callee_argtypes) + args = Argument.(2 .+ eachindex(callee_argtypes)) + new_args = flatten_arguments_for_callee!(returned_ic, callee_argmap, callee_argtypes, args, line, settings) + param_list = @insert_instruction_here(returned_ic, line, settings, tuple(new_args...)::Tuple) + sicm_state = @insert_instruction_here(returned_ic, line, settings, invoke(param_list, sicm_ci)::Tuple) else sicm_state = () end diff --git a/src/transform/codegen/rhs.jl b/src/transform/codegen/rhs.jl index 0215d29..5195442 100644 --- a/src/transform/codegen/rhs.jl +++ b/src/transform/codegen/rhs.jl @@ -77,7 +77,7 @@ end function rhs_finish!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings, ordinal::Int, indexT=Int) structure = make_structure_from_ipo(result) - tstate = TransformationState(result, structure, copy(result.total_incidence)) + tstate = TransformationState(result, structure) return rhs_finish!(tstate, ci, key, world, settings, ordinal, indexT) end @@ -90,6 +90,7 @@ function rhs_finish!( ordinal::Int, indexT=Int) + @assert !settings.skip_optimizations (; result, structure) = state result_ci = find_matching_ci(ci->isa(ci.inferred, RHSSpec) && ci.inferred.key == key && ci.inferred.ordinal == ordinal, ci.def, world) if result_ci !== nothing @@ -109,8 +110,6 @@ function rhs_finish!( cis = Vector{CodeInstance}() for (ir_ordinal, ir) in enumerate(torn.ir_seq) - ir = torn.ir_seq[ir_ordinal] - # Read in from the last level before any DAE or ODE-specific `ir_levels` # We assume this is named `tearing_schedule!` ir = copy(ir) @@ -208,6 +207,7 @@ function rhs_finish!( @assert 1 <= Int(kind) <= Int(LastStateKind) which = Argument(arg_range[Int(kind)]) replace_call!(ir, SSAValue(i), Expr(:call, Base.getindex, which, slot), settings, @__SOURCE__) + inst[:type] = Float64 elseif is_known_invoke_or_call(stmt, InternalIntrinsics.contribution!, ir) eq = stmt.args[end-2]::Int kind = stmt.args[end-1]::EquationStateKind @@ -234,20 +234,28 @@ function rhs_finish!( end # Just before the end of the function - ir = Compiler.compact!(ir) + spec = RHSSpec(key, ir_ordinal) + daef_ci = rhs_finish_ir!(ir, ci, settings, spec, slotnames; single_block = true) + push!(cis, daef_ci) + end + + return cis[ordinal] +end + +function rhs_finish_ir!(ir::IRCode, ci::CodeInstance, settings::Settings, owner::Union{RHSSpec, UnoptimizedKey}, slotnames; single_block = false) + ir = Compiler.compact!(ir) + if single_block resize!(ir.cfg.blocks, 1) empty!(ir.cfg.blocks[1].succs) + end - widen_extra_info!(ir) - Compiler.verify_ir(ir) - src = ir_to_src(ir, settings; slotnames) - - abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64} - daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, RHSSpec(key, ir_ordinal)) - ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), daef_ci, src) + widen_extra_info!(ir) + Compiler.verify_ir(ir) + src = ir_to_src(ir, settings; slotnames) - push!(cis, daef_ci) - end + abi = Tuple{ir.argtypes...} + daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, owner) + ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), daef_ci, src) - return cis[ordinal] + return daef_ci end diff --git a/src/transform/common.jl b/src/transform/common.jl index 492d7aa..042a29e 100644 --- a/src/transform/common.jl +++ b/src/transform/common.jl @@ -116,7 +116,7 @@ function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, @nos ir[idx][:type] = Any ir[idx][:info] = Compiler.NoCallInfo() ir[idx][:flag] |= Compiler.IR_FLAG_REFINED - return new_call + return idx end function maybe_insert_debuginfo!(compact::IncrementalCompact, settings::Settings, source::LineNumberNode, previous = nothing, i = compact.result_idx) @@ -129,6 +129,7 @@ function maybe_insert_debuginfo!(debuginfo::DebugInfoStream, settings::Settings, end function insert_debuginfo!(debuginfo::DebugInfoStream, i::Integer, source::LineNumberNode, previous) + prev_edge_index = prev_edge_line = nothing if previous !== nothing && isa(previous, Tuple) prev_edge_index, prev_edge_line = previous[2], previous[3] end @@ -199,8 +200,8 @@ function replace_if_intrinsic!(compact, settings, ssa, du, u, p, t, var_assignme if var_assignment === nothing var_idx = 0 else - var_idx, in_du = var_assignment[var] - @assert !in_du || (du !== nothing) + kind, var_idx = var_assignment[var] + @assert kind !== AssignedDiff || du !== nothing end if iszero(var_idx) @@ -208,7 +209,7 @@ function replace_if_intrinsic!(compact, settings, ssa, du, u, p, t, var_assignme # but for some reason, wasn't deleted in any prior pass. inst[:inst] = GlobalRef(DAECompiler.Intrinsics, :_VARIABLE_UNASSIGNED) else - source = in_du ? du : u + source = kind === AssignedDiff ? du : u replace_call!(compact, ssa, Expr(:call, getindex, source, var_idx), settings, @__SOURCE__) end elseif is_known_invoke_or_call(stmt, sim_time, compact) diff --git a/src/transform/reconstruct.jl b/src/transform/reconstruct.jl new file mode 100644 index 0000000..72f10fd --- /dev/null +++ b/src/transform/reconstruct.jl @@ -0,0 +1,132 @@ +function expand_residuals(state::TransformationState, key::TornCacheKey, states, compressed, u, du, t) + (; result, structure) = state + expanded = Float64[] + i = 1 + var_eq_matching = matching_for_key(state, key) + for (eq, incidence) in enumerate(result.total_incidence) + uses_replacement_variable = any(x -> in(x.by + 1, rowvals(incidence.row)), result.replacement_map.variables) + if !is_const_plus_var_known_linear(incidence) || uses_replacement_variable + sign = infer_residual_sign(result, eq, var_eq_matching) + uses_replacement_variable && (sign *= -1) # XXX: this may be too simple + push!(expanded, sign * compressed[i]) + i += 1 + continue + end + + residual = 0.0 + for (coeff, var) in zip(nonzeros(incidence.row), rowvals(incidence.row)) + var -= 1 + if var == 0 + value = t + else + is_diff = is_differential_variable(structure, var) + source = ifelse(is_diff, du, u) + # XXX: that's probably incorrect but has done the correct thing so far + var === invview(var_eq_matching)[eq] && !is_diff && (i += 1) + state = states[var] + @assert state โ‰  -1 "Reading from a state vector for a variable that has no corresponding state" + value = source[state] + end + residual += value * coeff + end + constant_term = incidence.typ.val::Float64 + push!(expanded, constant_term + residual) + end + return expanded +end + +function infer_residual_sign(result::DAEIPOResult, eq::Int, var_eq_matching) + # If a linear solved term appears with a positive coefficient, + # the residual will be taken as the negative of the value provided to `always!`. + # For example: xฬ‡โ‚ - xโ‚xโ‚‚ = 0 + # -xฬ‡โ‚ = -xโ‚xโ‚‚ + # xฬ‡โ‚ = -xโ‚xโ‚‚/-1 + # xฬ‡โ‚ = xโ‚xโ‚‚ + # 0 = xโ‚xโ‚‚ - xฬ‡โ‚ <-- residual + incidence = result.total_incidence[eq] + var = invview(var_eq_matching)[eq] + isa(var, Int) || return 1 + coeff = incidence.row[var + 1] + isa(coeff, Float64) || return -1 + return -sign(coeff) +end + +function is_differential_variable(structure::DAESystemStructure, var) + structure.var_to_diff[var] !== nothing && return false + return invview(structure.var_to_diff)[var] !== nothing && return true + @assert false +end + +function extract_removed_variables(state::TransformationState, key::TornCacheKey, torn::TornIR) + (; result, structure) = state + # TODO: handle multiple partitions + torn_ir = only(torn.ir_seq) + removed_vars = Int[] + for (i, inst) in enumerate(torn_ir.stmts) + stmt = inst[:stmt] + is_solved_variable(stmt) || continue + var = stmt.args[2]::Int + vint = invview(structure.var_to_diff)[var] + vint === nothing || key.diff_states === nothing || !in(vint, key.diff_states) || continue + push!(removed_vars, var) + end + return removed_vars +end + +""" + compute_residual_vectors(f, u, du; t = rand()) + +Compute residual vectors with the optimized and unoptimized code generation. +For a consistent `u` and `du` pair (in particular with respect +to the equations defining state derivatives), both residuals should be equal. +If not, it may indicate a bug in the code generation process and should be addressed. + +If a state derivative is used in more than one equation, `u` and `du` must +be provided such that the selected equation that determines this derivative +holds; otherwise, residuals for equations involving the value of this state +derivative may differ between the unoptimized and optimized versions. +""" +function compute_residual_vectors(f, u, du; t = 1.0, mode=DAE, world=Base.tls_world_age()) + @assert mode === DAE # TODO: support ODEs + settings = Settings(; mode, insert_stmt_debuginfo = true) + tt = Base.signature_type(f, ()) + ci = _code_ad_by_type(tt; world) + result = @code_structure result=true mode=settings.mode insert_stmt_debuginfo=settings.insert_stmt_debuginfo world=world f() + structure = make_structure_from_ipo(result) + state = TransformationState(result, structure) + key, _ = top_level_state_selection!(state) + tearing_schedule!(state, ci, key, world, settings) + torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world) + torn_ir = torn_ci.inferred + + our_prob = DAECProblem(f, (1,) .=> 1.; settings.insert_stmt_debuginfo) + sciml_prob = DiffEqBase.get_concrete_problem(our_prob, true) + f_compressed! = sciml_prob.f.f + + our_prob = DAECProblem(f, (1,) .=> 1.; settings.insert_stmt_debuginfo, skip_optimizations = true) + sciml_prob = DiffEqBase.get_concrete_problem(our_prob, true) + f_original! = sciml_prob.f.f + + residuals = zeros(length(u)) + p = SciMLBase.NullParameters() + states = map_variables_to_states(state) + removed_variables = extract_removed_variables(state, key, torn_ir) + removed_states = filter(โ‰ (-1), states[removed_variables]) + compressed_states = filter(x -> !in(x, removed_states) && x โ‰  -1, states) + state_compression = unique(compressed_states) + u_compressed = u[state_compression] + du_compressed = du[state_compression] + + n = length(state.result.eqkinds) + residuals_compressed = zeros(n) + f_compressed!(residuals_compressed, du_compressed, u_compressed, p, t) + f_original!(residuals, du, u, p, t) + + expanded = expand_residuals(state, key, states, residuals_compressed, u, du, t) + @assert issorted(result.replacement_map.variables, by = x -> x.equation) + for (; equation) in result.replacement_map.variables + insert!(residuals, equation, 0.0) + end + + return residuals, expanded +end diff --git a/src/transform/state_selection.jl b/src/transform/state_selection.jl index e35b1d6..70a9cd6 100644 --- a/src/transform/state_selection.jl +++ b/src/transform/state_selection.jl @@ -6,6 +6,8 @@ struct TransformationState <: StateSelection.TransformationState{DAEIPOResult} structure::DAESystemStructure total_incidence::Vector{Incidence} end +TransformationState(result::DAEIPOResult, structure::DAESystemStructure) = + TransformationState(result, structure, copy(result.total_incidence)) function StateSelection.linear_subsys_adjmat!(state::TransformationState) graph = state.structure.graph @@ -165,7 +167,7 @@ function Base.show(io::IO, (; callees)::CalleeInfo) end end -function top_level_state_selection!(tstate) +function top_level_state_selection!(tstate::TransformationState) (; result, structure) = tstate # For the top-level problem, all external vars are state-invariant, and we do no other fissioning diff --git a/src/transform/tearing/schedule.jl b/src/transform/tearing/schedule.jl index d9c3a46..e30a327 100644 --- a/src/transform/tearing/schedule.jl +++ b/src/transform/tearing/schedule.jl @@ -333,8 +333,10 @@ function compute_eq_schedule(key::TornCacheKey, total_incidence, result, mss::St for i = 1:length(callee_info.result.total_incidence) i in previously_scheduled_or_ignored && continue # We scheduled this previously i in this_callee_eqs && continue # We already scheduled this + # Skip equations that the callee defines but does not apply. + !isassigned(callee_info.result.total_incidence, i) && continue callee_incidence = callee_info.result.total_incidence[i] - incidence = apply_linear_incidence(nothing, callee_incidence, nothing, callee_info.mapping) + incidence = apply_linear_incidence!(callee_info.mapping, nothing, callee_incidence, nothing) if is_const_plus_var_known_linear(incidence) # No non-linear components - skip it push!(previously_scheduled_or_ignored, i) @@ -498,6 +500,8 @@ function invert_eq_callee_mapping(eq_callee_mapping) return callee_eq_mapping end +classify_var(structure::DAESystemStructure, key::TornCacheKey, var) = classify_var(structure.var_to_diff, key, var) +classify_var(result::DAEIPOResult, key::TornCacheKey, var) = classify_var(result.var_to_diff, key, var) function classify_var(var_to_diff, key::TornCacheKey, var) if var in key.alg_states vint = invview(var_to_diff)[var] @@ -652,7 +656,7 @@ end function tearing_schedule!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, settings::Settings) structure = make_structure_from_ipo(result) - tstate = TransformationState(result, structure, copy(result.total_incidence)) + tstate = TransformationState(result, structure) return tearing_schedule!(tstate, ci, key, world, settings) end @@ -839,7 +843,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To if !any(out->callee_eq in out[2], callee_var_schedule) display(mss) cstructure = make_structure_from_ipo(callee_result) - cvar_eq_matching = matching_for_key(callee_result, callee_key, cstructure) + cvar_eq_matching = matching_for_key(callee_result, callee_key) display(StateSelection.MatchedSystemStructure(callee_result, cstructure, cvar_eq_matching)) @sshow eq_orders @sshow callee_result.total_incidence[callee_eq] @@ -854,7 +858,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To vars_for_final = BitSet() # TODO: This is a very expensive way to compute this - we should be able to do this cheaper cstructure = make_structure_from_ipo(callee_result) - tstate = TransformationState(callee_result, cstructure, copy(callee_result.total_incidence)) + tstate = TransformationState(callee_result, cstructure) cvar_eq_matching = matching_for_key(tstate, callee_key) for callee_var in 1:length(cvar_eq_matching) if cvar_eq_matching[callee_var] !== unassigned @@ -976,7 +980,7 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To @sshow lin_var @sshow ordinal @sshow eq_order - display(result.ir) + @sshow result.ir error("Tried to schedule variable $(lin_var) that we do not have a solution to (but our scheduling should have ensured that we do)") end var_sols[lin_var] = CarriedSSAValue(ordinal, (@insert_instruction_here(compact1, line, settings, (:invoke)(nothing, Intrinsics.variable)::Incidence(lin_var)).id)) diff --git a/src/transform/unoptimized.jl b/src/transform/unoptimized.jl new file mode 100644 index 0000000..dc6addf --- /dev/null +++ b/src/transform/unoptimized.jl @@ -0,0 +1,182 @@ +struct UnoptimizedKey end + +function Base.StackTraces.show_custom_spec_sig(io::IO, owner::UnoptimizedKey, linfo::CodeInstance, frame::Base.StackTraces.StackFrame) + print(io, "Unoptimized transformed IR for ") + mi = Base.get_ci_mi(linfo) + return Base.StackTraces.show_spec_sig(io, mi.def, mi.specTypes) +end + +function rhs_finish_noopt!( + state::TransformationState, + ci::CodeInstance, + key::UnoptimizedKey, + world::UInt, + settings::Settings, + equation_to_residual_mapping = 1:length(state.structure.eq_to_diff), + variable_to_state_mapping = map_variables_to_states(state); + opaque_closure) + + (; result, structure) = state + result_ci = find_matching_ci(ci -> ci.owner === key, ci.def, world) + if result_ci !== nothing + return result_ci + end + + ir = copy(result.ir) + src = ci.inferred::AnalyzedSource + argrange = 1:src.nargs + # Original arguments. + slotnames = src.slotnames[argrange] + argtypes = remove_variable_and_equation_annotations(ir.argtypes) + if opaque_closure + slotnames[1] = :captures + argtypes[1] = Tuple + end + # Additional ABI arguments. + push!(slotnames, :out, :du, :u, :residuals, :states, :t) + push!(argtypes, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Int}, Vector{Int}, Float64) + + @assert length(slotnames) == length(argtypes) + append!(empty!(ir.argtypes), argtypes) + captures, args..., out, du, u, residuals, states, t = Argument.(eachindex(slotnames)) + + equations = Pair{Union{Argument, SSAValue}, Eq}[] + for arg in args + index = [arg.n] + i = findfirst(==(index), result.argmap.equations) + i !== nothing && push!(equations, arg => Eq(i)) + end + callee_to_caller_eq_map = invert_eq_callee_mapping(result.eq_callee_mapping) + compact = IncrementalCompact(ir) + + for ((old, i), _) in compact + ssaidx = SSAValue(i) + inst = compact[ssaidx] + stmt = inst[:stmt] + type = inst[:type] + line = inst[:line] + + if i == 1 + # @insert_instruction_here(compact, nothing, settings, println("Residuals: ", residuals)::Any) + # @insert_instruction_here(compact, nothing, settings, println("States: ", states)::Any) + end + + if is_known_invoke_or_call(stmt, Intrinsics.variable, compact) + var = idnum(type) + index = @insert_instruction_here(compact, line, settings, getindex(states, var)::Int) + value = @insert_instruction_here(compact, line, settings, getindex(u, index)::Float64) + replace_uses!(compact, (old, inst) => value) + # @insert_instruction_here(compact, line, settings, println("Variable (", var, "): ", value)::Float64) + elseif is_known_invoke(stmt, Intrinsics.ddt, compact) + var = idnum(type) + index = @insert_instruction_here(compact, line, settings, getindex(states, var)::Int) + value = @insert_instruction_here(compact, line, settings, getindex(du, index)::Float64) + replace_uses!(compact, (old, inst) => value) + # @insert_instruction_here(compact, line, settings, println("Variable derivative (", var, " := ", invview(structure.var_to_diff)[var], "โ€ฒ): ", value)::Any) + elseif is_known_invoke(stmt, Intrinsics.equation, compact) + # This is already done for each encountered `Eq` type. + # push!(equations, ssaidx => type::Eq) + elseif is_equation_call(stmt, compact) + callee, value = stmt.args[2], stmt.args[3] + i = findfirst(x -> first(x) == callee, equations)::Int + eq = last(equations[i]) + index = @insert_instruction_here(compact, line, settings, getindex(residuals, eq.id)::Int) + previous = @insert_instruction_here(compact, line, settings, getindex(out, index)::Float64) + accumulated = @insert_instruction_here(compact, line, settings, +(previous, value)::Float64) + ret = @insert_instruction_here(compact, line, settings, setindex!(out, accumulated, index)::Any) + replace_uses!(compact, (old, inst) => ret) + # @insert_instruction_here(compact, line, settings, println("Residuals (index = ", index, ", value = ", value, "): ", residuals)::Any) + elseif is_known_invoke_or_call(stmt, Intrinsics.sim_time, compact) + inst[:stmt] = t + elseif is_known_invoke_or_call(stmt, Intrinsics.epsilon, compact) + inst[:stmt] = 0.0 + elseif isexpr(stmt, :invoke) + info = inst[:info]::MappingInfo + callee_ci, args = stmt.args[1]::CodeInstance, @view stmt.args[2:end] + callee_result = structural_analysis!(callee_ci, world, settings) + callee_structure = make_structure_from_ipo(callee_result) + callee_state = TransformationState(callee_result, callee_structure) + + caller_eqs = get(Vector{Int}, callee_to_caller_eq_map, StructuralSSARef(old)) + callee_residuals = equation_to_residual_mapping[caller_eqs] + caller_variables = map(info.mapping.var_coeffs) do coeff + isa(coeff, Incidence) || return -1 + nnz(coeff.row) == 1 || return -1 + idnum(coeff) + end + callee_states = [get(variable_to_state_mapping, i, -1) for i in caller_variables] + + callee_daef_ci = rhs_finish_noopt!(callee_state, callee_ci, UnoptimizedKey(), world, settings, callee_residuals, callee_states; opaque_closure = false) + call = @insert_instruction_here(compact, line, settings, (:invoke)(callee_daef_ci, args..., + out, + du, + u, + @insert_instruction_here(compact, line, settings, Base.vect(callee_residuals...)::Vector{Int}), + @insert_instruction_here(compact, line, settings, Base.vect(callee_states...)::Vector{Int}), + t)::type) + replace_uses!(compact, (old, inst) => call) + isa(type, Eq) && push!(equations, call => type) + end + + type = inst[:type] + isa(type, Eq) && push!(equations, ssaidx => type) + if isa(type, Incidence) || isa(type, Eq) + inst[:type] = widenconst(type) + end + end + + daef_ci = rhs_finish_ir!(Compiler.finish(compact), ci, settings, key, slotnames) + # @sshow daef_ci.inferred + return daef_ci +end + +function map_variables_to_states(state::TransformationState) + (; result, structure) = state + diff_to_var = invview(structure.var_to_diff) + states = Int[] + prev_state = 0 + for var in continuous_variables(state) + if any(repl -> repl.by == var, result.replacement_map.variables) + # This is a replacement variable, skip it. + push!(states, -1) + continue + end + ref = is_differential_variable(structure, var) ? diff_to_var[var] : var + state = @something(get(states, ref, nothing), prev_state += 1) + push!(states, state) + end + return states +end + +function replace_uses!(compact, ((old, inst), new)) + inst[:stmt] = nothing + compact.ssa_rename[old] = new +end + +function sciml_to_internal_abi_noopt!(ir::IRCode, state::TransformationState, internal_ci::CodeInstance, settings::Settings) + slotnames = [:captures, :out, :du, :u, :p, :t] + captures, out, du, u, p, t = Argument.(eachindex(slotnames)) + + empty!(ir.argtypes) + push!(ir.argtypes, Tuple) # opaque closure captures + append!(ir.argtypes, fieldtypes(SCIML_ABI)) + + compact = IncrementalCompact(ir) + line = ir[SSAValue(1)][:line] + + internal_oc = @insert_instruction_here(compact, line, settings, getfield(captures, 1)::Core.OpaqueClosure) + neqs = length(state.structure.eq_to_diff) + nvars = length(state.structure.var_to_diff) + residuals = @insert_instruction_here(compact, line, settings, getindex(Int, 1:neqs...)::Vector{Int}) + states = @insert_instruction_here(compact, line, settings, getindex(Int, map_variables_to_states(state)...)::Vector{Int}) + @insert_instruction_here(compact, line, settings, (:invoke)(internal_ci, internal_oc, out, du, u, residuals, states, t)::Nothing) + @insert_instruction_here(compact, line, settings, (return nothing)::Union{}) + + ir = Compiler.finish(compact) + maybe_rewrite_debuginfo!(ir, settings) + resize!(ir.cfg.blocks, 1) + empty!(ir.cfg.blocks[1].succs) + Compiler.verify_ir(ir) + + return Core.OpaqueClosure(ir; slotnames) +end diff --git a/src/utils.jl b/src/utils.jl index 03d0385..c5a095d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -162,15 +162,19 @@ end @sshow stmt @sshow length(ir.stmts) typeof(val) -Drop-in replacement for `@show`, but using `jl_safe_printf` to avoid task switches. +Drop-in replacement for `@show`, but using `Core.println` to avoid task switches. This directly prints to C stdout; `stdout` redirects won't have any effect. """ macro sshow(exs...) blk = Expr(:block) for ex in exs - push!(blk.args, :(Core.println($(sprint(Base.show_unquoted,ex)*" = "), - repr(begin local value = $(esc(ex)) end)))) + push!(blk.args, quote + value = $(esc(ex)) + Core.print($(sprint(Base.show_unquoted, ex))) + Core.print(" = ") + Core.println(sprint(print, value, context = :color => true)) + end) end isempty(exs) || push!(blk.args, :value) return blk diff --git a/test/benchmark.jl b/test/benchmark.jl new file mode 100644 index 0000000..a0110af --- /dev/null +++ b/test/benchmark.jl @@ -0,0 +1,27 @@ +module _benchmark + +using DAECompiler +using DAECompiler: compute_residual_vectors +using SciMLBase, Sundials +using Test + +include("../benchmark/thermalfluid.jl") + +@testset "Validation" begin + Benchmark{3}()() + + u = zeros(68) + du = zeros(68) + residuals, expanded_residuals = compute_residual_vectors(Benchmark{3}(), u, du) + @test length(residuals) == length(expanded_residuals) + @test_broken residuals โ‰ˆ expanded_residuals + # indices = findall(i -> residuals[i] โ‰‰ expanded_residuals[i], eachindex(residuals)) + # residuals[indices] + # expanded_residuals[indices] +end + +let sol = solve(DAECProblem(Benchmark{3}(), [1:9;] .=> 0.), IDA()) + @test_broken sol.retcode == ReturnCode.Success +end + +end # module _benchmark diff --git a/test/runtests.jl b/test/runtests.jl index 954161c..e1f1caf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,8 +7,9 @@ include("ssrm.jl") include("regression.jl") include("errors.jl") include("invalidation.jl") +include("validation.jl") using Pkg Pkg.activate(joinpath(dirname(@__DIR__), "benchmark")) do - include("../benchmark/thermalfluid.jl") + include("benchmark.jl") end diff --git a/test/validation.jl b/test/validation.jl new file mode 100644 index 0000000..ec88925 --- /dev/null +++ b/test/validation.jl @@ -0,0 +1,239 @@ +module _validation + +using DAECompiler +using DAECompiler: refresh, compute_residual_vectors + +using Test +using DAECompiler +using DAECompiler.Intrinsics + +const *แตข = Core.Intrinsics.mul_float +const +แตข = Core.Intrinsics.add_float +const -แตข = Core.Intrinsics.sub_float + +@noinline function onecall!() + x = continuous() + always!(ddt(x) - x) +end + +function multiple_linear_equations!() + xโ‚ = continuous() # selected + xโ‚‚ = continuous() # selected + xโ‚ƒ = continuous() # algebraic, optimized away + xโ‚„ = continuous() # algebraic + always!(ddt(xโ‚) -แตข xโ‚ *แตข xโ‚‚) + always!(ddt(xโ‚‚) -แตข 3.0) + always!(xโ‚ƒ -แตข xโ‚) # optimized away, not part of the DAE problem. + always!(xโ‚„ *แตข xโ‚„ -แตข ddt(xโ‚)) +end + +@noinline function sin!() + x = continuous() + always!(ddt(x) - sin(x)) +end + +@noinline function neg_sin!() + x = continuous() + always!(sin(x) - ddt(x)) +end + +function twocall!() + onecall!(); onecall!(); + return nothing +end + +function sin2!() + sin!(); sin!(); + return nothing +end + +@noinline new_equation() = always() +function external_equation!() + x = continuous() + eq = new_equation() + eq(ddt(x) - 1.0) +end + +@noinline function flattening_inner!((a, b)) + x = continuous() + always!(ddt(x) - (a + b)) +end +function flattening!() + x = continuous() + flattening_inner!((1.0, x)) + always!(ddt(x) - 2.0) +end + +@noinline apply_equation(eq, residual) = eq(residual) +function equation_argument!() + x = continuous() + equation = new_equation() + apply_equation(equation, x - ddt(x)) +end + +@noinline new_equation_and_variable() = (always(), continuous()) +@noinline apply_equation_on_ddtx_minus_one(eq, x) = apply_equation(eq, ddt(x) - 1.) +function equation_used_multiple_times!() + (eq, x) = new_equation_and_variable() + apply_equation_on_ddtx_minus_one(eq, x) + apply_equation_on_ddtx_minus_one(eq, x) +end + +function equation_with_callable!() + x = continuous() + callable = @noinline Returns(x) + always!(ddt(callable()) - 3.0) +end + +@noinline apply_equation!(lhs, rhs) = always!(lhs - rhs) +function nonlinear_argument!() + x = continuous() + apply_equation!(ddt(x), sin(x)) +end + +struct WithParameter{N} end +@noinline (::WithParameter{N})(eq, x) where {N} = eq(ddt(x) - N) +function callable_with_type_parameter!() + eq, x = new_equation_and_variable() + WithParameter{3}()(eq, x) +end + +@noinline nonlinear_operation() = tanh(ddt(continuous())) +function nonlinear_replacement!() + result = nonlinear_operation() + always!(result - 0.5) +end + +@noinline function nested_nonlinear_operations(x, y) + z = continuous() + eq = always() + a = sin(y) + b = exp(ddt(z)) + c = cosh(ddt(x) + y + z) + (a, ((b, c), eq)) +end +function nonlinear_replacement_nested!() + x = continuous() + (sy, ((ezฬ‡, cxฬ‡yz), eq)) = nested_nonlinear_operations(x, 0.5) + always!(sy + ezฬ‡) + eq(cxฬ‡yz) +end + +@noinline nonlinear_operation(x) = sin(x) - ddt(continuous()) +function external_derivative_nonlinear!() + x = continuous() + always!(x - 0.5) + always!(nonlinear_operation(ddt(x))) +end + +@testset "Validation" begin + u = [2.0] + du = [3.0] + residuals, expanded_residuals = compute_residual_vectors(onecall!, u, du) + @test residuals โ‰ˆ [1.0] + @test residuals โ‰ˆ expanded_residuals + + u = [3.0, 1.0, 100.0, 4.0] + du = [3.0, 0.0, 0.0, 0.0] + residuals, expanded_residuals = compute_residual_vectors(multiple_linear_equations!, u, du) + @test residuals โ‰ˆ [0.0, -3.0, 97.0, 13.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [3.0] + residuals, expanded_residuals = compute_residual_vectors(sin!, u, du) + @test residuals โ‰ˆ du .- sin.(u) + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [3.0] + residuals, expanded_residuals = compute_residual_vectors(neg_sin!, u, du) + @test residuals โ‰ˆ sin.(u) .- du + @test residuals โ‰ˆ expanded_residuals + + # IPO + + u = [2.0] + du = [3.0] + residuals, expanded_residuals = compute_residual_vectors(() -> onecall!(), u, du) + @test residuals โ‰ˆ [1.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0, 4.0] + du = [3.0, 7.0] + residuals, expanded_residuals = compute_residual_vectors(twocall!, u, du) + @test residuals โ‰ˆ [1.0, 3.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0, 4.0] + du = [1.0, 1.0] + residuals, expanded_residuals = compute_residual_vectors(sin2!, u, du) + @test all(>(0), residuals) + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [4.0] + residuals, expanded_residuals = compute_residual_vectors(nonlinear_argument!, u, du) + @test residuals โ‰ˆ du .- sin.(u) + @test residuals โ‰ˆ expanded_residuals + + u = [0.0] + du = [2.0] + residuals, expanded_residuals = compute_residual_vectors(external_equation!, u, du) + @test residuals โ‰ˆ [1.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [1.0] + residuals, expanded_residuals = compute_residual_vectors(equation_argument!, u, du) + @test residuals โ‰ˆ [1.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0, 4.0] + du = [1.0, 1.0] + residuals, expanded_residuals = compute_residual_vectors(flattening!, u, du) + @test residuals โ‰ˆ [-1.0, -2.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [4.0] + residuals, expanded_residuals = compute_residual_vectors(equation_used_multiple_times!, u, du) + @test residuals โ‰ˆ [6.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [4.0] + residuals, expanded_residuals = compute_residual_vectors(equation_with_callable!, u, du) + @test residuals โ‰ˆ [1.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [3.0] + residuals, expanded_residuals = compute_residual_vectors(callable_with_type_parameter!, u, du) + @test residuals โ‰ˆ [0.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0] + du = [3.0] + residuals, expanded_residuals = compute_residual_vectors(nonlinear_replacement!, u, du) + @test residuals โ‰ˆ [0.49505475368673046, 0.0] + @test residuals โ‰ˆ expanded_residuals + + u = [2.0, 6.0] + du = [3.0, -1.0] + residuals, expanded_residuals = compute_residual_vectors(nonlinear_replacement_nested!, u, du) + @test residuals โ‰ˆ expanded_residuals + + u = [2.0, 4.0] + du = [3.0, 5.0] + # XXX: Fix GlobalRef handling in Diffractor's forward AD pass first. + # ERROR: UndefVarError: `pos` not defined in `Diffractor` + # XXX: To pass this test we'll need to map one of the callee variables to a state + # differential (`du`), while currently we only map to states and we assume the derivative index matches. + @test_skip begin + residuals, expanded_residuals = compute_residual_vectors(external_derivative_nonlinear!, u, du) + @test residuals โ‰ˆ expanded_residuals + end +end; + +end # module _validation