Skip to content

Commit 924e8c4

Browse files
authored
Implement rewrite validator (#55)
* Misc refactors, insert internal stmt debuginfo * Add slotnames info * Refactor naming of locals for ODE codegen * Add comment * Apply DebugInfo at top-level * Remove `insert_ssa_debuginfo` setting * Marshall settings everywhere * Add maybe_insert_debuginfo * Disable crashing tests * Propagate source information for `insert_node_here!(...)` * Add source provenance to DAE/Init codegen * Use `__SOURCE__` macro for `replace_call!` * Update `replace_call!` with source information in index_lowering.jl * Default `insert_stmt_debuginfo` to `false` for reflection tools * Remove accidental code inclusion * Don't seek previous codeloc * Remove unused code * Reenable IPO tests * [DO NOT MERGE] Temporarily dev ConstructionBase for CI * Work around `invokelatest` issue * Only wrap `string` call in `try`/`catch` * Add `insert_ssa_debuginfo` setting * Minor fix * Refactor `insert_node_here` macro * Use insert_instruction! in more places * Rename `insert_node_here` macro to `insert_instruction` * `insert_instruction` -> `insert_instruction_here` Not to confuse it with the nonlocal insertion * Fixes * Undev ConstructionBase (and update) * [WIP] Add unoptimized compilation path * Use UnoptimizedKey singleton for caching unoptimized RHS * Refactor settings construction for DAE/ODE problems * Write `expand_residuals` * Use :call form for macrocall * Use :call for added macrocalls * Fix introduced codegen bug * [WIP] Add basic infrastructure for testing * Reconstruct variables from states * Separate optimized and unoptimized structural caches * Add incidence-copying constructor for TransformationState * Minor refactor `rhs_ir_finish!` * Skip flattening in unoptimized context, WIP on IPO support * Correctly compute state/equation indices for IPO This includes a refactor that considers mapping to states, instead of individual mappings into u/du vectors * Adjust residual sign based on solved variable coefficient * Rewrite and support flattening, improve external equation support * Fix flattening bug * Fix another flattening bug * Rename testing function * Support type constructors for IPO * Split benchmark definition and test * Remove process_template! * Fix handling of functors, replace flatten_parameter! * Fix bug in residual expansion * Mark `apply_linear_incidence` as mutating * Correctly handle nonlinear replacements * Add (broken) validation test for thermalfluid benchmark * Wrap test files in modules * Remove unnecessary import * Remove temp package from project * Remove `refresh()` in tests
1 parent 54c8ec0 commit 924e8c4

29 files changed

+1128
-302
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

benchmark/main.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Test
2+
using SciMLBase, Sundials
3+
4+
include("thermalfluid.jl")
5+
6+
Benchmark{3}()()
7+
@test isa(code_lowered(DAECompiler.factory, Tuple{Val{DAECompiler.Settings(mode=DAECompiler.DAENoInit)}, Benchmark{3}})[1], Core.CodeInfo)
8+
let sol = solve(DAECProblem(Benchmark{3}(), [1:9;] .=> 0.), IDA())
9+
@test_broken sol.retcode == ReturnCode.Success
10+
end

benchmark/thermalfluid.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ using DAECompiler
22
using DAECompiler.Intrinsics
33
using Polynomials: fit
44
using XSteam: my_pT, rho_pT, Cp_pT, tc_pT
5-
using SciMLBase, Sundials
6-
using Test
75

86
struct Sink; end
97
function (::Sink)(inlet)
@@ -282,9 +280,3 @@ function (::Benchmark{N})() where {N}
282280
PreinsulatedPipe{N}()(in[2], out[1])
283281
Sink()(out[2])
284282
end
285-
286-
Benchmark{3}()()
287-
@test isa(code_lowered(DAECompiler.factory, Tuple{Val{DAECompiler.Settings(mode=DAECompiler.DAENoInit)}, Benchmark{3}})[1], Core.CodeInfo)
288-
let sol = solve(DAECProblem(Benchmark{3}(), [1:9;] .=> 0.), IDA())
289-
@test_broken sol.retcode == ReturnCode.Success
290-
end

src/DAECompiler.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module DAECompiler
66
using Diffractor
77
using OrderedCollections
88
using Compiler
9-
using Compiler: IRCode, IncrementalCompact, DebugInfoStream, NewInstruction, argextype, singleton_type, isexpr, widenconst
9+
using Compiler: AbstractLattice, IRCode, IncrementalCompact, DebugInfoStream, NewInstruction, argextype, singleton_type, isexpr, widenconst
1010
using Core.IR
1111
using SciMLBase
1212
using AutoHashEquals
@@ -16,18 +16,21 @@ module DAECompiler
1616
include("settings.jl")
1717
include("utils.jl")
1818
include("intrinsics.jl")
19+
include("reflection.jl")
1920
include("analysis/utils.jl")
2021
include("analysis/lattice.jl")
2122
include("analysis/ADAnalyzer.jl")
2223
include("analysis/scopes.jl")
24+
include("analysis/flattening.jl")
2325
include("analysis/cache.jl")
2426
include("analysis/refiner.jl")
2527
include("analysis/ipoincidence.jl")
2628
include("analysis/structural.jl")
27-
include("analysis/flattening.jl")
2829
include("transform/state_selection.jl")
2930
include("transform/common.jl")
3031
include("transform/runtime.jl")
32+
include("transform/unoptimized.jl")
33+
include("transform/reconstruct.jl")
3134
include("transform/tearing/schedule.jl")
3235
include("transform/codegen/dae_factory.jl")
3336
include("transform/codegen/ode_factory.jl")
@@ -40,5 +43,4 @@ module DAECompiler
4043
include("analysis/consistency.jl")
4144
include("interface.jl")
4245
include("problem_interface.jl")
43-
include("reflection.jl")
4446
end

src/analysis/ADAnalyzer.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ end
6262

6363
struct AnalyzedSource
6464
ir::Compiler.IRCode
65+
slotnames::Vector{Any}
6566
inline_cost::Compiler.InlineCostType
6667
nargs::UInt
6768
isva::Bool
@@ -72,19 +73,26 @@ end
7273
Core.svec(edges..., interp.edges...)
7374
end
7475

76+
function get_slotnames(def::Method)
77+
names = split(def.slot_syms, '\0')
78+
return map(Symbol, names)
79+
end
80+
7581
@override function Compiler.transform_result_for_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
7682
ir = result.src.optresult.ir
83+
slotnames = get_slotnames(result.linfo.def)
7784
params = Compiler.OptimizationParams(interp)
78-
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
85+
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
7986
end
8087

8188
@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult)
8289
if Compiler.result_is_constabi(interp, result)
8390
return nothing
8491
end
8592
ir = result.src.optresult.ir
93+
slotnames = get_slotnames(result.linfo.def)
8694
params = Compiler.OptimizationParams(interp)
87-
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
95+
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
8896
end
8997

9098
function Compiler.retrieve_ir_for_inlining(ci::CodeInstance, result::AnalyzedSource)

src/analysis/cache.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ end
2121
var_schedule::Vector{Pair{BitSet, BitSet}}
2222
end
2323

24+
struct VariableReplacement
25+
replaced::Incidence
26+
by::Int
27+
equation::Int
28+
end
29+
30+
mutable struct NonlinearReplacementMap
31+
const variables::Vector{VariableReplacement}
32+
variable_counter::Int
33+
equation_counter::Int
34+
end
35+
NonlinearReplacementMap() = NonlinearReplacementMap(VariableReplacement[], 0, 0)
36+
2437
"""
2538
StructuralSSARef
2639
@@ -35,6 +48,7 @@ struct DAEIPOResult
3548
opaque_eligible::Bool
3649
extended_rt::Any
3750
argtypes
51+
argmap::ArgumentMap
3852
nexternalargvars::Int # total vars is length(var_to_diff)
3953
nsysmscopes::Int
4054
nexternaleqs::Int
@@ -44,6 +58,7 @@ struct DAEIPOResult
4458
total_incidence::Vector{Any}
4559
eqclassification::Vector{VarEqClassification}
4660
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{StructuralSSARef, Int}}}}
61+
replacement_map::NonlinearReplacementMap
4762
names::OrderedDict{Any, ScopeDictEntry} # TODO: OrderedIdDict
4863
varkinds::Vector{Union{Intrinsics.VarKind, Nothing}}
4964
eqkinds::Vector{Union{Intrinsics.EqKind, Nothing}}

0 commit comments

Comments
 (0)