diff --git a/.github/workflows/run-examples.yml b/.github/workflows/run-examples.yml index bcf278e347..42085043cf 100644 --- a/.github/workflows/run-examples.yml +++ b/.github/workflows/run-examples.yml @@ -249,6 +249,7 @@ jobs: AnalysisSystemTest2, AnalysisSystemTest3, AnalysisSystemTest4, + TVSystemTest, ] fail-fast: false diff --git a/build.mill b/build.mill index d6a63ca9b8..292e4fe076 100644 --- a/build.mill +++ b/build.mill @@ -20,6 +20,7 @@ import basilmill.BasilDocs import basilmill.BasilVersion import basilmill.ProfileModule import basilmill.Z3Module +//import basilmill.CVC5Module import os.Path @@ -49,8 +50,12 @@ object `package` extends ScalaModule with BasilDocs with BasilVersion with Scala val aslpOffline = mvn"io.github.uq-pac::lifter:0.1.0" val javaSmt = mvn"org.sosy-lab:java-smt:5.0.0" val javaSmtZ3 = mvn"org.sosy-lab:javasmt-solver-z3:4.14.0" + val javaSmtCVC5 = mvn"org.sosy-lab:javasmt-solver-cvc5:1.2.1-g8594a8e4dc" + val cats_collections = mvn"org.typelevel::cats-collections-core:0.9.10" + val cats_core = mvn"org.typelevel::cats-core:2.13.0" + val cats_kernel = mvn"org.typelevel::cats-kernel:2.13.0" - override def mvnDeps = Seq(scalactic, sourceCode, mainArgs, upickle, aslpOffline, javaSmt, javaSmtZ3) + override def mvnDeps = Seq(scalactic, sourceCode, mainArgs, upickle, aslpOffline, javaSmt, javaSmtZ3, javaSmtCVC5, cats_kernel, cats_core, cats_collections) override def repositoriesTask = Task.Anon { super.repositoriesTask() :+ MavenRepository( @@ -67,6 +72,7 @@ object `package` extends ScalaModule with BasilDocs with BasilVersion with Scala override def moduleDir = BuildCtx.workspaceRoot / "src" + override def sources = Task.Sources("main/scala") override def forkArgs = Task { @@ -223,6 +229,7 @@ object `package` extends ScalaModule with BasilDocs with BasilVersion with Scala } object z3 extends Z3Module +// object cvc5 extends CVC5Module def ctagsConfig = Task.Source { BuildCtx.workspaceRoot / "basilmill" / "scala.ctags" @@ -255,7 +262,9 @@ object `package` extends ScalaModule with BasilDocs with BasilVersion with Scala def runProfile(profileDest: String, args: String*) = Task.Command { println(s"Profiling: you may want to set\n sudo sysctl kernel.perf_event_paranoid=1\n sudo sysctl kernel.kptr_restrict=0\n") val prof = asyncProf.path() - os.call(("java", s"-agentpath:${prof}=start,event=cpu,file=${profileDest}", "-jar", assembly().path.toString, args), stdout = os.Inherit, cwd = BuildCtx.workspaceRoot) + val oargs : Seq[String] = forkArgs() + val realArgs : Seq[String] = oargs ++ Seq(s"-agentpath:${prof}=start,event=cpu,file=${profileDest}") + os.call(Seq("java") ++ realArgs ++ Seq( "-jar", assembly().path.toString) ++ args, stdout = os.Inherit, cwd = BuildCtx.workspaceRoot) } } diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index e729d39090..e683f92679 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -20,6 +20,10 @@ # Analyses - [Data Structure Analysis](development/dsa.md) Basil Memory Analysis +- [Translation Validation](development/tv/index.md) + - [Public API](development/tv/tv-api.md) + - [Implementation](development/tv/implementation.md) + - [Assumptions](development/tv/assumptions.md) # Development diff --git a/docs/src/development/profiling.md b/docs/src/development/profiling.md index 5d3f83a1f0..2fe74d60ac 100644 --- a/docs/src/development/profiling.md +++ b/docs/src/development/profiling.md @@ -34,4 +34,20 @@ sudo sysctl kernel.perf_event_paranoid=1 sudo sysctl kernel.kptr_restrict=0 ``` +## Identifying memory leaks + +Memory leaks can be identified using async profiler. + +E.g. build an assembly `./mill assembly` to get `out/assembly.dest/out.jar` + +Then use asprof with flags `-e alloc --total --live`. + +``` +./async-profiler-4.0-linux-x64/bin/asprof -e alloc -o flamegraph --live -f alloc.html --total -d 60 out.jar +``` + +This will show the total allocation count only for references that are still live, allocations wiht a reference still retained +somewhere. + +[related blog article](https://web.archive.org/web/20240228031308/https://krzysztofslusarski.github.io/2022/11/27/async-live.html). diff --git a/docs/src/development/tv/assumptions.md b/docs/src/development/tv/assumptions.md new file mode 100644 index 0000000000..3cebc442a8 --- /dev/null +++ b/docs/src/development/tv/assumptions.md @@ -0,0 +1,113 @@ +# Lifter Assumptions + +A transform pass may want drive a code-transform with a reasonable assumption +about the behaviour, that it cannot technically prove with analysis. + +One such example is the `AssumeCallPreserved` pass which modifies +procedure parameter lists on the assumption that the program +respects the arm64 calling convention. + +To allow this, we introduce an assertion to the program, encoding the +assumption that was made by the transform. When the program +is eventually verified in Boogie, we discharge this assumption, +proving that our assumption was sound. + +The translation validation pipeline proves that as long as +the assumption holds, the programs have the same behaviour. +Because the introduction of an assert statement changes the +program behaviour, we need to add the specification that +traces on which the assertion fails can be ignored. + +Procedure procedure `_start` is an interesting example. + +We can use the script in `scripts/soundnesslitmuscvc5wrapper.sh` to get the number of +program statements contributing to the unsat core of the program: + + +``` + +tvsmt/simplifyCFG-_start_4213248.smt2 +66% contributed 171/259 asserts + +tvsmt/Parameters-_start_4213248.smt2 +66% contributed 145/219 asserts + +tvsmt/AssumeCallPreserved-_start_4213248.smt2 +68% contributed 177/260 asserts + +tvsmt/DSA-_start_4213248.smt2 +53% contributed 169/316 asserts + +tvsmt/CopyProp-_start_4213248.smt2 +4% contributed 11/257 asserts + +tvsmt/GuardCleanup-_start_4213248.smt2 +7% contributed 10/138 asserts + +``` + +The procedure `_start` contains an unresolved indirect tail call, +so the pass AssumeCallPreserved injects assertions that don't hold (but TV passes as its not TV's responsibility to prove them), +but then subsequent analyses that leverage this assumption are not sound and we get tv passing trivially due to +`assert eq(0x404a34:bv64, 0x404b64:bv64) { .comment = "R30 = R30_in" }; ~> assert false`. This is the expected behavior, a +program with an assert false is obviously a program that doesn't verify. + +This unsat core is showing roughly that the translation validation is passing vacuously because +the copyprop transform has derived false from the assertion we introduced. + +``` +cvc5 --dump-unsat-cores tvsmt/CopyProp-_start_4213248.smt2 +unsat +( +source5 +source47 +source54 +source56 +source51 +source57 +source53 +source58 +source69 +source70 +source48 +) +``` + +
+ full _start IL after copyprop pass + +``` + + (R0_in:bv64, R10_in:bv64, R11_in:bv64, R12_in:bv64, R13_in:bv64, R14_in:bv64, R15_in:bv64, R16_in:bv64, R17_in:bv64, R18_in:bv64, R1_in:bv64, R29_in:bv64, R2_in:bv64, R30_in:bv64, R31_in:bv64, R3_in:bv64, R4_in:bv64, R5_in:bv64, R6_in:bv64, R7_in:bv64, R8_in:bv64, R9_in:bv64, _PC_in:bv64) + -> (R0_out:bv64, R1_out:bv64, R2_out:bv64, R3_out:bv64, R4_out:bv64, R5_out:bv64, R6_out:bv64, R7_out:bv64, _PC_out:bv64) + { .name = "_start"; .address = 0x404a00 } +[ + block %_start_entry {.address = 0x404a00; .originalLabel = "ufjL9zmpTde18uF80OwPVQ=="} [ + var R5_2: bv64 := bvor(0x0:bv64, bvshl(R0_in:bv64, 0x0:bv64)); + var var1_4213376_bv64_1: bv64 := load le $mem R31_in:bv64 64; + var var2_4202816_bv64_1: bv64 := load le $mem bvadd(0x430000:bv64, 0x40:bv64) 64; + assert eq(0x404b64:bv64, 0x404b64:bv64) { .comment = "R30 = R30_in" }; + var (R0_4:bv64, R10_2:bv64, R11_2:bv64, R12_2:bv64, R13_2:bv64, R14_2:bv64, R15_2:bv64, R16_4:bv64, R17_3:bv64, R18_2:bv64, R1_3:bv64, R29_3:bv64, R2_3:bv64, R3_3:bv64, R4_3:bv64, R5_3:bv64, R6_3:bv64, R7_2:bv64, R8_2:bv64, R9_2:bv64) + := call @__libc_start_main (0x404a34:bv64, R10_in:bv64, R11_in:bv64, R12_in:bv64, R13_in:bv64, R14_in:bv64, R15_in:bv64, 0x430040:bv64, var2_4202816_bv64_1:bv64, R18_in:bv64, var1_4213376_bv64_1:bv64, 0x0:bv64, bvadd(R31_in:bv64, 0x8:bv64), 0x404b64:bv64, 0x0:bv64, 0x0:bv64, R5_2:bv64, R31_in:bv64, R7_in:bv64, R8_in:bv64, R9_in:bv64); + goto(%phi_5); + ]; + block %_start_10 {.address = 0x404a30; .originalLabel = "eH7LoljnQS6XhgPVv4Qipg=="} [ + assert eq(0x404a34:bv64, 0x404b64:bv64) { .comment = "R30 = R30_in" }; + var var2_4203488_bv64_1: bv64 := load le $mem bvadd(0x430000:bv64, 0x190:bv64) 64; + assert eq(0x404a34:bv64, 0x404a34:bv64) { .comment = "R30 = R30_in" }; + var (R0_6:bv64, R10_4:bv64, R11_4:bv64, R12_4:bv64, R13_4:bv64, R14_4:bv64, R15_4:bv64, R16_7:bv64, R17_5:bv64, R18_4:bv64, R1_5:bv64, R29_5:bv64, R2_5:bv64, R3_5:bv64, R4_5:bv64, R5_5:bv64, R6_5:bv64, R7_4:bv64, R8_4:bv64, R9_4:bv64) + := call @abort (R0_4:bv64, R10_2:bv64, R11_2:bv64, R12_2:bv64, R13_2:bv64, R14_2:bv64, R15_2:bv64, 0x430190:bv64, var2_4203488_bv64_1:bv64, R18_2:bv64, R1_3:bv64, R29_3:bv64, R2_3:bv64, 0x404a34:bv64, R3_3:bv64, R4_3:bv64, R5_3:bv64, R6_3:bv64, R7_2:bv64, R8_2:bv64, R9_2:bv64); + goto(%phi_6); + ]; + block %phi_5 {.originalLabel = "eH7LoljnQS6XhgPVv4Qipg==, ufjL9zmpTde18uF80OwPVQ=="} [ + goto(%_start_10); + ]; + block %phi_6 {.originalLabel = "eH7LoljnQS6XhgPVv4Qipg=="} [ + goto(%_start_return); + ]; + block %_start_return [ + return (R0_6:bv64, R1_5:bv64, R2_5:bv64, R3_5:bv64, R4_5:bv64, R5_5:bv64, R6_5:bv64, R7_4:bv64, _PC_in:bv64); + ] +]; +``` +
diff --git a/docs/src/development/tv/implementation.md b/docs/src/development/tv/implementation.md new file mode 100644 index 0000000000..89e7e9bcfe --- /dev/null +++ b/docs/src/development/tv/implementation.md @@ -0,0 +1,170 @@ +# Translation Validation Implementation + +At the highest level the translation validation takes two programs, an invariant linking them +and produces, constructs product program describing the simultaneous execution of the two programs, +and verifies this program satisfies the invariant. + +We initially describe the translation pipeline and structures used throughout this process. + +## Phases + +### Cut-Transition System + +- `TransitionSystem.scala` +- Transforms a Basil IR program to an equivalent acyclic Basil IR program + +A transition system describes one acyclic aggregate program step. This step breaks a single +Basil IR program into a program which represents a single acyclic step at a time. +This effectively fans out every loop in the program into a single loop which is equivalent +to the original program. + +#### Cut transform: + + 1. Create a new entry and exit + 2. Use program entry as a cut link it to the new entry and guard with a specific PC value `ENTRY` + 3. Use program exit as a cut link it to the new exit and set with a specific PC value `RETURN` + 2. Identify each loop header as a cut, set `PC := Loop$i` and redirect through exit, add edge + from entry to header guarded by a pc value `Loop$i` + +### Monadic Local Side-Effect Form + +- `SSADAG.scala` and `Ackermann.scala` + +This translates the Basil IR program to a program containing three statement types: + +1. (Simultaneous) assignment `(a, b, c) := (d, e, f)` (Scala type `SimulAssign`) +2. Side effect calls : `(TRACE, a, b, c) := EffName (TRACE, d, e, f)` (Scala type `SideEffectStatement`). + This is an uninterpreted function call with multiple/tuple return. +3. Assumes / Guards + +Note the `TRACE` boolean-typed variable here which represents the state passed through the program. +This is where the "monadic" terminology comes from. A boolean type is sufficient here as it only +needs to represent the truth value of the equivalence between the source and target trace. + +Think of this `TRACE` as an oracle represnting the entire universe, i.e. we assume +the precondition `TRACE_source == TRACE_target`. This is assuming the programs execute at the +same time in the same universe state; thus it captures external IO, assuming +both programs will always receive identical external inputs if they are invariantly +in the same state. + +A frame analysis is used to identify the interprocedural effects of calls. This transform +pulls these side effects (memory access, global variable access) into the parameter list +of the side-effect statement. + +### SSA Form + +This performs a naive SSA transform (not considering loops) on the Monadic program. + +It introduces reachability predicates (`${blockName}_done`) for the end of every block. +This predicate is the conjunction of + +1. the disjunction of the reachability of its predcessors and +2. conjunction of all assume statements in the block. + +Note the phi nodes have a slightly odd structure so they fit in the existing Basil IR. +In the below code, the assume at the start of block `l3` represents the phi node +joining `l1` and `l2`. + + +```c +block l1 [ + r0_0 := 1; // was r0 := 1 + goto l3; +]; +block l2 [ + r0_1 := 2; // was r0 := 2 + goto l3; +]; +block l3 [ + assume (l1_done ==> r0_3 == r0_0 && l2_done ==> r0_3 == r0_2); + ret; +]; +``` + +This transform returns a function which renames an un-indexed expression +to one in terms of the ssa-indexed variables defined at a given block. + +### Ackermannisation + +- This is an invariant inference pass perfomed on the SSA-form program + +This is a transform which soundly performs the reasoning about the correspondence of +side effects in the product program ahead of verification-time. + +At a high level, assume we have side-effect statements in the source and target program below: + +``` +// source program: +(source_TRACE_1, source_R0_1) := source_Load_mem_64 (source_TRACE, source_R1); +// target program: +(target_TRACE_1, target_R0_1) := target_Load_mem_64 (target_TRACE, target_R1); +``` + +Analagous to the congruence rule of uninterpreted functions we have the axiom: + +``` +\forall ts, tt, r0s, r0t :: tt == ts && r0s == r0t + ==> source_Load_mem_64(ts, t0s) == target_Load_mem_64(tt, t0t) +``` + +I.e. these loads have the same effect as long as they are the same address and occur +in the same state. + +We would want to instantiate this axiom whenever we have two corresponding +source and target loads, but really we only care about those that +are already likely to line up. Instead of letting the SMT solver +decide when to instantiate this axiom we use the control-flow graph, +and the requirement that transforms must preserve the order and number +of side-effects to instantiate exactly only the instances of the axiom that +the verification will need. + +This is done by walking the source and target CFGs in lockstep, +identifying matching side-effects and adding the +body of the axiom as an assertion to the verification condition. + +- After this is performed all `SideEffectStatement` are removed from the program. + +### Passified Form + +Since we have SSA form the semantics of assignment are unneccessary, we replace +every assignment with an `Assume` stating the equality of assignees. + +We now have a program consisting only of `Assume` statements. + +### SMT + +- `TranslationValidate.scala` + +- Infer invariant component at each cut and rename for the SSA renaming at the corresponding cut + - Rename free variables for ssa indexes for synth entry precondition and emit assertion + - Rename free variables for ssa indexes for synth exit and emit negated assertion +- Add every assume from the passified program to the SMT query +- Add the initial invariant to the SMT query, add the negated exit-invariant to the SMT query. + +This is built with `JavaSMT` and Basil's internal SMT builder. + + +## Debugging + +### Validation Failure + +When immediate verification is enabled, and `sat` is returned, the validator emits an `.il` file and a +CFG for a containing a fake representation of the passified product +program. It attempts to annotate the CFG with the model, however note that it often +incorrectly relates source variables to target variables (due to mis-alignment of blocks, assigns, SSA-indexing), +so this cannot be taken as given. + +### Unsoundness + +A litmus-test for the soundness of the verification is to generate the unsat core for the dumped SMT query. +If the verification is substantive, the unsat core should contain the entire transition system: +assertions named `source$number` and `tgt$number`. + + +# Split Optimisation + +For large procedures we break down the proof based on the entry cut. Because we always ahve the precondition +that we start in the same entry cut, for each possible entry we select the edge corresponding to that +entry and remove all other outgoing edges from the entry point. This assumption is then fully propogated +through by a dead code elimination. (In fact we remove the edge before the SSA pass so the flow is removed from the program). + diff --git a/docs/src/development/tv/index.md b/docs/src/development/tv/index.md new file mode 100644 index 0000000000..34147c1968 --- /dev/null +++ b/docs/src/development/tv/index.md @@ -0,0 +1,45 @@ +# Translation Validation + +Translation validaion is a system for verifying a transform pass preserves the +behaviour of a Basil IR program. + +The translation validator aims to prove that a given run of a transform is trace-preserving. +This property is sufficient to show that security verification on the transformed program +is sufficient to for the property to hold on the original program. + +The requirement is resonably strong and includes: + +- every possible path through loops corresponds between programs +- the set of observable variables (globals and in/out parameters) in the resulting program take the same + values on all traces as they do in corresponding traces in the original program +- The exact sequence of side-effects in the source and target program is identical. + +For our purposes, side effects include: + +- memory stores and loads +- procedure calls +- indirect calls +- control-flow branches (those marked `checkSecurity = true`) + +At a high level the translation validation works similarly to regular program verification. +This verification is applied to a 'product program' which combines the program prior to the transform (the *target* program) +and the program after the transform (the *source* program). +This verification is with respect to a specification stating that the source program has the same behaviour +as the target program. We automatically construct this specification using information +provided by the transform. This is sometimes called a certificate. +We often call this specification an `invariant`, since it closely corresponds to a loop invariant. + +> **Note on terminology:** +> +> We term the "*source*" program as the "higher level" or more +> abstract program that is the *result* of a lifting transform pass. +> +> The "*target*" program refers to the original program, the *input* to the lifting +> transform pass. +> +> This is the reverse of how how inputs and outputs are typically termed in compiler +> validation work. + +The translation validation is performed per-procedure, with a simple analysis being used +to infer live variables and procedure frames which summarises interprocedural effects +in the procedure scope. diff --git a/docs/src/development/tv/methodology.md b/docs/src/development/tv/methodology.md new file mode 100644 index 0000000000..5b1e5b3687 --- /dev/null +++ b/docs/src/development/tv/methodology.md @@ -0,0 +1 @@ +# Translation Validation Methodology diff --git a/docs/src/development/tv/tv-api.md b/docs/src/development/tv/tv-api.md new file mode 100644 index 0000000000..952377a924 --- /dev/null +++ b/docs/src/development/tv/tv-api.md @@ -0,0 +1,123 @@ +# Translation Validation API + +Translation validation is performed by the method `getValidationSMT` on the `TranslationValidator` class. + +We now explain its signature in full, it is provided in full below for context. + +```scala + + +case class TVJob( + outputPath: Option[String], + verify: Option[util.SMT.Solver] = None, + results: List[TVResult] = List(), + debugDumpAlways: Boolean = false, + /* minimum number of statements in source and target combined to trigger case analysis */ + splitLargeProceduresThreshold: Option[Int] = Some(60) +) + +object TranslationValidator: + def forTransform[T]( + transformName: String, + transform: Program => T, + invariant: T => InvariantDescription = (_: T) => InvariantDescription() + ): ((Program, TVJob) => TVJob) +``` + +This returns a anonymous function which runs the provided transform on a program, passes its result to the invariant function +to produce a description of the transform, then runs the tanslation validation, returning a copy of the `TVJob` +including the additional validation results. + +The invariants are specified with the `InvariantDescription` type, which can be derived from the output of the transform functor. + +```scala + +/** + * Describes the mapping from a variable in one program to an expression in the other, at a specific block and procedure. + */ +type TransformDataRelationFun = (ProcID, Option[BlockID]) => (Variable | Memory) => Seq[Expr] + +case class InvariantDescription( + /** The way live variables at each cut in the source program relate to equivalent expressions or variables in the target. + * + * NOTE: !!! The first returned value of this is also used to map procedure call arguments in the source + * program to the equivalent arguments in the target program. + * */ + renamingSrcTgt: TransformDataRelationFun = (_, _) => e => Seq(e), + + /** + * Describes how live variables at a cut in the target program relate to equivalent variables in the source. + * + */ + renamingTgtSrc: TransformDataRelationFun = (_, _) => _ => Seq(), + + /** + * Set of values of [ir.Assert.label] for assertions introduced in this pass, whose should + * be ignored as far as translation validation is concerned. + */ + introducedAsserts: Set[String] = Set() +) { +``` + +### Shape-Preserving Transforms + +However, for simple transforms---those which are *shape-preserving*--and perform only +limited code motion, can use the default relation between source and target programs. + +> **Shape preserving:** +> A transform which does not change the 'shape' of the program state: variables and +> memory are not renamed, added or removed and are used in the same way + +This default invaraint states that every variable in the source corresponds to the same +variable in the target. + +An example of this is the identity transform, which makes no changes: + +```scala +def nop(p: Program) = { + // execute the transform and write validation queries to folder tvsmt + TranslationValidator.forTransform("NOP", p => p)(p, TVJob(Some("tvsmt"))) +} +``` + +This validates with the default invariant `() => InvariantDescription` + +Even dead-code eliminiation can be handled with the default invariant. + +### Non-Shape-Preserving + +For sophisticated transforms more information may be provided to the validation framework +to generate the verification invariant through the parameters `renamingSrcTgt` and `flowFacts`. + +1. source-to-target variable relations; to relate possibly renamed variables (`renamingSrcTgt` parameter) + - !! This is also used to figure out how to match procedure call parameters between the target and source programs. +2. source-to-target variable relations; used in copyprop to pass definitions across cuts for variables live in the + target program but not live in the source program. + +## Introducing Assumptions via Later-Verified Assertions + +see also: [lifter-assertions](assumptions.md) + +For transforms that introduce assumptions by adding assertions to be discharged later, the +framework allows these to be temporarily ignored during validation. + +The verification then verifies the translation, but ignores traces differing only by the failure +of an assertion introduced by the transform. Since these traces will later be verified to not exist, +we can soundly ignore them in the translation validation. + +```scala +TranslationValidator.forTransform( + "AssumeCallPreserved", + transforms.CalleePreservedParam.transform, + asserts => InvariantDescription(introducedAsserts = asserts.toSet) +) +``` + +# Soundness of Translation Validation + +The specified assertion becomes an inductive loop invariant, so if it is not strong enough +the usual failure mode is for it not to verify, either due to not being inductive +or not being strong enough to establish the post-condition (equivalence of the procedure's post-state). + +However, since we currently do not implement specification validation it is possible to inject +`false` as a specification and this will vacuously verify. diff --git a/examples/conds/conds.c b/examples/conds/conds.c index faf2218294..6b95015dd0 100644 --- a/examples/conds/conds.c +++ b/examples/conds/conds.c @@ -1,85 +1,17 @@ #include -int x = 0; -volatile int r = 0; -volatile unsigned z; -volatile unsigned y; +volatile long int x = 0; +volatile long int r = 0; +volatile long int z; +volatile long int y; int main(int argc, char **argv) { - x = argc; - y = argc; - z = argc; - if (x < 0) { - x = r; - } - - if (x > 0) { - x = r; - } - - if (x < 5) { - x += r; - } - - if (x <= 8) { - x += r; - } - - - if (x >= 100) { - x += r; - } - - - if (x > 1000) { - x += r; - } - - - if (x < r) { - x += r * 10; - } - if (x <= r + 100) { - x += r * r; - } - if (x > r + 1000) { - x += 2 * r + r; - } - if (x >= r + 2000) { - x += r + r; - } - - if (y < 0) { - y += 1; - } - - if (y <= 0) { - y += 2; - } - - if (y <= 1000) { - y += 1; - } - - if (y >= -1) { - y += 1; - } - - if (y >= z) { - y += z; - } - - if (y <= z) { - y += z; - } - if (y < z) { - y += z; - } - if (y > z) { - y += z; + if (x > 4) { + puts("hello"); } + puts("world"); } diff --git a/examples/conds/conds.gts b/examples/conds/conds.gts index 39e8ad2a51..c458059fb5 100644 Binary files a/examples/conds/conds.gts and b/examples/conds/conds.gts differ diff --git a/examples/conds/conds.relf b/examples/conds/conds.relf index ee6b37db4a..c9b7c2659b 100644 --- a/examples/conds/conds.relf +++ b/examples/conds/conds.relf @@ -1,26 +1,28 @@ -Relocation section '.rela.dyn' at offset 0x530 contains 3 entries: +Relocation section '.rela.dyn' at offset 0x590 contains 3 entries: Offset Info Type Symbol's Value Symbol's Name + Addend 000000000041ffd0 0000000200000401 R_AARCH64_GLOB_DAT 0000000000000000 _ITM_deregisterTMCloneTable + 0 000000000041ffd8 0000000300000401 R_AARCH64_GLOB_DAT 0000000000000000 __gmon_start__ + 0 -000000000041ffe0 0000000500000401 R_AARCH64_GLOB_DAT 0000000000000000 _ITM_registerTMCloneTable + 0 +000000000041ffe0 0000000600000401 R_AARCH64_GLOB_DAT 0000000000000000 _ITM_registerTMCloneTable + 0 -Relocation section '.rela.plt' at offset 0x578 contains 3 entries: +Relocation section '.rela.plt' at offset 0x5d8 contains 4 entries: Offset Info Type Symbol's Value Symbol's Name + Addend 0000000000420000 0000000100000402 R_AARCH64_JUMP_SLOT 0000000000000000 __libc_start_main@GLIBC_2.34 + 0 0000000000420008 0000000300000402 R_AARCH64_JUMP_SLOT 0000000000000000 __gmon_start__ + 0 0000000000420010 0000000400000402 R_AARCH64_JUMP_SLOT 0000000000000000 abort@GLIBC_2.17 + 0 +0000000000420018 0000000500000402 R_AARCH64_JUMP_SLOT 0000000000000000 puts@GLIBC_2.17 + 0 -Symbol table '.dynsym' contains 6 entries: +Symbol table '.dynsym' contains 7 entries: Num: Value Size Type Bind Vis Ndx Name 0: 0000000000000000 0 NOTYPE LOCAL DEFAULT UND 1: 0000000000000000 0 FUNC GLOBAL DEFAULT UND __libc_start_main@GLIBC_2.34 (2) 2: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_deregisterTMCloneTable 3: 0000000000000000 0 NOTYPE WEAK DEFAULT UND __gmon_start__ 4: 0000000000000000 0 FUNC GLOBAL DEFAULT UND abort@GLIBC_2.17 (3) - 5: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_registerTMCloneTable + 5: 0000000000000000 0 FUNC GLOBAL DEFAULT UND puts@GLIBC_2.17 (3) + 6: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_registerTMCloneTable -Symbol table '.symtab' contains 95 entries: +Symbol table '.symtab' contains 97 entries: Num: Value Size Type Bind Vis Ndx Name 0: 0000000000000000 0 NOTYPE LOCAL DEFAULT UND 1: 0000000000400238 0 SECTION LOCAL DEFAULT 1 .interp @@ -28,92 +30,94 @@ Symbol table '.symtab' contains 95 entries: 3: 00000000004002c8 0 SECTION LOCAL DEFAULT 3 .hash 4: 00000000004002f8 0 SECTION LOCAL DEFAULT 4 .gnu.hash 5: 0000000000400318 0 SECTION LOCAL DEFAULT 5 .dynsym - 6: 00000000004003a8 0 SECTION LOCAL DEFAULT 6 .dynstr - 7: 00000000004004f2 0 SECTION LOCAL DEFAULT 7 .gnu.version - 8: 0000000000400500 0 SECTION LOCAL DEFAULT 8 .gnu.version_r - 9: 0000000000400530 0 SECTION LOCAL DEFAULT 9 .rela.dyn - 10: 0000000000400578 0 SECTION LOCAL DEFAULT 10 .rela.plt - 11: 00000000004005c0 0 SECTION LOCAL DEFAULT 11 .init - 12: 00000000004005e0 0 SECTION LOCAL DEFAULT 12 .plt - 13: 0000000000400640 0 SECTION LOCAL DEFAULT 13 .text - 14: 0000000000400c04 0 SECTION LOCAL DEFAULT 14 .fini - 15: 0000000000400c18 0 SECTION LOCAL DEFAULT 15 .rodata - 16: 0000000000400c1c 0 SECTION LOCAL DEFAULT 16 .eh_frame_hdr - 17: 0000000000400c60 0 SECTION LOCAL DEFAULT 17 .eh_frame + 6: 00000000004003c0 0 SECTION LOCAL DEFAULT 6 .dynstr + 7: 000000000040054e 0 SECTION LOCAL DEFAULT 7 .gnu.version + 8: 0000000000400560 0 SECTION LOCAL DEFAULT 8 .gnu.version_r + 9: 0000000000400590 0 SECTION LOCAL DEFAULT 9 .rela.dyn + 10: 00000000004005d8 0 SECTION LOCAL DEFAULT 10 .rela.plt + 11: 0000000000400638 0 SECTION LOCAL DEFAULT 11 .init + 12: 0000000000400650 0 SECTION LOCAL DEFAULT 12 .plt + 13: 00000000004006c0 0 SECTION LOCAL DEFAULT 13 .text + 14: 000000000040080c 0 SECTION LOCAL DEFAULT 14 .fini + 15: 0000000000400820 0 SECTION LOCAL DEFAULT 15 .rodata + 16: 0000000000400838 0 SECTION LOCAL DEFAULT 16 .eh_frame_hdr + 17: 0000000000400880 0 SECTION LOCAL DEFAULT 17 .eh_frame 18: 000000000041fdc8 0 SECTION LOCAL DEFAULT 18 .init_array 19: 000000000041fdd0 0 SECTION LOCAL DEFAULT 19 .fini_array 20: 000000000041fdd8 0 SECTION LOCAL DEFAULT 20 .dynamic 21: 000000000041ffc8 0 SECTION LOCAL DEFAULT 21 .got 22: 000000000041ffe8 0 SECTION LOCAL DEFAULT 22 .got.plt - 23: 0000000000420018 0 SECTION LOCAL DEFAULT 23 .data - 24: 0000000000420028 0 SECTION LOCAL DEFAULT 24 .bss + 23: 0000000000420020 0 SECTION LOCAL DEFAULT 23 .data + 24: 0000000000420030 0 SECTION LOCAL DEFAULT 24 .bss 25: 0000000000000000 0 SECTION LOCAL DEFAULT 25 .comment 26: 0000000000000000 0 FILE LOCAL DEFAULT ABS crt1.o 27: 00000000004002a8 0 NOTYPE LOCAL DEFAULT 2 $d 28: 00000000004002a8 32 OBJECT LOCAL DEFAULT 2 __abi_tag - 29: 0000000000400640 0 NOTYPE LOCAL DEFAULT 13 $x - 30: 0000000000400674 0 NOTYPE LOCAL DEFAULT 13 __wrap_main - 31: 0000000000400c74 0 NOTYPE LOCAL DEFAULT 17 $d - 32: 0000000000400c18 0 NOTYPE LOCAL DEFAULT 15 $d - 33: 0000000000400680 0 NOTYPE LOCAL DEFAULT 13 $x - 34: 0000000000400c88 0 NOTYPE LOCAL DEFAULT 17 $d + 29: 00000000004006c0 0 NOTYPE LOCAL DEFAULT 13 $x + 30: 00000000004006f4 0 NOTYPE LOCAL DEFAULT 13 __wrap_main + 31: 0000000000400894 0 NOTYPE LOCAL DEFAULT 17 $d + 32: 0000000000400820 0 NOTYPE LOCAL DEFAULT 15 $d + 33: 0000000000400700 0 NOTYPE LOCAL DEFAULT 13 $x + 34: 00000000004008a8 0 NOTYPE LOCAL DEFAULT 17 $d 35: 0000000000000000 0 FILE LOCAL DEFAULT ABS crti.o - 36: 0000000000400684 0 NOTYPE LOCAL DEFAULT 13 $x - 37: 0000000000400684 20 FUNC LOCAL DEFAULT 13 call_weak_fn - 38: 00000000004005c0 0 NOTYPE LOCAL DEFAULT 11 $x - 39: 0000000000400c04 0 NOTYPE LOCAL DEFAULT 14 $x + 36: 0000000000400704 0 NOTYPE LOCAL DEFAULT 13 $x + 37: 0000000000400704 20 FUNC LOCAL DEFAULT 13 call_weak_fn + 38: 0000000000400638 0 NOTYPE LOCAL DEFAULT 11 $x + 39: 000000000040080c 0 NOTYPE LOCAL DEFAULT 14 $x 40: 0000000000000000 0 FILE LOCAL DEFAULT ABS crtn.o - 41: 00000000004005d0 0 NOTYPE LOCAL DEFAULT 11 $x - 42: 0000000000400c10 0 NOTYPE LOCAL DEFAULT 14 $x + 41: 0000000000400648 0 NOTYPE LOCAL DEFAULT 11 $x + 42: 0000000000400818 0 NOTYPE LOCAL DEFAULT 14 $x 43: 0000000000000000 0 FILE LOCAL DEFAULT ABS crtbegin.o - 44: 00000000004006a0 0 NOTYPE LOCAL DEFAULT 13 $x - 45: 00000000004006a0 0 FUNC LOCAL DEFAULT 13 deregister_tm_clones - 46: 00000000004006d0 0 FUNC LOCAL DEFAULT 13 register_tm_clones - 47: 0000000000420020 0 NOTYPE LOCAL DEFAULT 23 $d - 48: 0000000000400710 0 FUNC LOCAL DEFAULT 13 __do_global_dtors_aux - 49: 0000000000420028 1 OBJECT LOCAL DEFAULT 24 completed.0 + 44: 0000000000400720 0 NOTYPE LOCAL DEFAULT 13 $x + 45: 0000000000400720 0 FUNC LOCAL DEFAULT 13 deregister_tm_clones + 46: 0000000000400750 0 FUNC LOCAL DEFAULT 13 register_tm_clones + 47: 0000000000420028 0 NOTYPE LOCAL DEFAULT 23 $d + 48: 000000000040078c 0 FUNC LOCAL DEFAULT 13 __do_global_dtors_aux + 49: 0000000000420030 1 OBJECT LOCAL DEFAULT 24 completed.0 50: 000000000041fdd0 0 NOTYPE LOCAL DEFAULT 19 $d 51: 000000000041fdd0 0 OBJECT LOCAL DEFAULT 19 __do_global_dtors_aux_fini_array_entry - 52: 0000000000400740 0 FUNC LOCAL DEFAULT 13 frame_dummy + 52: 00000000004007c0 0 FUNC LOCAL DEFAULT 13 frame_dummy 53: 000000000041fdc8 0 NOTYPE LOCAL DEFAULT 18 $d 54: 000000000041fdc8 0 OBJECT LOCAL DEFAULT 18 __frame_dummy_init_array_entry - 55: 0000000000400ca0 0 NOTYPE LOCAL DEFAULT 17 $d - 56: 0000000000420028 0 NOTYPE LOCAL DEFAULT 24 $d + 55: 00000000004008c0 0 NOTYPE LOCAL DEFAULT 17 $d + 56: 0000000000420030 0 NOTYPE LOCAL DEFAULT 24 $d 57: 0000000000000000 0 FILE LOCAL DEFAULT ABS conds.c - 58: 000000000042002c 0 NOTYPE LOCAL DEFAULT 24 $d - 59: 0000000000400744 0 NOTYPE LOCAL DEFAULT 13 $x - 60: 0000000000400d00 0 NOTYPE LOCAL DEFAULT 17 $d - 61: 0000000000000000 0 FILE LOCAL DEFAULT ABS crtend.o - 62: 0000000000400d1c 0 NOTYPE LOCAL DEFAULT 17 $d - 63: 0000000000400d1c 0 OBJECT LOCAL DEFAULT 17 __FRAME_END__ - 64: 0000000000000000 0 FILE LOCAL DEFAULT ABS - 65: 000000000041fdd8 0 OBJECT LOCAL DEFAULT 20 _DYNAMIC - 66: 0000000000400c1c 0 NOTYPE LOCAL DEFAULT 16 __GNU_EH_FRAME_HDR - 67: 000000000041ffc8 0 OBJECT LOCAL DEFAULT 21 _GLOBAL_OFFSET_TABLE_ - 68: 00000000004005e0 0 NOTYPE LOCAL DEFAULT 12 $x - 69: 0000000000000000 0 FUNC GLOBAL DEFAULT UND __libc_start_main@GLIBC_2.34 - 70: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_deregisterTMCloneTable - 71: 0000000000420018 0 NOTYPE WEAK DEFAULT 23 data_start - 72: 0000000000420028 0 NOTYPE GLOBAL DEFAULT 24 __bss_start__ - 73: 0000000000420040 0 NOTYPE GLOBAL DEFAULT 24 _bss_end__ - 74: 0000000000420030 4 OBJECT GLOBAL DEFAULT 24 r - 75: 0000000000420028 0 NOTYPE GLOBAL DEFAULT 23 _edata - 76: 0000000000420034 4 OBJECT GLOBAL DEFAULT 24 z - 77: 000000000042002c 4 OBJECT GLOBAL DEFAULT 24 x - 78: 0000000000400c04 0 FUNC GLOBAL HIDDEN 14 _fini - 79: 0000000000420040 0 NOTYPE GLOBAL DEFAULT 24 __bss_end__ - 80: 0000000000420018 0 NOTYPE GLOBAL DEFAULT 23 __data_start - 81: 0000000000000000 0 NOTYPE WEAK DEFAULT UND __gmon_start__ - 82: 0000000000420020 0 OBJECT GLOBAL HIDDEN 23 __dso_handle - 83: 0000000000000000 0 FUNC GLOBAL DEFAULT UND abort@GLIBC_2.17 - 84: 0000000000400c18 4 OBJECT GLOBAL DEFAULT 15 _IO_stdin_used - 85: 0000000000420040 0 NOTYPE GLOBAL DEFAULT 24 _end - 86: 0000000000400680 4 FUNC GLOBAL HIDDEN 13 _dl_relocate_static_pie - 87: 0000000000400640 60 FUNC GLOBAL DEFAULT 13 _start - 88: 0000000000420040 0 NOTYPE GLOBAL DEFAULT 24 __end__ - 89: 0000000000420038 4 OBJECT GLOBAL DEFAULT 24 y - 90: 0000000000420028 0 NOTYPE GLOBAL DEFAULT 24 __bss_start - 91: 0000000000400744 1216 FUNC GLOBAL DEFAULT 13 main - 92: 0000000000420028 0 OBJECT GLOBAL HIDDEN 23 __TMC_END__ - 93: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_registerTMCloneTable - 94: 00000000004005c0 0 FUNC GLOBAL HIDDEN 11 _init + 58: 0000000000420038 0 NOTYPE LOCAL DEFAULT 24 $d + 59: 0000000000400828 0 NOTYPE LOCAL DEFAULT 15 $d + 60: 00000000004007c4 0 NOTYPE LOCAL DEFAULT 13 $x + 61: 0000000000400920 0 NOTYPE LOCAL DEFAULT 17 $d + 62: 0000000000000000 0 FILE LOCAL DEFAULT ABS crtend.o + 63: 0000000000400940 0 NOTYPE LOCAL DEFAULT 17 $d + 64: 0000000000400940 0 OBJECT LOCAL DEFAULT 17 __FRAME_END__ + 65: 0000000000000000 0 FILE LOCAL DEFAULT ABS + 66: 000000000041fdd8 0 OBJECT LOCAL DEFAULT 20 _DYNAMIC + 67: 0000000000400838 0 NOTYPE LOCAL DEFAULT 16 __GNU_EH_FRAME_HDR + 68: 000000000041ffc8 0 OBJECT LOCAL DEFAULT 21 _GLOBAL_OFFSET_TABLE_ + 69: 0000000000400650 0 NOTYPE LOCAL DEFAULT 12 $x + 70: 0000000000000000 0 FUNC GLOBAL DEFAULT UND __libc_start_main@GLIBC_2.34 + 71: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_deregisterTMCloneTable + 72: 0000000000420020 0 NOTYPE WEAK DEFAULT 23 data_start + 73: 0000000000420030 0 NOTYPE GLOBAL DEFAULT 24 __bss_start__ + 74: 0000000000420058 0 NOTYPE GLOBAL DEFAULT 24 _bss_end__ + 75: 0000000000420040 8 OBJECT GLOBAL DEFAULT 24 r + 76: 0000000000420030 0 NOTYPE GLOBAL DEFAULT 23 _edata + 77: 0000000000420048 8 OBJECT GLOBAL DEFAULT 24 z + 78: 0000000000420038 8 OBJECT GLOBAL DEFAULT 24 x + 79: 000000000040080c 0 FUNC GLOBAL HIDDEN 14 _fini + 80: 0000000000420058 0 NOTYPE GLOBAL DEFAULT 24 __bss_end__ + 81: 0000000000420020 0 NOTYPE GLOBAL DEFAULT 23 __data_start + 82: 0000000000000000 0 NOTYPE WEAK DEFAULT UND __gmon_start__ + 83: 0000000000420028 0 OBJECT GLOBAL HIDDEN 23 __dso_handle + 84: 0000000000000000 0 FUNC GLOBAL DEFAULT UND abort@GLIBC_2.17 + 85: 0000000000400820 4 OBJECT GLOBAL DEFAULT 15 _IO_stdin_used + 86: 0000000000000000 0 FUNC GLOBAL DEFAULT UND puts@GLIBC_2.17 + 87: 0000000000420058 0 NOTYPE GLOBAL DEFAULT 24 _end + 88: 0000000000400700 4 FUNC GLOBAL HIDDEN 13 _dl_relocate_static_pie + 89: 00000000004006c0 60 FUNC GLOBAL DEFAULT 13 _start + 90: 0000000000420058 0 NOTYPE GLOBAL DEFAULT 24 __end__ + 91: 0000000000420050 8 OBJECT GLOBAL DEFAULT 24 y + 92: 0000000000420030 0 NOTYPE GLOBAL DEFAULT 24 __bss_start + 93: 00000000004007c4 72 FUNC GLOBAL DEFAULT 13 main + 94: 0000000000420030 0 OBJECT GLOBAL HIDDEN 23 __TMC_END__ + 95: 0000000000000000 0 NOTYPE WEAK DEFAULT UND _ITM_registerTMCloneTable + 96: 0000000000400638 0 FUNC GLOBAL HIDDEN 11 _init diff --git a/scripts/runall.sh b/scripts/runall.sh new file mode 100644 index 0000000000..47fd5823ee --- /dev/null +++ b/scripts/runall.sh @@ -0,0 +1,11 @@ +#!/bin/bash + + +ls ../tvbins/*/*.gtirb | xargs -I % -n 1 -P 8 bash scripts/run.sh % + +python3 scripts/torow.py header +for i in output*; +do + cat $i | python3 scripts/torow.py $i +done + diff --git a/scripts/soundnesslitmuscvc5wrapper.sh b/scripts/soundnesslitmuscvc5wrapper.sh new file mode 100755 index 0000000000..da18261e61 --- /dev/null +++ b/scripts/soundnesslitmuscvc5wrapper.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +timelimit=3000 + +fname=$1 +echo $fname +in_query=$(grep -E '(source[0-9]+|tgt[0-9]+)' -o $fname | wc -l) +cvcresult=$(cvc5 --dump-unsat-cores $fname) +in_unsatcore=$(echo $cvcresult | grep -E '(source[0-9]+|tgt[0-9]+)' -o | wc -l) +exitcode=$? +if [[ exitcode ]] then + printf "%d%% contributed \t%d/%d asserts\n" $(echo "scale=0; 100 * $in_unsatcore / $in_query" | bc) $in_unsatcore $in_query +else + echo "timeout" +fi + +echo "" + + diff --git a/scripts/torow.py b/scripts/torow.py new file mode 100644 index 0000000000..3b6e82c582 --- /dev/null +++ b/scripts/torow.py @@ -0,0 +1,40 @@ +import fileinput +import sys + +headers = ["fname", "before-stmt-count", + "before-guard-total-count", + "before-guard-simple-count", + "before-guard-complex-unique-count", + "before-guard-complex-VarsTooMany-count", + "before-guard-complex-HasFlagRegisters-count", + "after-guard-total-count", + "after-guard-simple-count", + "after-guard-complex-unique-count", + "after-guard-complex-OpsTooMany-count", + "after-guard-complex-VarsTooMany-count", + "after-stmt-count"] + + +if (sys.argv[1] == "header"): + print(", ".join(headers)) + exit(0) + +rs = {} + +for i in headers: + rs[i] = "0" + +rs["fname"] = sys.argv[1] + + +for line in fileinput.input(encoding='utf-8', errors='ignore', files=('-')): + # print(line, end="") + if ('tv-eval-marker' in line): + line = line.split(":")[1].strip() + rc = line.split("=") + rs[rc[0].strip()] = rc[1].strip() + + + + +print(", ".join(rs[i] for i in headers)) diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala index ad1030635b..a4ca623325 100644 --- a/src/main/scala/Main.scala +++ b/src/main/scala/Main.scala @@ -19,6 +19,7 @@ import util.{ PCTrackingOption, ProcRelyVersion, RunUtils, + SimplifyMode, StaticAnalysisConfig, writeToFile } @@ -207,6 +208,12 @@ object Main { generateRelyGuarantees: Flag, @arg(name = "simplify", doc = "Partial evaluate / simplify BASIL IR before output (implies --parameter-form)") simplify: Flag, + @arg(name = "simplify-tv", doc = "Simplify pass with translation validation, takes smt file output directory") + tvSimp: Option[String], + @arg(name = "simplify-tv-verify", doc = "Simplify with translation validation immediately call z3") + tvSimpVerify: Flag, + @arg(name = "simplify-tv-dryrun", doc = "Skip all tv work after invariant generation") + tvDryRun: Flag, @arg( name = "pc", doc = "Program counter mode, supports GTIRB only. (options: none | keep | assert) (default: none)" @@ -451,6 +458,13 @@ object Main { util.assertion.disableAssertions = true } + val simplifyMode = (conf.simplify.value, conf.tvSimp, conf.tvSimpVerify.value) match { + case (_, d, true) => SimplifyMode.ValidatedSimplify(Some(util.SMT.Solver.Z3), d, dryRun = conf.tvDryRun.value) + case (_, Some(d), _) => SimplifyMode.ValidatedSimplify(None, Some(d), dryRun = conf.tvDryRun.value) + case (true, None, _) => SimplifyMode.Simplify + case _ => SimplifyMode.Disabled + } + val q = BASILConfig( loading = loadingInputs.copy( dumpIL = conf.dumpIL, @@ -462,7 +476,7 @@ object Main { gtirbLiftOffline = conf.liftOffline.value ), runInterpret = conf.interpret.value, - simplify = conf.simplify.value, + simplify = simplifyMode, validateSimp = conf.validateSimplify.value, summariseProcedures = conf.summariseProcedures.value, generateLoopInvariants = conf.generateLoopInvariants.value, diff --git a/src/main/scala/analysis/InterLiveVarsAnalysis.scala b/src/main/scala/analysis/InterLiveVarsAnalysis.scala index d1c794c767..9caeef7dc3 100644 --- a/src/main/scala/analysis/InterLiveVarsAnalysis.scala +++ b/src/main/scala/analysis/InterLiveVarsAnalysis.scala @@ -18,6 +18,8 @@ import ir.{ Variable } +import scala.collection.immutable.ListSet + /** Micro-transfer-functions for LiveVar analysis * This analysis works by inlining function calls - instead of just mapping parameters and returns, all live variables * (registers) are propagated to and from callee functions. The result of which variables are alive at each point in @@ -136,6 +138,56 @@ trait LiveVarsAnalysisFunctions(inline: Boolean, addExternals: Boolean = true) } } -class InterLiveVarsAnalysis(program: Program, ignoreExternals: Boolean = false) - extends BackwardIDESolver[Variable, TwoElement, TwoElementLattice](program), +class InterLiveVarsAnalysis(program: Program, ignoreExternals: Boolean = false, entry: Option[Procedure] = None) + extends BackwardIDESolver[Variable, TwoElement, TwoElementLattice](program, entry), LiveVarsAnalysisFunctions(true, !ignoreExternals) + +def interLiveVarsAnalysis( + program: Program, + ignoreExternals: Boolean = false +): Map[CFGPosition, Map[Variable, TwoElement]] = { + + var procs = ListSet.from(program.procedures) + var starts = List[Procedure](program.mainProcedure) + + // while { + // val entries = + // procs.toList.filter(p => p.incomingCalls().size == 0 && p.entryBlock.isDefined && p.returnBlock.isDefined) + // if (entries.nonEmpty) { + // starts = entries.head :: starts + // val done = entries.head.reachableFrom + // procs = procs -- done + // procs.nonEmpty + // } else { + // Logger.warn(s"Live vars :: no program entry candidates remaining") + // false + // } + // } do {} + + // val reachable = starts.toSet.flatMap(_.reachableFrom) + // if ( + // !(reachable.contains( + // program.mainProcedure + // )) && procs.nonEmpty && program.mainProcedure.entryBlock.isDefined && program.mainProcedure.returnBlock.isDefined + // ) { + // Logger.warn( + // s"mainProcedure has predecessors but is not reachable from an entry-candidate, using it as an entry candidate." + // ) + // val remaining = program.procedures.toSet -- program.mainProcedure.reachableFrom + // starts = List(program.mainProcedure) + // procs = procs -- program.mainProcedure.reachableFrom + // } else if (!reachable.contains(program.mainProcedure)) { + // Logger.warn(s"mainProcedure ${program.mainProcedure.name} is a stub and is not reachable from any entry point") + // } + + // if (procs.nonEmpty) { + // Logger.error(s"Code unreachable for liveness analysis: ${procs.toList}") + // } + + var r = Map[CFGPosition, Map[Variable, TwoElement]]() + for (p <- starts) { + r = r ++ InterLiveVarsAnalysis(program, ignoreExternals, Some(p)).analyze() + } + r + +} diff --git a/src/main/scala/analysis/IrreducibleLoops.scala b/src/main/scala/analysis/IrreducibleLoops.scala index 0776b3aa27..d6be06abe8 100644 --- a/src/main/scala/analysis/IrreducibleLoops.scala +++ b/src/main/scala/analysis/IrreducibleLoops.scala @@ -52,6 +52,7 @@ object LoopDetector { case class State( // Header -> Loop loops: Map[Block, Loop] = Map(), + loops_o: List[Loop] = List(), headers: Set[Block] = Set(), // Algorithm helpers @@ -65,7 +66,8 @@ object LoopDetector { def identifiedLoops: Iterable[Loop] = loops.values def reducibleTransformIR(): State = { - this.copy(loops = LoopTransform.llvm_transform(loops.values).map(l => l.header -> l).toMap) + val nr = LoopTransform.llvm_transform(loops.values).map(l => l.header -> l).toMap + this.copy(loops = nr, loops_o = nr.values.toList) } def updateIrWithLoops() = { @@ -121,7 +123,7 @@ object LoopDetector { } if (!st.loops.contains(edge.to)) { - st = st.copy(loops = st.loops.updated(edge.to, newLoop)) + st = st.copy(loops = st.loops.updated(edge.to, newLoop), loops_o = newLoop :: st.loops_o) } st = tag_lhead(st, edge.from, edge.to) diff --git a/src/main/scala/analysis/ProcFrames.scala b/src/main/scala/analysis/ProcFrames.scala new file mode 100644 index 0000000000..bb881b31dd --- /dev/null +++ b/src/main/scala/analysis/ProcFrames.scala @@ -0,0 +1,82 @@ +package analysis +import ir.* +import ir.cilvisitor.* + +import transforms.* + +object ProcFrames { + + case class Frame( + modifiedGlobalVars: Set[GlobalVar] = Set(), + modifiedMem: Set[Memory] = Set(), + readGlobalVars: Set[GlobalVar] = Set(), + readMem: Set[Memory] = Set() + ) { + /* coarse variable-level frame for procedures + * + * - globals modified + * - globals captured + * + */ + + def union(o: Frame) = { + Frame( + modifiedGlobalVars ++ o.modifiedGlobalVars, + modifiedMem ++ o.modifiedMem, + readGlobalVars ++ o.readGlobalVars, + readMem ++ o.readMem + ) + } + } + + private class LocalModified(val v: Procedure => Frame, var summary: Frame) extends CILVisitor { + override def vmem(m: Memory) = { + summary = summary.copy(readMem = summary.readMem + m) + SkipChildren() + } + + override def vrvar(v: Variable) = v match { + case v: GlobalVar => + summary = summary.copy(readGlobalVars = summary.readGlobalVars + v) + SkipChildren() + case _ => + SkipChildren() + } + + override def vlvar(v: Variable) = v match { + case v: GlobalVar => + summary = summary.copy(modifiedGlobalVars = summary.modifiedGlobalVars + v) + SkipChildren() + case _ => + SkipChildren() + } + + override def vstmt(s: Statement) = { + s match { + case MemoryStore(m, _, _, _, _, _) => { + summary = summary.copy(modifiedMem = summary.modifiedMem + m) + } + case MemoryAssign(m: GlobalVar, _, _) => { + summary = summary.copy(modifiedGlobalVars = summary.modifiedGlobalVars + m) + } + case d: DirectCall => { + summary = summary.union(v(d.target)) + } + case _ => () + } + DoChildren() + } + } + + private def transferProcedure(getSt: Procedure => Frame, s: Frame, p: Procedure): Frame = { + val v = LocalModified(getSt, s) + visit_proc(v, p) + v.summary + } + + def inferProcFrames(p: Program): Map[Procedure, Frame] = { + val solver = BottomUpCallgraphWorklistSolver(transferProcedure, x => Frame(Set(), Set(), Set(), Set())) + solver.solve(p) + } + +} diff --git a/src/main/scala/analysis/solvers/IDESolver.scala b/src/main/scala/analysis/solvers/IDESolver.scala index a711da8919..104fa90bf4 100644 --- a/src/main/scala/analysis/solvers/IDESolver.scala +++ b/src/main/scala/analysis/solvers/IDESolver.scala @@ -245,21 +245,14 @@ abstract class IDESolver[ } def analyze(): Map[CFGPosition, Map[D, T]] = { - if ( - program.mainProcedure.blocks.nonEmpty && program.mainProcedure.returnBlock.isDefined && program.mainProcedure.entryBlock.isDefined - ) { - val phase1 = Phase1() - val phase2 = Phase2(phase1) - phase2.restructure(phase2.analyze()) - } else { - Logger.warn(s"Disabling IDE solver tests due to external main procedure: ${program.mainProcedure.name}") - Map() - } + val phase1 = Phase1() + val phase2 = Phase2(phase1) + phase2.restructure(phase2.analyze()) } } -abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) - extends IDESolver[Procedure, Return, DirectCall, Command, D, T, L](program, program.mainProcedure), +abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program, entry: Option[Procedure] = None) + extends IDESolver[Procedure, Return, DirectCall, Command, D, T, L](program, entry.getOrElse(program.mainProcedure)), ForwardIDEAnalysis[D, T, L], IRInterproceduralForwardDependencies { @@ -302,10 +295,12 @@ abstract class ForwardIDESolver[D, T, L <: Lattice[T]](program: Program) InterProcIRCursor.succ(exit).filter(_.isInstanceOf[Command]).map(_.asInstanceOf[Command]) } -abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program) +abstract class BackwardIDESolver[D, T, L <: Lattice[T]](program: Program, entry: Option[Procedure] = None) extends IDESolver[Return, Procedure, Command, DirectCall, D, T, L]( - program, - IRWalk.lastInProc(program.mainProcedure).getOrElse(program.mainProcedure) + program, { + val e = entry.getOrElse(program.mainProcedure) + IRWalk.lastInProc(e).getOrElse(e) + } ), BackwardIDEAnalysis[D, T, L], IRInterproceduralBackwardDependencies { diff --git a/src/main/scala/cfg_visualiser/DotTools.scala b/src/main/scala/cfg_visualiser/DotTools.scala index d9a98020bd..6b0fb583db 100644 --- a/src/main/scala/cfg_visualiser/DotTools.scala +++ b/src/main/scala/cfg_visualiser/DotTools.scala @@ -54,7 +54,7 @@ class DotNode(val id: String, val label: String, highlight: Boolean = false) ext def equals(other: DotNode): Boolean = toDotString.equals(other.toDotString) - def hl = if (highlight) then "style=filled, fillcolor=\"orangered\", " else "" + def hl = if (highlight) then "style=filled, fillcolor=\"aliceblue\", " else "" def toDotString: String = s"\"$id\"" + s"[${hl}label=\"" + escape(wrap(label, 100)) + "\", shape=\"box\", fontname=\"Mono\", fontsize=\"5\"]" diff --git a/src/main/scala/ir/Expr.scala b/src/main/scala/ir/Expr.scala index cd8278bdc8..0ed30e1fe1 100644 --- a/src/main/scala/ir/Expr.scala +++ b/src/main/scala/ir/Expr.scala @@ -187,6 +187,22 @@ sealed trait BVUnOp(op: String) extends UnOp { case object BVNOT extends BVUnOp("not") case object BVNEG extends BVUnOp("neg") +def boolAnd(exps: Iterable[Expr]) = + val l = exps.toList + l.size match { + case 0 => TrueLiteral + case 1 => l.head + case _ => AssocExpr(BoolAND, l) + } + +def boolOr(exps: Iterable[Expr]) = + val l = exps.toList + l.size match { + case 0 => FalseLiteral + case 1 => l.head + case _ => AssocExpr(BoolOR, l) + } + case class AssocExpr(op: BoolBinOp, args: List[Expr]) extends Expr with CachedHashCode { require(args.size >= 2, "AssocExpr requires at least two operands") override def getType: IRType = BoolType @@ -381,6 +397,7 @@ sealed trait Variable extends Expr { object Variable { implicit def ordering[V <: Variable]: Ordering[V] = Ordering.by(_.name) + implicit def catsOrdering[V <: Variable]: cats.kernel.Order[V] = cats.kernel.Order.by(_.name) } object Register { @@ -409,7 +426,7 @@ case class LocalVar(varName: String, override val irType: IRType, val index: Int } object LocalVar { - def unapply(l: LocalVar): Some[(String, IRType, Int)] = Some((l.name, l.irType, l.index)) + def unapply(l: LocalVar): Some[(String, IRType, Int)] = Some((l.varName, l.irType, l.index)) /** * Construct a LocalVar by infering its index from the provided name corresponding to [[LocalVar.name]]. diff --git a/src/main/scala/ir/Program.scala b/src/main/scala/ir/Program.scala index c9649e80a9..070d27f7a2 100644 --- a/src/main/scala/ir/Program.scala +++ b/src/main/scala/ir/Program.scala @@ -289,7 +289,8 @@ class Procedure private ( var requires: List[BExpr], var ensures: List[BExpr], var requiresExpr: List[Expr], - var ensuresExpr: List[Expr] + var ensuresExpr: List[Expr], + var loopInvariants: Map[Block, specification.LoopInvariant] ) extends Iterable[CFGPosition] with DeepEquality { @@ -327,7 +328,8 @@ class Procedure private ( List.from(requires), List.from(ensures), List(), - List() + List(), + Map() ) } @@ -625,6 +627,10 @@ class Block private ( this(label, IntrusiveList().addAll(statements), jump, mutable.LinkedHashSet.empty, Metadata(None, address)) } + def forwardIteratorFrom = { + ILForwardIterator(Seq(this), IntraProcIRCursor) + } + def address = meta.address override def deepEquals(b: Object): Boolean = b match { diff --git a/src/main/scala/ir/Statement.scala b/src/main/scala/ir/Statement.scala index 603986564c..65958e1d49 100644 --- a/src/main/scala/ir/Statement.scala +++ b/src/main/scala/ir/Statement.scala @@ -198,12 +198,16 @@ object MemoryLoad { * class's field in all methods of the subclass. */ class NOP(var label: Option[String] = None) extends Statement with Command { + def cloneStatement(): NOP = { + NOP(label) + } override def toString: String = s"NOP $labelStr" override def deepEquals(o: Object) = o match { case NOP(x) => x == label case _ => false } } + object NOP { def unapply(x: NOP) = Some(x.label) } diff --git a/src/main/scala/ir/cilvisitor/CILVisitor.scala b/src/main/scala/ir/cilvisitor/CILVisitor.scala index 371bc3dda4..41b3b79fb3 100644 --- a/src/main/scala/ir/cilvisitor/CILVisitor.scala +++ b/src/main/scala/ir/cilvisitor/CILVisitor.scala @@ -202,8 +202,12 @@ class CILVisitorImpl(val v: CILVisitor) { r match { case Nil => b.statements.remove(s) case n :: tl => - b.statements.replace(s, n) - b.statements.insertAllAfter(Some(n), tl) + if (n ne s) { + b.statements.replace(s, n) + b.statements.insertAllAfter(Some(n), tl) + } else { + b.statements.insertAllAfter(Some(n), tl) + } } } b.replaceJump(visit_jump(b.jump)) @@ -216,7 +220,7 @@ class CILVisitorImpl(val v: CILVisitor) { def visit_proc(p: Procedure): List[Procedure] = { def continue(p: Procedure) = { v.enter_scope(p.formalInParam) - for (b <- p.blocks) { + for (b <- p.blocks.toList) { p.replaceBlock(b, visit_block(b)) } v.leave_scope() @@ -258,4 +262,5 @@ def visit_stmt(v: CILVisitor, e: Statement): List[Statement] = CILVisitorImpl(v) def visit_jump(v: CILVisitor, e: Jump): Jump = CILVisitorImpl(v).visit_jump(e) def visit_expr(v: CILVisitor, e: Expr): Expr = CILVisitorImpl(v).visit_expr(e) def visit_rvar(v: CILVisitor, e: Variable): Variable = CILVisitorImpl(v).visit_rvar(e) +def visit_lvar(v: CILVisitor, e: Variable): Variable = CILVisitorImpl(v).visit_lvar(e) def visit_mem(v: CILVisitor, e: Memory): Memory = CILVisitorImpl(v).visit_mem(e) diff --git a/src/main/scala/ir/dsl/DSL.scala b/src/main/scala/ir/dsl/DSL.scala index cf23326f45..e4507a004b 100644 --- a/src/main/scala/ir/dsl/DSL.scala +++ b/src/main/scala/ir/dsl/DSL.scala @@ -72,7 +72,7 @@ def cloneStatement(x: NonCallStatement): NonCallStatement = case MemoryAssign(a, b, c) => MemoryAssign(a, b, c) case MemoryStore(a, b, c, d, e, f) => MemoryStore(a, b, c, d, e, f) case MemoryLoad(a, b, c, d, e, f) => MemoryLoad(a, b, c, d, e, f) - case NOP(l) => NOP(l) + case n: NOP => n.cloneStatement() case Assert(a, b, c) => Assert(a, b, c) case Assume(a, b, c, d) => Assume(a, b, c, d) case a: SimulAssign => SimulAssign(a.assignments, a.label) diff --git a/src/main/scala/ir/dsl/IRToDSL.scala b/src/main/scala/ir/dsl/IRToDSL.scala index c0e6da78f6..a3d0234d72 100644 --- a/src/main/scala/ir/dsl/IRToDSL.scala +++ b/src/main/scala/ir/dsl/IRToDSL.scala @@ -54,6 +54,10 @@ object IRToDSL { x.address ) + def cloneSingleProcedure(x: Procedure) = + val proc = convertProcedure(x) + EventuallyProgram(proc, Seq()).resolve + def convertProgram(x: Program) = val others = x.procedures.filter(_ != x.mainProcedure).map(convertProcedure) EventuallyProgram(convertProcedure(x.mainProcedure), others.to(ArraySeq), x.initialMemory.values) diff --git a/src/main/scala/ir/eval/SimplifyExpr.scala b/src/main/scala/ir/eval/SimplifyExpr.scala index 5f9e09d41f..f8bbb98154 100644 --- a/src/main/scala/ir/eval/SimplifyExpr.scala +++ b/src/main/scala/ir/eval/SimplifyExpr.scala @@ -43,12 +43,12 @@ class SimpExpr(simplifier: Simplifier) extends CILVisitor with Simplifier { var changedAnything = false var count = 0 - def simplify(e: Expr) = { + def simplify(e: Expr)(implicit line: sourcecode.Line, file: sourcecode.FileName, name: sourcecode.Name) = { val (n, changed) = simplifier(e) changedAnything = changedAnything || changed - if (changed) logSimp(e, n) + if (changed) logSimp(e, n)(line, file, name) (n, changed) } @@ -704,17 +704,21 @@ def simplifyCmpInequalities(e: Expr): (Expr, Boolean) = { logSimp(e, UnaryExpr(BoolNOT, BinaryExpr(BVUGT, x1, UnaryExpr(BVNOT, BinaryExpr(BVADD, y1, z1))))) } - case BinaryExpr( - EQ, - extended @ ZeroExtend(exts, orig @ BinaryExpr(o1, x1, z1)), - BinaryExpr(o2, compar @ ZeroExtend(ext2, BinaryExpr(o4, x2, y2)), z2) - ) - if exts == ext2 && size(x1).get >= 8 && (o1 == o2) && o2 == o4 && o1 == BVADD - && simplifyCond(BinaryExpr(o1, ZeroExtend(exts, x1), ZeroExtend(exts, z1))) - == simplifyCond(BinaryExpr(BVADD, ZeroExtend(exts, x2), (BinaryExpr(BVADD, ZeroExtend(exts, y2), z2)))) => { - // C not Set - logSimp(e, UnaryExpr(BoolNOT, BinaryExpr(BVUGT, x1, UnaryExpr(BVNOT, z1)))) - } + // broken + // (declare-const Var2 (_ BitVec 64)) + // (assert (! (not (= (or (not (= (bvnot (bool2bv1 (= (concat (_ bv0 64) (bvadd Var2 (_ bv18446744073705224440 64))) (bvadd (concat (_ bv0 64) (bvadd Var2 (_ bv8 64))) (_ bv18446744073705224432 128))))) (_ bv1 1))) (= (bvadd Var2 (_ bv18446744073705224440 64)) (_ bv0 64))) (bvule Var2 (_ bv4327176 64)))) :named simp.209SimplifyExpr.scala..62)) + // (check-sat) + // case BinaryExpr( + // EQ, + // extended @ ZeroExtend(exts, orig @ BinaryExpr(o1, x1, z1)), + // BinaryExpr(o2, compar @ ZeroExtend(ext2, BinaryExpr(o4, x2, y2)), z2) + // ) + // if exts == ext2 && size(x1).get >= 8 && (o1 == o2) && o2 == o4 && o1 == BVADD + // && simplifyCond(BinaryExpr(o1, ZeroExtend(exts, x1), ZeroExtend(exts, z1))) + // == simplifyCond(BinaryExpr(BVADD, ZeroExtend(exts, x2), (BinaryExpr(BVADD, ZeroExtend(exts, y2), z2)))) => { + // // C not Set + // logSimp(e, UnaryExpr(BoolNOT, BinaryExpr(BVUGT, x1, UnaryExpr(BVNOT, z1)))) + // } case BinaryExpr( EQ, @@ -732,30 +736,34 @@ def simplifyCmpInequalities(e: Expr): (Expr, Boolean) = { logSimp(e, UnaryExpr(BoolNOT, BinaryExpr(BVUGE, x1, y1))) } - case BinaryExpr( - EQ, - ZeroExtend(exts, orig @ BinaryExpr(BVADD, x1, y1: BitVecLiteral)), - BinaryExpr(BVADD, ZeroExtend(ext1, BinaryExpr(BVADD, x2, y3neg: BitVecLiteral)), y4neg: BitVecLiteral) - ) - if size(x1).get >= 8 - && exts == ext1 - && simplifyCond(UnaryExpr(BVNEG, y1)) - == simplifyCond( - BinaryExpr(BVADD, UnaryExpr(BVNEG, y3neg), UnaryExpr(BVNEG, Extract(size(y4neg).get - exts, 0, y4neg))) - ) - && simplifyCond(ZeroExtend(exts, Extract(size(y4neg).get - exts, 0, y4neg))) == y4neg - && { - val l = simplifyCond(BinaryExpr(BVSUB, UnaryExpr(BVNEG, y1), UnaryExpr(BVNEG, (y3neg)))) - val r = simplifyCond(UnaryExpr(BVNEG, Extract(size(y4neg).get - exts, 0, y4neg))) - l == r - } - && x1 == x2 => { - // somehow we get three-way inequality - logSimp( - e, - BinaryExpr(BoolAND, BinaryExpr(BVULT, x1, UnaryExpr(BVNEG, y1)), BinaryExpr(BVUGE, x1, UnaryExpr(BVNEG, y3neg))) - ) - } + // fails verification + // (assert (! (not (= (= (concat (_ bv0 64) (bvadd Var2 (_ bv18446744073705224440 64))) (bvadd (concat (_ bv0 64) (bvadd Var2 (_ bv8 64))) (_ bv18446744073705224432 128))) (and (bvult Var2 (bvneg (_ bv18446744073705224440 64))) (bvuge Var2 (bvneg (_ bv8 64)))))) :named simp.197SimplifyExpr.scala..762)) + // (check-sat) + // (echo "simp.197SimplifyExpr.scala..762 :: boolnot(eq(eq(zero_extend(64, bvadd(Var2:bv64, 0xffffffffffbdf8f8:bv64)), bvadd(zero_extend(64, bvadd(Var2:bv64, 0x8:bv64)), 0xffffffffffbdf8f0:bv128)), booland(bvult(Var2:bv64, bvneg(0xffffffffffbdf8f8:bv64)), bvuge(Var2:bv64, bvneg(0x8:bv64)))))") + // case BinaryExpr( + // EQ, + // ZeroExtend(exts, orig @ BinaryExpr(BVADD, x1, y1: BitVecLiteral)), + // BinaryExpr(BVADD, ZeroExtend(ext1, BinaryExpr(BVADD, x2, y3neg: BitVecLiteral)), y4neg: BitVecLiteral) + // ) + // if size(x1).get >= 8 + // && exts == ext1 + // && simplifyCond(UnaryExpr(BVNEG, y1)) + // == simplifyCond( + // BinaryExpr(BVADD, UnaryExpr(BVNEG, y3neg), UnaryExpr(BVNEG, Extract(size(y4neg).get - exts, 0, y4neg))) + // ) + // && simplifyCond(ZeroExtend(exts, Extract(size(y4neg).get - exts, 0, y4neg))) == y4neg + // && { + // val l = simplifyCond(BinaryExpr(BVSUB, UnaryExpr(BVNEG, y1), UnaryExpr(BVNEG, (y3neg)))) + // val r = simplifyCond(UnaryExpr(BVNEG, Extract(size(y4neg).get - exts, 0, y4neg))) + // l == r + // } + // && x1 == x2 => { + // // somehow we get three-way inequality + // logSimp( + // e, + // BinaryExpr(BoolAND, BinaryExpr(BVULT, x1, UnaryExpr(BVNEG, y1)), BinaryExpr(BVUGE, x1, UnaryExpr(BVNEG, y3neg))) + // ) + // } /* generic comparison simplification */ // redundant inequality diff --git a/src/main/scala/ir/invariant/ReadUninitialised.scala b/src/main/scala/ir/invariant/ReadUninitialised.scala new file mode 100644 index 0000000000..c54058a6ae --- /dev/null +++ b/src/main/scala/ir/invariant/ReadUninitialised.scala @@ -0,0 +1,74 @@ +package ir.invariant + +import ir.* +import util.Logger + +import scala.collection.mutable + +class ReadUninitialised() { + + var init = Set[Variable]() + var readUninit = List[(Command, Set[Variable])]() + + final def check(a: Command) = { + val free = freeVarsPos(a).filter(_.isInstanceOf[LocalVar]) -- init + if (free.size > 0) { + readUninit = (a, free) :: readUninit + } + } + + final def readUninitialised(b: Iterable[Statement]): Boolean = { + val i = readUninit.size + b.foreach { + case a: Assign => { + check(a) + init = init ++ a.assignees + } + case o => { + check(o) + } + } + i != readUninit.size + } + + final def readUninitialised(b: Block): Boolean = { + val i = readUninit.size + readUninitialised(b.statements) + check(b.jump) + i != readUninit.size + } + + final def getResult(): Option[String] = { + if (readUninit.size > 0) { + val msg = readUninit + .map { case (s, vars) => + s" ${vars.mkString(", ")} uninitialised in statement $s in block ${s.parent}" + } + .mkString("\n") + Some(msg) + } else { + None + } + } + + final def readUninitialised(p: Procedure): Boolean = { + init = init ++ p.formalInParam + + ir.transforms.reversePostOrder(p) + + val worklist = mutable.PriorityQueue[Block]()(Ordering.by(_.rpoOrder)) + worklist.addAll(p.blocks) + + while (worklist.nonEmpty) { + val b = worklist.dequeue() + readUninitialised(b) + } + getResult().map(e => Logger.error(p.name + "\n" + e)).isDefined + } + +} + +def readUninitialised(p: Program): Boolean = { + val r = p.procedures.map(p => ReadUninitialised().readUninitialised(p)) + r.forall(x => !x) +} diff --git a/src/main/scala/ir/transforms/AbsInt.scala b/src/main/scala/ir/transforms/AbsInt.scala index 203f53a030..c73620a88a 100644 --- a/src/main/scala/ir/transforms/AbsInt.scala +++ b/src/main/scala/ir/transforms/AbsInt.scala @@ -317,7 +317,7 @@ class interprocSummaryFixpointSolver[SummaryAbsVal, LocalAbsVal, A <: AbstractDo */ class BottomUpCallgraphWorklistSolver[L](transferProcedure: (Procedure => L, L, Procedure) => L, init: Procedure => L) { - def solve(p: Program) = { + def solve(p: Program): Map[Procedure, L] = { var old_summaries = Map[Procedure, L]() var summaries = Map[Procedure, L]() diff --git a/src/main/scala/ir/transforms/CalleePreservedParam.scala b/src/main/scala/ir/transforms/CalleePreservedParam.scala index c1b4e7ba2a..3fcf5ca6d9 100644 --- a/src/main/scala/ir/transforms/CalleePreservedParam.scala +++ b/src/main/scala/ir/transforms/CalleePreservedParam.scala @@ -6,6 +6,8 @@ import cilvisitor.* object CalleePreservedParam { + var counter = util.Counter() + /** * Asusming single-return and parameter form */ @@ -24,14 +26,18 @@ object CalleePreservedParam { } } - def transform(p: Program) = { + // returns labels of injected assertions + def transform(p: Program): List[String] = { val v = MakePreserved() visit_prog(v, p) debugAssert(invariant.correctCalls(p)) + v.addedAsserts } class MakePreserved extends CILVisitor { + var addedAsserts = List[String]() + override def vproc(p: Procedure) = { for (param <- p.formalOutParam.filter(isPreservedParam)) { @@ -57,9 +63,11 @@ object CalleePreservedParam { } case o => ??? } - b.statements.append( - Assert(BinaryExpr(EQ, input, out), Some(s"${param.name.stripSuffix("_out")} preserved across calls")) - ) + var l = Some(s"callerpreserved${counter.next()}") + val assert = + Assert(BinaryExpr(EQ, input, out), Some(s"${param.name.stripSuffix("_out")} preserved across calls"), l) + b.statements.append(assert) + addedAsserts = l.get :: addedAsserts }) } SkipChildren() diff --git a/src/main/scala/ir/transforms/CountGuardStatements.scala b/src/main/scala/ir/transforms/CountGuardStatements.scala new file mode 100644 index 0000000000..5ddfcd5d2d --- /dev/null +++ b/src/main/scala/ir/transforms/CountGuardStatements.scala @@ -0,0 +1,81 @@ +package ir.transforms + +import ir.* +import ir.cilvisitor.* +import util.tvEvalLogger + +import scala.util.chaining.scalaUtilChainingOps + +object CountGuardStatements { + + enum GuardComplexity { + case HasFlagRegisters(regs: Set[GlobalVar]) + case OpsTooMany(ops: List[BinOp | UnOp]) + case OpsTooComplex(ops: List[BoolToBV1.type]) + case VarsTooMany(vars: Set[Variable]) + } + + def listOpsInExpr(e: Expr): List[BinOp | UnOp] = e match { + case AssocExpr(op, es) => op +: es.flatMap(listOpsInExpr) + case BinaryExpr(op, x, y) => op +: listOpsInExpr(x) ++: listOpsInExpr(y) + case UnaryExpr(op, x) => op +: listOpsInExpr(x) + case _ => List() + } + + def classifyGuard(s: Assume): List[GuardComplexity] = { + val vars = s.body.variables + + val flags = vars.collect { case reg @ Register(_, 1) => + reg + } + + val ops = listOpsInExpr(s.body) + + val complexOps = ops.collect { case x @ BoolToBV1 => + BoolToBV1 + } + + (Option.when(flags.nonEmpty) { GuardComplexity.HasFlagRegisters(flags) } + + // allowable number of operations is `2 * vars + 1` to allow for + // one possibly-negated binary operation per variable. plus one top-level + // operation like ==. + ++ Option.when(ops.size > 2 * vars.size + 1) { GuardComplexity.OpsTooMany(ops) } + + ++ Option.when(complexOps.nonEmpty) { GuardComplexity.OpsTooComplex(complexOps) } + + ++ Option.when(vars.size > 2) { GuardComplexity.VarsTooMany(vars) }).toList + } +} + +class CountGuardStatements extends CILVisitor { + import CountGuardStatements.* + + var guards: List[(Assume, List[GuardComplexity])] = Nil + + override def vstmt(s: Statement) = { + s match { + case ass: Assume if ass.checkSecurity => + guards = (ass -> classifyGuard(ass)) +: guards + case _ => () + } + SkipChildren() + } + + def reportToLog(label: String, log: String => Unit = tvEvalLogger.debug(_)) = { + val entries = guards.flatMap { (k, vs) => vs.map(k -> _) } + val groupedByViolation = entries.groupMapReduce(_._2.productPrefix)(_._1.pipe(List(_)))(_ ++ _) + + log(s"tv-eval-marker: $label-guard-total-count=" + guards.size) + log(s"tv-eval-marker: $label-guard-simple-count=" + guards.count(_._2.isEmpty)) + log(s"tv-eval-marker: $label-guard-complex-unique-count=" + guards.count(_._2.nonEmpty)) + + groupedByViolation.foreach { (violation, guards) => + log(s"tv-eval-marker: $label-guard-complex-$violation-count=" + guards.size) + } + + if (label == "after") { + println(guards.filter(_._2.nonEmpty)) + } + } +} diff --git a/src/main/scala/ir/transforms/CountStatements.scala b/src/main/scala/ir/transforms/CountStatements.scala new file mode 100644 index 0000000000..981be50437 --- /dev/null +++ b/src/main/scala/ir/transforms/CountStatements.scala @@ -0,0 +1,13 @@ +package ir.transforms + +import ir.* +import ir.cilvisitor.* + +class CountStatements extends CILVisitor { + var count = 0 + + override def vstmt(s: Statement) = { + count += 1 + SkipChildren() + } +} diff --git a/src/main/scala/ir/transforms/DynamicSingleAssignment.scala b/src/main/scala/ir/transforms/DynamicSingleAssignment.scala index a205b97c39..949b29a692 100644 --- a/src/main/scala/ir/transforms/DynamicSingleAssignment.scala +++ b/src/main/scala/ir/transforms/DynamicSingleAssignment.scala @@ -107,10 +107,12 @@ class OnePassDSA( } - def applyTransform(p: Program): Unit = { - for (proc <- p.procedures) { - applyTransform(proc) - } + def applyTransform(p: Program): Map[String, result] = { + p.procedures + .map(proc => { + proc.name -> applyTransform(proc) + }) + .toMap } def createBlockBetween(b1: Block, b2: Block, label: String = "phi"): Block = { @@ -280,7 +282,22 @@ class OnePassDSA( fixSuccessors(_st, count, liveBefore, liveAfter, block) } - def applyTransform(p: Procedure): Unit = { + type result = Map[String, (Map[Variable, Variable], Map[Variable, Variable])] + + def stToResult(_st: mutable.Map[Block, BlockState]): result = { + _st.toMap.map((bl: Block, bs: BlockState) => { + bl.label -> ( + bs.renamesBefore.toMap.map { case (v, idx) => + (v, visit_rvar(StmtRenamer(Map(), Map(v -> idx)), v)) + }, + bs.renamesAfter.toMap.map { case (v, idx) => + (v, visit_rvar(StmtRenamer(Map(), Map(v -> idx)), v)) + } + ) + }) + } + + def applyTransform(p: Procedure): result = { val _st = mutable.Map[Block, BlockState]() // ensure order is defined ir.transforms.reversePostOrder(p) @@ -299,7 +316,7 @@ class OnePassDSA( val worklist = mutable.PriorityQueue[Block]()(Ordering.by(b => b.rpoOrder)) worklist.addAll(p.blocks) var seen = Set[Block]() - val count = mutable.Map[Variable, Int]().withDefaultValue(0) + val count = mutable.Map[Variable, Int]().withDefaultValue(p.ssaCount) while (worklist.nonEmpty) { while (worklist.nonEmpty) { @@ -335,6 +352,7 @@ class OnePassDSA( reversePostOrder(p) + stToResult(_st) } } diff --git a/src/main/scala/ir/transforms/Inline.scala b/src/main/scala/ir/transforms/Inline.scala index 3cb11c3271..9421ec601e 100644 --- a/src/main/scala/ir/transforms/Inline.scala +++ b/src/main/scala/ir/transforms/Inline.scala @@ -28,11 +28,7 @@ def renameBlock(s: String): String = { class VarRenamer(proc: Procedure) extends CILVisitor { def doRename(v: Variable): Variable = v match { - case l: LocalVar if l.name.endsWith("_in") => { - val name = l.name.stripSuffix("_in") - proc.getFreshSSAVar(name, l.getType) - } - case l: LocalVar if l.index != 0 => + case l: LocalVar => proc.getFreshSSAVar(l.varName, l.getType) case _ => v } diff --git a/src/main/scala/ir/transforms/PCTracking.scala b/src/main/scala/ir/transforms/PCTracking.scala index bf05778b59..3cc800f115 100644 --- a/src/main/scala/ir/transforms/PCTracking.scala +++ b/src/main/scala/ir/transforms/PCTracking.scala @@ -40,6 +40,10 @@ object PCTracking { val pcRequires = BinaryExpr(ir.EQ, pcVar, addrVar) val pcEnsures = BinaryExpr(ir.EQ, pcVar, OldExpr(r30Var)) + proc.entryBlock.foreach(b => { + b.statements.prepend(LocalAssign(pcVar, addrVar)) + b.statements.prepend(Assert(BinaryExpr(EQ, pcVar, addrVar))) + }) proc.requiresExpr = pcRequires +: proc.requiresExpr proc.ensuresExpr = pcEnsures +: proc.ensuresExpr diff --git a/src/main/scala/ir/transforms/ProcedureParameters.scala b/src/main/scala/ir/transforms/ProcedureParameters.scala index a4c7453f14..a73be331eb 100644 --- a/src/main/scala/ir/transforms/ProcedureParameters.scala +++ b/src/main/scala/ir/transforms/ProcedureParameters.scala @@ -4,7 +4,7 @@ import ir.cilvisitor.* import ir.{CallGraph, *} import specification.Specification import translating.PrettyPrinter -import util.{DebugDumpIRLogger, Logger} +import util.DebugDumpIRLogger import java.io.File import scala.collection.{immutable, mutable} @@ -112,20 +112,16 @@ object DefinedOnAllPaths { } def liftProcedureCallAbstraction(ctx: ir.IRContext): ir.IRContext = { + val ns = liftProcedureCallAbstraction(ctx.program, Some(ctx.specification)).get + ctx.copy(specification = ns) +} - transforms.clearParams(ctx.program) +def liftProcedureCallAbstraction(program: Program, spec: Option[Specification]): Option[Specification] = { - val mainNonEmpty = ctx.program.mainProcedure.blocks.nonEmpty - val mainHasReturn = ctx.program.mainProcedure.returnBlock.isDefined - val mainHasEntry = ctx.program.mainProcedure.entryBlock.isDefined + transforms.clearParams(program) - val liveVars = if (mainNonEmpty && mainHasEntry && mainHasReturn) { - analysis.InterLiveVarsAnalysis(ctx.program).analyze() - } else { - Logger.error(s"Empty live vars $mainNonEmpty $mainHasReturn $mainHasEntry") - Map.empty - } - transforms.applyRPO(ctx.program) + val liveVars = analysis.interLiveVarsAnalysis(program) + transforms.applyRPO(program) val liveLab = () => liveVars.collect { case (b: Block, r) => @@ -144,27 +140,28 @@ def liftProcedureCallAbstraction(ctx: ir.IRContext): ir.IRContext = { DebugDumpIRLogger.writeToFile( File(s"live-vars.il"), - PrettyPrinter.pp_prog_with_analysis_results(liveLab(), Map(), ctx.program, x => x) + PrettyPrinter.pp_prog_with_analysis_results(liveLab(), Map(), program, x => x) ) - val params = inOutParams(ctx.program, liveVars) + val params = inOutParams(program, liveVars) // functions for which we don't know their behaviour and assume they modify all registers - val external = ctx.externalFunctions.map(_.name) ++ ctx.program.collect { + val external = program.collect { case b: Procedure if b.blocks.isEmpty => b.name - } + }.toSet val formalParams = SetFormalParams(params, external) - visit_prog(formalParams, ctx.program) + visit_prog(formalParams, program) val actualParams = SetActualParams(formalParams.mappingInparam, formalParams.mappingOutparam, external) - visit_prog(actualParams, ctx.program) + visit_prog(actualParams, program) - while (removeDeadInParams(ctx.program)) {} + while (removeDeadInParams(program)) {} - ctx.program.procedures.foreach(SpecFixer.updateInlineSpec(formalParams.mappingInparam, formalParams.mappingOutparam)) - ctx.copy(specification = - SpecFixer.specToProcForm(ctx.specification, formalParams.mappingInparam, formalParams.mappingOutparam) - ) + program.procedures.foreach(SpecFixer.updateInlineSpec(formalParams.mappingInparam, formalParams.mappingOutparam)) + + spec.map(s => { + SpecFixer.specToProcForm(s, formalParams.mappingInparam, formalParams.mappingOutparam) + }) } def clearParams(p: Program) = { diff --git a/src/main/scala/ir/transforms/ReplaceReturn.scala b/src/main/scala/ir/transforms/ReplaceReturn.scala index 5b1aafd09e..77f060cdb8 100644 --- a/src/main/scala/ir/transforms/ReplaceReturn.scala +++ b/src/main/scala/ir/transforms/ReplaceReturn.scala @@ -13,6 +13,19 @@ class ReplaceReturns(insertR30InvariantAssertion: Procedure => Boolean = _ => tr private val R30procedures: mutable.Set[Procedure] = mutable.Set() + var tailCalls = Map[Procedure, Block]() + + def tailCallBlock(p: Procedure) = { + tailCalls.get(p) match { + case Some(b) => b + case None => { + val b = p.entryBlock.get.createBlockAfter("tailcallheader") + tailCalls = tailCalls + (p -> b) + b + } + } + } + /** Assumes IR with 1 call per block which appears as the last statement. */ override def vstmt(j: Statement): VisitAction[List[Statement]] = { @@ -55,8 +68,8 @@ class ReplaceReturns(insertR30InvariantAssertion: Procedure => Boolean = _ => tr // If we can't find one case _: Unreachable => if (d.target == procedure) { - // recursive tail call - d.parent.replaceJump(GoTo(procedure.entryBlock.get)) + // recursive tail call (entryBlock not allowed to have predecessors) + d.parent.replaceJump(GoTo(tailCallBlock(procedure))) } else { // non-recursive tail call d.parent.replaceJump(Return()) diff --git a/src/main/scala/ir/transforms/Simp.scala b/src/main/scala/ir/transforms/Simp.scala index df0cfbee2e..acc8bb2167 100644 --- a/src/main/scala/ir/transforms/Simp.scala +++ b/src/main/scala/ir/transforms/Simp.scala @@ -140,7 +140,9 @@ class IntraLiveVarsDomain extends PowerSetDomain[Variable] { case a: Assume => s ++ a.body.variables case a: Assert => s ++ a.body.variables case i: IndirectCall => s + i.target - case c: DirectCall => (s -- c.outParams.map(_._2)) ++ c.actualParams.flatMap(_._2.variables) + case c: DirectCall => { + s -- c.outParams.map(_._2) ++ c.actualParams.flatMap(_._2.variables) + } case g: GoTo => s case r: Return => s ++ r.outParams.flatMap(_._2.variables) case r: Unreachable => s @@ -667,6 +669,8 @@ class GuardVisitor(validate: Boolean = false) extends CILVisitor { DoChildren() } + var replaced = Map[Variable, Expr]() + def substitute(pos: Command)(v: Variable): Option[Expr] = { if (goodSubst(v)) { val res = defs.get(v).getOrElse(Set()) @@ -689,6 +693,7 @@ class GuardVisitor(validate: Boolean = false) extends CILVisitor { if (validate) { debugAssert(propOK(rhs)) } + replaced = replaced + (v -> rhs) Some(rhs) } case o => { @@ -723,12 +728,7 @@ def simplifyCFG(p: Procedure) = { removeEmptyBlocks(p) } -def copypropTransform( - p: Procedure, - procFrames: Map[Procedure, Set[Memory]], - funcEntries: Map[BigInt, Procedure], - constRead: (BigInt, Int) => Option[BitVecLiteral] -) = { +def copypropTransform(p: Procedure, procFrames: Map[Procedure, Set[Memory]]) = { val t = util.PerformanceTimer(s"simplify ${p.name} (${p.blocks.size} blocks)") // SimplifyLogger.info(s"${p.name} ExprComplexity ${ExprComplexity()(p)}") // val result = solver.solveProc(p, true).withDefaultValue(dom.bot) @@ -852,7 +852,11 @@ def coalesceBlocks(proc: Procedure): Boolean = { val stmts = b.statements.map(b.statements.remove).toList nextBlock.statements.prependAll(stmts) // leave empty block b and cleanup with removeEmptyBlocks - } else if (b.jump.isInstanceOf[Unreachable] && b.statements.isEmpty && b.prevBlocks.size == 1) { + } else if ( + b.jump.isInstanceOf[ + Unreachable + ] && b.statements.isEmpty && b.prevBlocks.size == 1 && b.prevBlocks.head.nextBlocks.size == 1 + ) { b.prevBlocks.head.replaceJump(Unreachable()) b.parent.removeBlocks(b) } @@ -1054,7 +1058,7 @@ def cleanupBlocks(p: Program) = { } } -def doCopyPropTransform(p: Program, rela: Map[BigInt, BigInt]) = { +def doCopyPropTransform(p: Program) = { applyRPO(p) @@ -1068,8 +1072,6 @@ def doCopyPropTransform(p: Program, rela: Map[BigInt, BigInt]) = { val procFrames = getProcFrame.solveInterproc(p) - val addrToProc = p.procedures.toSeq.flatMap(p => p.address.map(addr => addr -> p).toSeq).toMap - def read(addr: BigInt, size: Int): Option[BitVecLiteral] = { val rodata = p.initialMemory.filter((_, s) => s.readOnly) rodata.maxBefore(addr + 1) match { @@ -1097,7 +1099,7 @@ def doCopyPropTransform(p: Program, rela: Map[BigInt, BigInt]) = { { SimplifyLogger .debug(s"CopyProp Transform ${p.name} (${p.blocks.size} blocks, expr complexity ${ExprComplexity()(p)})") - copypropTransform(p, procFrames, addrToProc, read) + copypropTransform(p, procFrames) } ) @@ -1125,9 +1127,9 @@ def doCopyPropTransform(p: Program, rela: Map[BigInt, BigInt]) = { } -def copyPropParamFixedPoint(p: Program, rela: Map[BigInt, BigInt]): Int = { +def copyPropParamFixedPoint(p: Program): Int = { SimplifyLogger.info(s"Simplify:: Copyprop iteration 0") - doCopyPropTransform(p, rela) + doCopyPropTransform(p) var inlinedOutParams: Map[Procedure, Set[Variable]] = removeInvariantOutParameters(p) var changed = inlinedOutParams.nonEmpty var iterations = 1 @@ -1136,7 +1138,7 @@ def copyPropParamFixedPoint(p: Program, rela: Map[BigInt, BigInt]): Int = { changed = false SimplifyLogger.info(s"Simplify:: Copyprop iteration $iterations") transforms.removeTriviallyDeadBranches(p) - doCopyPropTransform(p, rela) + doCopyPropTransform(p) val extraInlined = removeInvariantOutParameters(p, inlinedOutParams) inlinedOutParams = extraInlined.foldLeft(inlinedOutParams)((acc, v) => acc + (v._1 -> (acc.getOrElse(v._1, Set[Variable]()) ++ v._2)) @@ -1244,48 +1246,52 @@ object OffsetProp { // None, Some(Lit) -> Lit type Value = (Option[Variable], Option[BitVecLiteral]) - def joinValue(l: Value, r: Value) = { - (l, r) match { - case ((None, None), _) => (None, None) - case (_, (None, None)) => (None, None) - case (l, r) if l != r => (None, None) - case (l, r) => l - } - } - class CopyProp() { val st = mutable.Map[Variable, Value]() var giveUp = false val lastUpdate = mutable.Map[Block, Int]() var stSequenceNo = 1 - def findOff(v: Variable, c: BitVecLiteral): BitVecLiteral | Variable | BinaryExpr = find(v) match { - case lc: BitVecLiteral => ir.eval.BitVectorEval.smt_bvadd(lc, c) - case lv: Variable => BinaryExpr(BVADD, lv, c) - case BinaryExpr(BVADD, l: Variable, r: BitVecLiteral) => - BinaryExpr(BVADD, l, ir.eval.BitVectorEval.smt_bvadd(r, c)) - case _ => throw Exception("Unexpected expression structure created by find() at some point") - } + def eval(c: BitVecLiteral)(v: BitVecLiteral | Variable | BinaryExpr): BitVecLiteral | Variable | BinaryExpr = + v match { + case lc: BitVecLiteral => ir.eval.BitVectorEval.smt_bvadd(lc, c) + case lv: Variable => BinaryExpr(BVADD, lv, c) + case BinaryExpr(BVADD, l: Variable, r: BitVecLiteral) => + BinaryExpr(BVADD, l, ir.eval.BitVectorEval.smt_bvadd(r, c)) + case _ => throw Exception("Unexpected expression structure created by find() at some point") + } - def find(v: Variable): BitVecLiteral | Variable | BinaryExpr = { - st.get(v) match { - case None => v - case Some((None, None)) => v - case Some((None, Some(c))) => c - case Some((Some(v), None)) => find(v) - case Some((Some(v), Some(c))) => findOff(v, c) + def findOff(v: Variable, c: BitVecLiteral, fuel: Int = 10000): BitVecLiteral | Variable | BinaryExpr = + find(v, fuel) match { + case lc: BitVecLiteral => ir.eval.BitVectorEval.smt_bvadd(lc, c) + case lv: Variable => BinaryExpr(BVADD, lv, c) + case BinaryExpr(BVADD, l: Variable, r: BitVecLiteral) => + BinaryExpr(BVADD, l, ir.eval.BitVectorEval.smt_bvadd(r, c)) + case _ => throw Exception("Unexpected expression structure created by find() at some point") } - } - def joinState(lhs: Variable, rhs: Expr) = { - specJoinState(lhs, rhs) match { - case Some((l, r)) => { - if (st.contains(l) && st(l) != r) { - stSequenceNo += 1 + def find(v: Variable, fuel: Int = 10000): BitVecLiteral | Variable | BinaryExpr = { + if (fuel == 0) { + var chain = List(v) + for (i <- 0 to 10) { + chain = st.get(chain.head) match { + case Some((Some(v: Variable), _)) => v :: chain + case o => + chain } - st(l) = r } - case _ => () + + update(v, (None, None)) + SimplifyLogger.error( + s"Ran out of fuel recursively resolving copyprop (at $v): probable cycle. Next lookups are: $chain" + ) + } + st.get(v) match { + case None => v + case Some((None, None)) => v + case Some((None, Some(c))) => c + case Some((Some(v), None)) => find(v, fuel - 1) + case Some((Some(v), Some(c))) => findOff(v, c, fuel - 1) } } @@ -1302,8 +1308,11 @@ object OffsetProp { } } - def clob(v: Variable) = { - st(v) = (None, None) + def update(v: Variable, r: Value) = { + if (!st.get(v).exists(_ == r)) { + stSequenceNo += 1 + st(v) = r + } } def transfer(s: Statement) = s match { @@ -1316,11 +1325,11 @@ object OffsetProp { case (l: Variable, _) => Seq(l -> (None, None)) } .foreach { case (l, r) => - st(l) = r + update(l, r) } case a: Assign => { // memoryload and DirectCall - a.assignees.foreach(clob) + a.assignees.foreach(v => update(v, (None, None))) } case _: MemoryStore => () case _: NOP => () @@ -1365,7 +1374,7 @@ object OffsetProp { class SubstExprs(subst: Map[Variable, Expr]) extends CILVisitor { override def vexpr(e: Expr) = { - Substitute(subst.get)(e) match { + Substitute(subst.get, false)(e) match { case Some(n) => ChangeTo(n) case _ => SkipChildren() } @@ -1660,12 +1669,7 @@ object CopyProp { c.keys.filter(isMemVar).foreach(v => clobberFull(c, v)) } - def DSACopyProp( - p: Procedure, - procFrames: Map[Procedure, Set[Memory]], - funcEntries: Map[BigInt, Procedure], - constRead: (BigInt, Int) => Option[BitVecLiteral] - ) = { + def DSACopyProp(p: Procedure, procFrames: Map[Procedure, Set[Memory]]) = { val updated = false val state = mutable.HashMap[Variable, PropState]() var poisoned = false // we have an indirect call @@ -1749,27 +1753,6 @@ object CopyProp { // need a reaching-defs to get inout args (just assume register name matches?) // this reduce we have to clobber with the indirect call this round poisoned = true - val r = for { - (addr, deps) <- canPropTo(c, x.target) - addr <- addr match { - case b: BitVecLiteral => Some(b.value) - case _ => None - } - proc <- funcEntries.get(addr) - } yield (proc, deps) - - r match { - case Some(target, deps) => { - SimplifyLogger.info("Resolved indirect call") - } - case None => { - for ((i, v) <- c) { - v.clobbered = true - } - poisoned = true - } - } - } case _ => () } diff --git a/src/main/scala/ir/transforms/SimplifyPipeline.scala b/src/main/scala/ir/transforms/SimplifyPipeline.scala index 0a23d884b9..ca8cbe97c0 100644 --- a/src/main/scala/ir/transforms/SimplifyPipeline.scala +++ b/src/main/scala/ir/transforms/SimplifyPipeline.scala @@ -56,6 +56,7 @@ def doSimplify(ctx: IRContext, config: Option[StaticAnalysisConfig]): Unit = { transforms.OnePassDSA().applyTransform(program) + assert(ir.invariant.readUninitialised(ctx.program)) // fixme: this used to be a plain function but now we have to supply an analysis manager! transforms.inlinePLTLaunchpad(ctx, AnalysisManager(ctx.program)) @@ -113,7 +114,8 @@ def doSimplify(ctx: IRContext, config: Option[StaticAnalysisConfig]): Unit = { } } Logger.info("Copyprop Start") - transforms.copyPropParamFixedPoint(program, ctx.globalOffsets) + transforms.copyPropParamFixedPoint(program) + assert(ir.invariant.readUninitialised(ctx.program)) transforms.fixupGuards(program) transforms.removeDuplicateGuard(program) diff --git a/src/main/scala/ir/transforms/StackPreservationSpecification.scala b/src/main/scala/ir/transforms/StackPreservationSpecification.scala new file mode 100644 index 0000000000..05bc610450 --- /dev/null +++ b/src/main/scala/ir/transforms/StackPreservationSpecification.scala @@ -0,0 +1,133 @@ +package ir.transforms +import ir.* + +/* + * Generate specification to that ensures the caller's stack is preserved. + * + * Assume fully simplified IR as we coarsely attempt to infer the maximum size of a stack + * allocation syntactically. + * + * Attempts to interprocedurally compute a bound on the stack allocation for each procedure. + * We impose a tight stack limit to widen on recursive calls, and give up on emitting + * a spec in this case. + * + */ + +enum StackAlloc { + case Max(local: BigInt, call: BigInt) + case Top + + def join(o: StackAlloc) = { + (this, o) match { + case (Top, _) => Top + case (_, Top) => Top + case (Max(local, call), Max(local2, call2)) => Max(local.max(local2), call.max(call2)) + + } + } + + def bot = Max(0, 0) + def top = Top + + def asCall = this match { + case Top => Top + case Max(local, call) => Max(0, local + call) + } +} + +def getMaxAllocation(readProc: Procedure => StackAlloc, init: StackAlloc, p: Procedure) = { + val SP = LocalVar("R31_in", BitVecType(64)) + + // widening + val stackLimit = 10000 + + p.blocks.toSeq + .flatMap(_.statements) + .foldLeft(init)((accIn, b) => + // wideniing + // + val acc = accIn.asCall match { + case StackAlloc.Top => StackAlloc.Top + case StackAlloc.Max(l, _) if l > stackLimit => StackAlloc.Top + case _ => accIn + } + + b match { + case d: DirectCall => { + + val p = d.actualParams.get(SP) match { + case Some(BinaryExpr(BVADD, SP, b: BitVecLiteral)) if ir.eval.BitVectorEval.isNegative(b) => + StackAlloc.Max(ir.eval.BitVectorEval.smt_bvneg(b).value, 0) + case _ => StackAlloc.Max(0, 0) + } + + acc.join(p).join(readProc(d.target).asCall) + } + case MemoryStore(m, BinaryExpr(BVADD, SP, off @ BitVecLiteral(v, sz)), _, _, _, _) + if ir.eval.BitVectorEval.isNegative(off) => { + acc.join(StackAlloc.Max(ir.eval.BitVectorEval.smt_bvneg(off).value, 0)) + } + case MemoryStore(m, BinaryExpr(BVSUB, SP, off @ BitVecLiteral(v, sz)), _, _, _, _) + if !ir.eval.BitVectorEval.isNegative(off) => { + acc.join(StackAlloc.Max(off.value, 0)) + } + case _ => acc + } + ) +} + +def callGraphSolve(p: Program) = { + val solver = ir.transforms.BottomUpCallgraphWorklistSolver[StackAlloc]( + transferProcedure = getMaxAllocation, + init = proc => StackAlloc.Max(0, 0) + ) + solver.solve(p) +} + +def genStackAllocationSpec(p: Program) = { + + val stack = StackMemory("stack", 64, 8) + val SP = LocalVar("R31_in", BitVecType(64)) + + val stackAllocs = callGraphSolve(p) + + val staticStackAllocs = stackAllocs.toSeq.collect { case (k, StackAlloc.Max(local, call)) => + (k, local + call) + } + + for ((proc, maxStackNum) <- staticStackAllocs) { + + val maxStack = BitVecLiteral(maxStackNum, 64) + + if (proc.blocks.isEmpty) { + val ensures = BinaryExpr(EQ, stack, OldExpr(stack)) + proc.ensuresExpr = ensures :: proc.ensuresExpr + } else { + if (maxStack.value > 0) { + // no integer overflow on allocation + val requires1 = BinaryExpr(BVSGE, SP, BinaryExpr(BVSUB, SP, maxStack)) + // val requires2 = BinaryExpr(BVSGT, BinaryExpr(BVSUB, SP, maxStack), BitVecLiteral(0, 64)) + // val requires3 = BinaryExpr(BVSGE, SP, BitVecLiteral(0, 64)) + proc.requiresExpr = requires1 :: proc.requiresExpr + } + + val ensures = { + import boogie.* + val i = BVariable("i", BitVecBType(64), Scope.Local) + val SP = BVariable("R31_in", BitVecBType(64), Scope.Local) + ForAll( + List(i), + BinaryBExpr( + BoolIMPLIES, + BinaryBExpr(BVSGT, i, SP), + BinaryBExpr(EQ, MapAccess(stack.toBoogie, i), Old(MapAccess(stack.toBoogie, i))) + ), + List(MapAccess(stack.toBoogie, i)) + ) + } + + proc.ensures = ensures :: proc.ensures + } + + } +} diff --git a/src/main/scala/ir/transforms/validate/Ackermann.scala b/src/main/scala/ir/transforms/validate/Ackermann.scala new file mode 100644 index 0000000000..f44ae1fc17 --- /dev/null +++ b/src/main/scala/ir/transforms/validate/Ackermann.scala @@ -0,0 +1,405 @@ +package ir.transforms.validate + +import ir.* +import util.functional.memoised +import util.tvLogger + +import scala.collection.mutable + +case class Field(name: String) +type EffCallFormalParam = Variable | Memory | Field + +/** +* Maps a input or output dependnecy of a call to a variable representing it in the TV +* +* formal -> actual option +* +* Allow providing an actual parameter that is always used +* +*/ +case class CallParamMapping( + // Option[Variable] and Option[Expr] represent in and out parameters that are invariant for all calls, i.e. + // should be in terms of global variables + lhs: List[(EffCallFormalParam, Option[Variable])], // out params + rhs: List[(EffCallFormalParam, Option[Expr])] // in params +) + +sealed trait InvTerm { + def toPred(renameSource: Expr => Expr, renameTarget: Expr => Expr): Expr +} + +case class TargetTerm(e: Expr) extends InvTerm { + def toPred(renameSource: Expr => Expr, renameTarget: Expr => Expr) = + renameTarget(e) +} + +/** + * Encodes the equality of source with target expression for a translation validation + * invariant + * + * Expected to refer to source and target variables prior to renaming being applied. + */ +case class CompatArg(source: Expr, target: Expr) extends InvTerm { + + require(!source.isInstanceOf[Memory] && !target.isInstanceOf[Memory]) + + /* + * Generate source == target expression renamed + */ + def toPred(renameSource: Expr => Expr, renameTarget: Expr => Expr) = + BinaryExpr(EQ, renameSource(source), renameTarget(target)) + + def map(srcFunc: Expr => Expr, tgtFunc: Expr => Expr) = { + CompatArg(srcFunc(source), tgtFunc(target)) + } +} + +case class SideEffectStatement( + stmt: Statement, + name: String, + var lhs: List[(EffCallFormalParam, Variable)], + var rhs: List[(EffCallFormalParam, Expr)] +) extends NOP { + override def toString = s"SideEffectStatement($name, $lhs, $rhs)" + // lhs := name(rhs) + // rhs is mapping formal -> actual + // for globals the formal param is the captured global var or memory + + override def cloneStatement() = { + SideEffectStatement(stmt, name, lhs, rhs) + } +} + +object SideEffectStatementOfStatement { + def traceVar(m: Memory) = { + GlobalVar(s"TRACE_MEM_${m.name}_${m.addressSize}_${m.valueSize}", BoolType) + } + + def globalTraceVar = { + GlobalVar(s"TRACE", BoolType) + } + + def param(v: Variable | Memory): (Variable | Memory, Variable) = v match { + case g: GlobalVar => (g -> g) + case g: LocalVar => (g -> g) + case m: Memory => (m -> traceVar(m)) + } +} + +class SideEffectStatementOfStatement(callParams: Map[String, CallParamMapping]) { + import SideEffectStatementOfStatement.* + + def endianHint(e: Endian) = e match { + case Endian.LittleEndian => "le" + case Endian.BigEndian => "be" + } + + def typeHint(t: IRType) = t match { + case IntType => "int" + case BitVecType(sz) => s"bv$sz" + case BoolType => "bool" + case _ => ??? + } + + // source -> target + // variable rewriting applied by parameter analysis + // + // axiom: + // \land ps: formal param in source, pt : param or global in target st (paramInvariant(ps) = pt) + // renamed(src, p) = renamed(tgt, pt) + // + // \forall g: global \in tgt . \exists p or global \in src st. paramInvariant(p) = g + // + // paramInvariant(m: Memory) = m \forall m: Memory + // + // \land ps: formal param in source, pt : param or global in target st, + // renamed(src, p) = renamed(tgt, pt) + // + // + + // need to have already lifted globals to locals + val traceOut = Field("trace") -> globalTraceVar + + /** + * Unified monadic side-effect signature + * + * (name, lhs named params, rhs named params) + */ + def unapply(e: Statement): Option[SideEffectStatement] = e match { + case a @ Assume(e, _, _, true) => { + // Some(SideEffectStatement(a, s"Leak_${typeHint(e.getType)}", List(traceOut), List(traceOut, (Field("arg") -> e)))) + None + } + case call @ DirectCall(tgt, lhs, rhs, _) => + val params = callParams(tgt.name) + + val realLHS = lhs.toMap + val realRHS = rhs.toMap + + val external = tgt.isExternal.contains(true) || tgt.blocks.isEmpty + + val lhsParams = params.lhs.map { + case (formal, Some(actual)) => formal -> actual + case (formal: LocalVar, None) => + formal -> realLHS + .get(formal) + .getOrElse(throw Exception(s"Unable to instantiate call: $formal :: $call :: $params")) + case e => throw Exception(s"Unexpected param arrangement $e") + } + val rhsParams = params.rhs.map { + case (formal, Some(actual)) => formal -> actual + case (formal: LocalVar, None) => + formal -> realRHS + .get(formal) + .getOrElse(throw Exception(s"Unable to instantiate call: $formal :: $call :: $params")) + case e => throw Exception(s"Unexpected param arrangement $e") + } + + Some(SideEffectStatement(e, s"Call_${tgt.name}", lhsParams, rhsParams)) + case MemoryLoad(lhs, memory, addr, endian, size, _) => + val args = List(traceOut, Field("addr") -> addr) + val rets = List(Field("out") -> lhs, traceOut) + Some( + SideEffectStatement( + e, + s"Load_${endianHint(endian)}_${typeHint(addr.getType)}_${typeHint(lhs.getType)}", + rets, + args + ) + ) + case MemoryStore(memory, addr, value, endian, size, _) => + val args = List(traceOut, Field("addr") -> addr) + val rets = List(traceOut) + Some( + SideEffectStatement( + e, + s"Store_${endianHint(endian)}_${typeHint(addr.getType)}_${typeHint(value.getType)}", + rets, + args + ) + ) + case MemoryAssign(lhs, rhs, _) => + Some( + SideEffectStatement( + e, + s"MemoryAssign_${typeHint(lhs.getType)}_${typeHint(rhs.getType)}", + List(traceOut, Field("out") -> lhs), + List(traceOut, Field("arg") -> rhs) + ) + ) + case IndirectCall(arg, _) => + // kind of want these gone :( + // doesn't capture interprocedural effects but we don't resolve indirect calls so it doesn't matter + Some( + SideEffectStatement( + e, + s"IndirectCall_${typeHint(arg.getType)}", + List(traceOut, Field("target") -> arg), + List(traceOut) + ) + ) + case _ => None + } + +} + +object Ackermann { + + case class AckInv(name: String, lhs: List[CompatArg], rhs: List[CompatArg]) { + def toPredicate(renameSource: Expr => Expr, renameTarget: Expr => Expr) = { + val args = boolAnd(rhs.map { case CompatArg(s, t) => + BinaryExpr(EQ, renameSource(s), renameTarget(t)) + }) + val implicant = boolAnd(lhs.map { case CompatArg(s, t) => + BinaryExpr(EQ, renameSource(s), renameTarget(t)) + }) + BinaryExpr(BoolIMPLIES, args, implicant) + } + } + + enum InstFailureReason { + case NameMismatch(msg: String) + case ParamMismatch(msg: String) + } + + /** + * Check compatibility of two side effects and emit the lists (lhs, rhs) such that + * + * `(\forall (si, ti) \in lhs si == ti) ==> (\forall (so, to) \in rhs . so == to)` + * + */ + def instantiateAxiomInstance( + renaming: TransformDataRelationFun + )(source: SideEffectStatement, target: SideEffectStatement): Either[InstFailureReason, AckInv] = { + // source has higher level, has params, target does not have params + + import InstFailureReason.* + + def applyRename(a: EffCallFormalParam): EffCallFormalParam = a match { + case a: Field => a + case a: (Memory | Variable) => + renaming(source.parent.parent.name, None)(a).toList match { + case (n: EffCallFormalParam) :: Nil => n + case (n: EffCallFormalParam) :: tl => n + case n :: Nil => + tvLogger.warn( + s"Transform description fun rewrite formal parameter $a to $n, which I can't fit back into the formal parameter type Variable | Memory | Field, ignoring" + ) + a + case Nil => a + case _ => ??? + } + } + + for { + name <- (source, target) match { + case (l, r) if l.name == r.name => Right(l.name) + case (l, r) => Left(NameMismatch(s"Name incompat: ${l.name}, ${r.name}")) + } + targetArgs = target.rhs.toMap + args <- source.rhs.foldLeft(Right(List()): Either[InstFailureReason, List[CompatArg]]) { + case (agg, (formal, actual)) => + agg.flatMap(agg => { + targetArgs.get(applyRename(formal)) match { + case Some(a) => Right(CompatArg(actual, a) :: agg) + case None => + Left( + ParamMismatch( + s"Unable to match source var $formal to ${applyRename(formal)} in target list ${targetArgs.keys.toList}" + ) + ) + } + }) + } + targetLHS = target.lhs.toMap + lhs <- source.lhs.foldLeft(Right(List()): Either[InstFailureReason, List[CompatArg]]) { + case (agg, (formal, actual)) => + agg.flatMap(agg => { + targetLHS.get(applyRename(formal)) match { + case Some(a) => Right(CompatArg(actual, a) :: agg) + case None => + Left( + ParamMismatch( + s"Unable to match outparam $formal to ${applyRename(formal)} in target list ${target.lhs}" + ) + ) + } + }) + } + } yield (AckInv(name, lhs, args)) + } + + def instantiateAxioms( + sourceEntry: Block, + targetEntry: Block, + renameSourceExpr: Expr => Expr, + renameTargetExpr: Expr => Expr, + paramMapping: TransformDataRelationFun + ): List[(Expr, String)] = { + val seen = mutable.Set[CFGPosition]() + var invariant = Set[(Expr, String)]() + + def getSucc(p: CFGPosition) = { + // seen.add(p) + var n = IntraProcIRCursor.succ(p) + while ( + (n.size == 1) && (n.head match { + case s: SideEffectStatement => false + case _ => true + }) + ) { + // skip statements within stright lines of execution + n = IntraProcIRCursor.succ(n.head) + } + val r = n.filterNot(seen.contains(_)).map { + case stmt: SideEffectStatement => (Some(stmt), stmt) + case s => (None, s) + } + r + } + + val (succ, succMemoStat) = memoised(getSucc) + + val q = mutable.Queue[((Option[SideEffectStatement], CFGPosition), (Option[SideEffectStatement], CFGPosition))]() + val start = ((None, sourceEntry), (None, targetEntry)) + q.enqueue(start) + + def flatMapSucc(s: CFGPosition): Seq[SideEffectStatement] = { + succ(s).toSeq.flatMap { + case (None, r) => flatMapSucc(r) + case (Some(x), r) => Seq(x) + } + } + + val seenQ = mutable.Set[((Option[SideEffectStatement], CFGPosition), (Option[SideEffectStatement], CFGPosition))]() + + while { + // should probably fix the traversal order to avoid re-queuing but this works for now + var c = q.dequeue + while (seenQ.contains(c) && q.nonEmpty) { + c = q.dequeue + } + seenQ += c + + val ((srcCall, srcPos), (tgtCall, tgtPos)) = c + + def advanceBoth() = { + for (s <- succ(srcPos)) { + for (t <- succ(tgtPos)) { + q.enqueue((s, t)) + } + } + } + + def advanceSrc() = { + for (s <- succ(srcPos)) { + q.enqueue((s, (tgtCall, tgtPos))) + } + } + + def advanceTgt() = { + for (t <- succ(tgtPos)) { + q.enqueue(((srcCall, srcPos), t)) + } + } + + def label(s: CFGPosition) = s match { + case p: Procedure => p.name + case b: Block => b.label + case s: Command => s.getClass.getSimpleName + } + (srcCall, tgtCall) match { + case (None, None) => + advanceBoth() + // case (None, Some(x)) if seen.contains(x) => advanceBoth() + // case (Some(x), None) if seen.contains(x) => advanceBoth() + case (None, Some(_)) => advanceSrc() + case (Some(_), None) => advanceTgt() + case (Some(src), Some(tgt)) /* if !(seen.contains(src) && seen.contains(tgt)) */ => { + seen.add(src) + seen.add(tgt) + + instantiateAxiomInstance(paramMapping)(src, tgt) match { + case Right(inv) => { + invariant = invariant + ((inv.toPredicate(renameSourceExpr, renameTargetExpr)) -> inv.name) + advanceBoth() + } + case Left(InstFailureReason.ParamMismatch(err)) => + tvLogger.warn(s"Ackermannisation failure (side effect func params): $err; ${src} ${tgt}") + case Left(_) => () + } + } + case _ => () + } + (q.nonEmpty) + } do {} + + tvLogger.debug(s"Ackermann hitrate : ${succMemoStat().hitRate} ${succMemoStat()}") + + tvLogger.debug(s"Ackermann inv count: ${invariant.size}") + val invs = invariant.toSet + tvLogger.debug(s"Ackermann inv dedup count: ${invs.size}") + + invs.toList + } +} diff --git a/src/main/scala/ir/transforms/validate/SSADAG.scala b/src/main/scala/ir/transforms/validate/SSADAG.scala new file mode 100644 index 0000000000..4a49355f88 --- /dev/null +++ b/src/main/scala/ir/transforms/validate/SSADAG.scala @@ -0,0 +1,259 @@ +package ir.transforms.validate + +import ir.* +import ir.transforms.Substitute +import util.tvLogger + +import scala.collection.mutable + +import cilvisitor.* + +object SSADAG { + + /** + * Convert an acyclic CFA to a transition encoding + * + * Returns the SSA renaming for each block entry in the CFA. + */ + def transform(frames: Map[String, CallParamMapping], p: Procedure, liveVarsBefore: Map[String, Set[Variable]]) = { + + ssaTransform(p, liveVarsBefore) + } + + private class Passify extends CILVisitor { + override def vstmt(s: Statement) = s match { + case l @ SideEffectStatement(s, n, lhs, rhs) => { + // assume ackermann + // FIXME: should probably go in the ackermann pass? + ChangeTo(List()) + } + case SimulAssign(assignments, _) => { + ChangeTo(List(Assume(boolAnd(assignments.map(polyEqual))))) + } + case _ => SkipChildren() + } + } + + def passify(p: Procedure) = { + visit_proc(Passify(), p) + } + + def convertToMonadicSideEffect(frames: Map[String, CallParamMapping], p: Procedure) = { + + class MonadicConverter(frames: Map[String, CallParamMapping]) extends CILVisitor { + val SF = SideEffectStatementOfStatement(frames) + override def vstmt(s: Statement) = s match { + case a @ SF(s) if a.isInstanceOf[Assume] => ChangeTo(List(a, s)) + case SF(s) => ChangeTo(List(s)) + case _ => SkipChildren() + } + } + + visit_proc(MonadicConverter(frames), p) + + } + + def blockDoneVar(b: Block) = { + LocalVar(b.label + "_done", BoolType) + } + + /** + * Convert an acyclic CFA to a transition encoding + * + * Returns the SSA renaming for each block entry in the CFA. + */ + def ssaTransform( + p: Procedure, + liveVarsBefore: Map[String, Set[Variable]] + ): (((String, Expr) => Expr), Map[BlockID, Map[Variable, Variable]]) = { + + var renameCount = 0 + val stRename = mutable.Map[Block, Map[Variable, Variable]]() + val renameBefore = mutable.Map[Block, Map[Variable, Variable]]() + + def blockDone(b: Block) = blockDoneVar(b) + + var count = Map[String, Int]() + + def freshName(v: Variable) = + renameCount = count.get(v.name).getOrElse(0) + 1 + count = count + (v.name -> renameCount) + v match { + case l: LocalVar => l.copy(varName = l.name + "_AT" + renameCount, index = 0) + case l: GlobalVar => l.copy(name = l.name + "_AT" + renameCount) + } + + class RenameRHS(rn: Variable => Option[Variable]) extends CILVisitor { + override def vrvar(v: Variable) = ChangeTo(rn(v).getOrElse(v)) + override def vexpr(e: Expr) = { + ChangeTo(Substitute(rn, false)(e).getOrElse(e)) + } + override def vstmt(s: Statement) = s match { + case se @ SideEffectStatement(s, n, lhs, rhs) => + se.rhs = rhs.map((f, e) => (f, visit_expr(this, e))) + SkipChildren() + case _ => DoChildren() + + } + } + + def renameRHS(rename: Variable => Option[Variable])(c: Command): Command = c match { + // rename all rvars + case s: Statement => visit_stmt(RenameRHS(rename), s).head + case s: Jump => visit_jump(RenameRHS(rename), s) + } + + ir.transforms.reversePostOrder(p) + val worklist = mutable.PriorityQueue[Block]()(Ordering.by(_.rpoOrder)) + worklist.addAll(p.blocks) + + class RenameLHS(subst: Variable => Option[Variable]) extends CILVisitor { + + override def vstmt(s: Statement) = s match { + case s @ SideEffectStatement(_, _, lhs, _) => + s.lhs = lhs.map((f, e) => (f, visit_lvar(this, e))) + SkipChildren() + case _ => DoChildren() + + } + + override def vlvar(v: Variable) = subst(v) match { + case Some(vn) => ChangeTo(vn) + case none => SkipChildren() + } + } + + def renameLHS(substs: Map[Variable, Variable], s: Statement) = s match { + case s: Statement => visit_stmt(RenameLHS(substs.get), s) + } + + def mkIte(cases: List[(Expr, Expr)]) = { + /* assume cases are total, i.e. last guard of cases can be discarded */ + val rt = cases.head._2.getType + FApplyExpr("ite", cases.flatMap { case (c, b) => List(c, b) }, rt) + } + + val phiLabel = "TVSSADAGPHI" + + var phis = 0 + while (worklist.nonEmpty) { + val b = worklist.dequeue() + + var blockDoneCond = List[Expr](boolOr(b.prevBlocks.map(blockDone).toList)) + + def live(v: Variable) = + liveVarsBefore + .get(b.label) + .forall(_.contains(v)) + + val ite = false + + if (b.prevBlocks.nonEmpty) then { + val defines: Seq[Variable] = + (b.prevBlocks.toSeq.flatMap(b => stRename.get(b).toSeq.flatMap(_.map(_._1).filter(live)))).toSet.toSeq + + var nrenaming = mutable.Map.from(Map[Variable, Variable]()) + + defines.foreach((v: Variable) => { + val defsToJoin = b.prevBlocks.map(b => b -> stRename.get(b).flatMap(_.get(v)).getOrElse(v)) + val inter = defsToJoin.map(_._2).toSet + if (inter.size == 1) { + nrenaming(v) = inter.head + } else { + + val fresh = freshName(v) + + val grouped = defsToJoin.groupBy(_._2).map { + case (ivar, blockset) => { + val blocks = blockset.map(_._1) + (boolOr(blocks.map(blockDone)), ivar) + } + } + + val phiscond = if (ite) { + Seq(LocalAssign(fresh, mkIte(grouped.toList), Some(phiLabel))) + } else { + val cond = boolAnd(grouped.map { case (cond, lhs) => + BinaryExpr(BoolIMPLIES, cond, polyEqual(lhs, fresh)) + }) + Seq(Assume(cond, Some(phiLabel), Some(phiLabel))) + } + + // phis += grouped.length + b.statements.prependAll(phiscond) + + nrenaming(v) = fresh + } + }) + + stRename(b) = nrenaming.toMap + } else { + val rn = mutable.Map.from(renameBefore.getOrElse(b, Map())) + stRename(b) = rn.toMap + rn + } + + if (!b.parent.entryBlock.contains(b)) { + renameBefore(b) = stRename.getOrElse(b, Map()) + } + + val renaming = mutable.Map.from(renameBefore.getOrElse(b, Map())) + + def isPhi(s: Statement) = s match { + case a if a.label.contains(phiLabel) => true + case _ => false + } + + for (s <- b.statements.toList.filterNot(isPhi)) { + val c = renameRHS(renaming.get)(s) // also modifies in-place + + c match { + case a @ SideEffectStatement(s, n, lhs, rhs) => { + // note this matches some assume statements + // where checkSecurity = true + val rn = lhs + .map((formal, v) => { + val freshDef = freshName(v) + renaming(v) = freshDef + v -> freshDef + }) + .toMap + + renameLHS(rn, a) + } + case a @ Assume(cond, _, _, _) => + blockDoneCond = cond :: blockDoneCond + b.statements.remove(a) + case a: Assign => { + a.assignees.foreach(v => { + val freshDef = freshName(v) + renameLHS(Map(v -> freshDef), a) + renaming(v) = freshDef + }) + } + case _ => () + } + } + + renameRHS(renaming.get)(b.jump) + stRename(b) = renaming.toMap + + val c = + if (b.parent.entryBlock.contains(b) || b.label.endsWith("SYNTH_ENTRY")) then TrueLiteral + else boolAnd(blockDoneCond) + + b.statements.append(LocalAssign(blockDone(b), c)) + if (b.label.contains("SYNTH_EXIT")) { + b.statements.append(Assert(blockDone(b), Some("blockdone"))) + } + } + + tvLogger.debug("Phi node count: " + phis) + val renameBeforeLabels = renameBefore.map((b, r) => b.label -> r).toMap + + ( + (b, c) => renameBeforeLabels.get(b).map(b => visit_expr(RenameRHS(b.get), c)).getOrElse(c), + renameBeforeLabels.toMap + ) + } +} diff --git a/src/main/scala/ir/transforms/validate/SimplifyPipeline.scala b/src/main/scala/ir/transforms/validate/SimplifyPipeline.scala new file mode 100644 index 0000000000..c441f9e0bd --- /dev/null +++ b/src/main/scala/ir/transforms/validate/SimplifyPipeline.scala @@ -0,0 +1,334 @@ +package ir.transforms.validate +import ir.* +import util.SMT.SatResult +import util.{Logger, SimplifyMode, tvEvalLogger} + +import java.io.{BufferedWriter, File, FileWriter} + +import cilvisitor.{visit_proc, visit_prog} + +/** + * Translation-validated simplification pipeline. + */ + +def simplifyCFG(p: Program) = { + p.procedures.foreach(transforms.RemoveUnreachableBlocks.apply) + for (p <- p.procedures) { + while (transforms.coalesceBlocks(p)) {} + transforms.removeEmptyBlocks(p) + } +} + +def simplifyCFGValidated(config: TVJob, p: Program): TVJob = { + TranslationValidator.forTransform("simplifyCFG", simplifyCFG)(p, config) +} + +val DSAInvariant = (u: Map[String, Map[BlockID, (Map[Variable, Variable], Map[Variable, Variable])]]) => { + def sourceToTarget(p: ProcID, b: Option[BlockID])(v: Variable | Memory) = v match { + case l @ LocalVar(n, t, i) => Seq(LocalVar(l.varName, t)) + case g => Seq(g) + } + InvariantDescription(sourceToTarget) +} + +def dynamicSingleAssignment(config: TVJob, p: Program) = { + + TranslationValidator.forTransform("DSA", transforms.OnePassDSA().applyTransform, DSAInvariant)(p, config) +} + +def dsaCopyPropCombined(config: TVJob, p: Program) = { + // not working reliably + + def dsa(p: Program) = transforms.OnePassDSA().applyTransform(p) + def copyprop(p: Program) = { + p.procedures.foreach(ir.eval.AlgebraicSimplifications(_)) + val r = p.procedures.map(p => p.name -> transforms.OffsetProp.transform(p)).toMap + transforms.CleanupAssignments().transform(p) + r + } + + def transform(p: Program) = { + val dr = dsa(p) + transforms.applyRPO(p) + val cp = copyprop(p) + (dr, cp) + } + + val bidiDSAInv = (dsares: Map[ProcID, Map[BlockID, (Map[Variable, Variable], Map[Variable, Variable])]]) => { + + def sourceToTarget(p: ProcID, b: Option[BlockID])(v: Variable | Memory) = v match { + case l @ LocalVar(n, t, i) => Seq(LocalVar(l.varName, t)) + case g => Seq(g) + } + + def targetToSource(p: ProcID, b: Option[BlockID])(v: Variable | Memory) = { + v match { + case m: Memory => Seq() + case v: Variable => { + dsares + .get(p) + .flatMap(dsares => + b.flatMap(b => + dsares + .get(b) + .flatMap(r => { + val (beforeRn, afterRn) = r + beforeRn.get(v) + }) + ) + ) + .toSeq + } + } + } + + InvariantDescription(sourceToTarget, targetToSource) + } + + def invariant( + dr: Map[String, Map[String, (Map[Variable, Variable], Map[Variable, Variable])]], + cp: Map[String, Map[Variable, Expr]] + ) = copypropInvariant(cp) compose bidiDSAInv(dr) + + TranslationValidator.forTransform("DSA-CopyProp", transform, invariant)(p, config) +} + +val copypropInvariant = (results: Map[String, Map[Variable, Expr]]) => { + def flowFacts(b: String): Map[Variable, Expr] = { + results.getOrElse(b, Map()) + } + + val revResults: Map[ProcID, Map[Variable, Set[Variable]]] = results.map { case (p, r) => + val m: Map[Variable, Set[Variable]] = (r.toSeq + .collect { case (v1: Variable, v2: Variable) => + v2 -> v1 + } + .groupBy(_._1) + .map { case (k, v) => + (k, v.map(_._2).toSet) + }) + p -> m + }.toMap + + def renamingTgt(proc: ProcID, b: Option[BlockID])(v: Variable | Memory) = v match { + case v: Variable => results.get(proc).flatMap(_.get(v)).toSeq + case _ => Seq() + } + + def renaming(proc: ProcID, b: Option[BlockID])(v: Variable | Memory) = v match { + case g => + Seq(g) ++ (v match { + case v: Variable => + val rr = revResults.get(proc).toSeq.flatMap(_.get(v).toSeq.flatten) + results.get(proc).flatMap(_.get(v)) ++ rr + case _ => Seq() + }) + } + + InvariantDescription(renaming, renamingTgt) +} + +def copyProp(config: TVJob, p: Program) = { + + def transform(p: Program): Map[String, Map[Variable, Expr]] = { + p.procedures.foreach(ir.eval.AlgebraicSimplifications(_)) + val r = p.procedures.map(p => p.name -> transforms.OffsetProp.transform(p)).toMap + transforms.CleanupAssignments().transform(p) + r + } + + TranslationValidator.forTransform("CopyProp", transform, copypropInvariant)(p, config) +} + +def parameters(config: TVJob, ctx: IRContext) = { + + val localInTarget = ctx.program.procedures.view.map { case p => + p.name -> (freeVarsPos(p).collect { case l: LocalVar => + l + }.toSet) + }.toMap + + val entryBlocks = ctx.program.procedures.collect { + case p if p.entryBlock.isDefined => p.name -> p.entryBlock.get.label + }.toMap + + val returnBlocks = ctx.program.procedures.collect { + case p if p.returnBlock.isDefined => p.name -> p.returnBlock.get.label + }.toMap + + // val (validator, res) = + // validatorForTransform(p => transforms.liftProcedureCallAbstraction(p, Some(ctx.specification)))(ctx.program) + + var res: Option[specification.Specification] = None + + val invariant = (result: Option[specification.Specification]) => { + res = result + + def sourceToTarget(p: ProcID, b: Option[String])(v: Variable | Memory): Seq[Expr] = v match { + // in/out params only map to the registers at the procedure entry and exit + case LocalVar(s"${i}_in", t, 0) if b.forall(_.endsWith("ENTRY")) => Seq(GlobalVar(s"$i", t)) + case LocalVar(s"${i}_out", t, 0) if b.forall(_.endsWith("EXIT")) => Seq(GlobalVar(s"$i", t)) + case LocalVar(s"${i}_in", t, 0) if b.forall(b => entryBlocks.get(p).contains(b)) => Seq(GlobalVar(s"$i", t)) + case LocalVar(s"${i}_out", t, 0) if b.forall(b => returnBlocks.get(p).contains(b)) => Seq(GlobalVar(s"$i", t)) + case LocalVar(s"${i}_in", t, 0) if b.forall(_ == p) => Seq(GlobalVar(s"$i", t)) + case LocalVar(s"${i}_out", t, 0) if b.forall(_ == p) => Seq(GlobalVar(s"$i", t)) + case LocalVar(s"${i}_in", t, 0) => + Seq() + case LocalVar(s"${i}_out", t, 0) => + Seq() + + case local @ LocalVar(n, t, 0) if localInTarget.get(p).exists(_.contains(local)) => + // local variables + Seq(local) + case LocalVar(n, t, 0) => + // the rest map to global variables with the same name + Seq(GlobalVar(n, t)) + case g => Seq(g) + } + + InvariantDescription(sourceToTarget) + } + + val vr = TranslationValidator.forTransform( + "Parameters", + p => transforms.liftProcedureCallAbstraction(p, Some(ctx.specification)), + invariant + )(ctx.program, config) + + (vr, ctx.copy(specification = res.get)) + +} + +def guardCleanupTransforms(p: Program) = { + def simplifyGuards(prog: Program) = { + (prog.procedures + .map(p => + transforms.fixupGuards(p) + val gvis = transforms.GuardVisitor(true) + visit_proc(gvis, p) + p.name -> gvis.replaced + )) + .toMap + } + + def deadAssignmentElimination(prog: Program) = { + transforms.CleanupAssignments().transform(prog) + visit_prog(transforms.CleanupAssignments(), prog) + } + + p.procedures.foreach(ir.eval.AlgebraicSimplifications(_)) + val guardProp = simplifyGuards(p) + println(guardProp) + p.procedures.foreach(p => { + ir.eval.AlgebraicSimplifications(p) + ir.eval.AssumeConditionSimplifications(p) + ir.eval.AlgebraicSimplifications(p) + ir.eval.cleanupSimplify(p) + }) + deadAssignmentElimination(p) + simplifyCFG(p) + transforms.removeDuplicateGuard(p) + guardProp +} + +def guardCleanup(config: TVJob, p: Program) = { + TranslationValidator.forTransform("GuardCleanup", guardCleanupTransforms, copypropInvariant)(p, config) +} + +def nop(config: TVJob, p: Program) = { + TranslationValidator.forTransform("NOP", p => p)(p, config) +} + +def assumePreservedParams(config: TVJob, p: Program) = { + // val (validator, asserts) = validatorForTransform(transforms.CalleePreservedParam.transform)(p) + // validator.getValidationSMT(config, , introducedAsserts = asserts.toSet) + TranslationValidator.forTransform( + "AssumeCallPreserved", + transforms.CalleePreservedParam.transform, + asserts => InvariantDescription(introducedAsserts = asserts.toSet) + )(p, config) + +} + +def validatedSimplifyPipeline(ctx: IRContext, mode: util.SimplifyMode): (TVJob, IRContext) = { + // passing through ircontext like this just for spec transform is horrible + // also the translation validation doesn't really consider spec at all + // maybe it should; in ackermann phase and observable variables... + val p = ctx.program + var config = mode match { + case SimplifyMode.ValidatedSimplify(verifyMode, filePrefix, dryRun) => + TVJob(outputPath = filePrefix, verify = verifyMode, debugDumpAlways = true, dryRun = dryRun) + case _ => TVJob(None, None) + } + + tvEvalLogger.debug { + val counter = ir.transforms.CountStatements() + val _ = visit_prog(counter, p) + "tv-eval-marker: before-stmt-count=" + counter.count + } + + var counter = ir.transforms.CountGuardStatements() + val _ = visit_prog(counter, p) + counter.reportToLog("before") + + transforms.applyRPO(p) + // nop(config, p) + // Logger.writeToFile(File("beforeParams.il"), translating.PrettyPrinter.pp_prog(ctx.program)) + val (res, nctx) = parameters(config, ctx) + config = res + config = assumePreservedParams(config, p) + assert(ir.invariant.readUninitialised(ctx.program)) + transforms.applyRPO(p) + config = simplifyCFGValidated(config, p) + assert(ir.invariant.readUninitialised(ctx.program)) + transforms.applyRPO(p) + config = dynamicSingleAssignment(config, p) + assert(ir.invariant.readUninitialised(ctx.program)) + transforms.applyRPO(p) + config = copyProp(config, p) + + assert(ir.invariant.readUninitialised(ctx.program)) + transforms.applyRPO(p) + config = guardCleanup(config, p) + + assert(ir.invariant.readUninitialised(ctx.program)) + counter = ir.transforms.CountGuardStatements() + val _ = visit_prog(counter, p) + counter.reportToLog("after") + + tvEvalLogger.debug { + val counter = ir.transforms.CountStatements() + val _ = visit_prog(counter, p) + "tv-eval-marker: after-stmt-count=" + counter.count + } + + val failed = config.results.filter(_.verified.exists(_.isInstanceOf[SatResult.SAT])) + + if (failed.nonEmpty) { + val fnames = failed.map(f => s" ${f.runName}::${f.proc} ${f.smtFile}").reverse.mkString("\n ") + // Logger.error(s"Failing cases: $fnames") + throw Exception(s"TranslationValidationFailed:\n $fnames") + } else if (config.verify.isDefined) { + Logger.info("[!] Translation validation passed") + } + + config.outputPath.foreach(p => { + val csv = (config.results.map(_.toCSV).groupBy(_._1).toList match { + case (h, vs) :: Nil => h :: vs.map(_._2) + case _ => { + Logger.error("Broken header structure for csv metrics file") + List() + } + }).mkString("\n") + Logger.writeToFile(File(p + "/stats.csv"), csv) + }) + + Logger.info("[!] Simplify :: Writing simplification validation") + val w = BufferedWriter(FileWriter("rewrites.smt2")) + ir.eval.SimplifyValidation.makeValidation(w) + w.close() + + // ir.transforms.genStackAllocationSpec(p) + + (config, nctx) +} diff --git a/src/main/scala/ir/transforms/validate/TransitionSystem.scala b/src/main/scala/ir/transforms/validate/TransitionSystem.scala new file mode 100644 index 0000000000..493dfcc4c4 --- /dev/null +++ b/src/main/scala/ir/transforms/validate/TransitionSystem.scala @@ -0,0 +1,262 @@ +package ir.transforms.validate +import analysis.Loop +import ir.* +import ir.dsl.IRToDSL + +import scala.collection.mutable + +object PCMan { + val assumptionFailLabel = "ASSUMEFAIL" + val assertionFailLabel = "ASSERTFAIL" + + val assertFailBlockLabel = "ASSERTFAIL" + + val allocatedPCS = mutable.Map[String, BitVecLiteral]() + var pcCounter = 0 + def PCSym(s: String) = { + allocatedPCS.getOrElseUpdate( + s, { + pcCounter += 1 + BitVecLiteral(pcCounter, 64) + } + ) + } + + def setPCLabel(label: String) = { + val pcVar = TransitionSystem.programCounterVar + LocalAssign(pcVar, PCSym(label), Some(label)) + } + + def pcGuard(label: String) = { + val pcVar = TransitionSystem.programCounterVar + Assume(BinaryExpr(EQ, pcVar, PCSym(label)), Some(s"PC = $label")) + } +} + +case class CutPointMap(cutLabelBlockInTr: Map[String, Block], cutLabelBlockInProcedure: Map[String, Block]) + +object TransitionSystem { + + import PCMan.* + + val traceType = BoolType + val programCounterVar = GlobalVar("SYNTH_PC", BitVecType(64)) + val traceVar = GlobalVar("TRACE", traceType) + + def procToTransition(p: Procedure, loops: List[Loop], cutJoins: Boolean = false) = { + + val pcVar = programCounterVar + + // cut point in transition system program + var cutPoints = Map[String, Block]() + + // cut point block in original program + var cutPointRealBlockBegin = Map[String, Block]() + + val synthEntryJump = GoTo(Seq()) + val synthEntry = Block(s"${p.name}_SYNTH_ENTRY", None, Seq(), synthEntryJump) + val synthExit = Block(s"${p.name}_SYNTH_EXIT", None, Seq()) + + cutPoints = cutPoints.updated("EXIT", synthExit) + + cutPoints = cutPoints.updated("ENTRY", synthEntry) + + p.addBlocks(Seq(synthEntry, synthExit)) + + p.entryBlock.foreach(e => { + e.statements.prepend(pcGuard("ENTRY")) + synthEntryJump.addTarget(e) + cutPointRealBlockBegin = cutPointRealBlockBegin.updated("ENTRY", e) + }) + + p.returnBlock.foreach(e => { + e.jump match { + case r: Return => { + val outAssigns = r.outParams.map((formal, actual) => { + val l = LocalAssign(formal, actual) + l.comment = Some("synth return param") + l + }) + + e.replaceJump(GoTo(synthExit)) + e.statements.append(setPCLabel("RETURN")) + e.statements.appendAll(outAssigns) + + val nb = e.createBlockBetween(synthExit, "returnblocknew") + cutPoints = cutPoints.updated("RETURN", nb) + cutPointRealBlockBegin = cutPointRealBlockBegin.updated("RETURN", nb) + + } + case _ => ??? + } + + }) + + p.entryBlock = synthEntry + p.returnBlock = synthExit + + var loopCount = 0 + + for (l <- loops.filter(l => p.blocks.contains(l.header))) { + loopCount += 1 + + val backedges = l.backEdges.toList.sortBy(e => s"${e.from.label}_${e.to.label}") + val label = s"Loop${loopCount}" + synthEntryJump.addTarget(l.header) + + val nb = synthEntry.createBlockBetween(l.header, "cut_join_to_" + label) + nb.statements.prepend(pcGuard(label)) + + cutPoints = cutPoints.updated(label, l.header) + cutPointRealBlockBegin = cutPointRealBlockBegin.updated(label, l.header) + for (backedge <- backedges) { + assert(l.header == backedge.to) + backedge.from.statements.append(LocalAssign(pcVar, PCSym(label), Some(label))) + backedge.from.replaceJump(GoTo(synthExit)) + } + } + + var joinCount = 0 + if (cutJoins) { + + val cuts = p.blocks + .filter(c => + c.prevBlocks.size > 1 + && c.prevBlocks.flatMap(_.nextBlocks).forall(_ == c) + && !p.returnBlock.contains(c) + && !p.entryBlock.contains(c) + ) + .toList + .sortBy(_.label) + + for (c <- cuts) { + joinCount = joinCount + 1 + val label = s"Join${joinCount}" + cutPoints = cutPoints.updated(label, c) + + for (incoming <- c.prevBlocks) { + incoming.statements.append(LocalAssign(pcVar, PCSym(label), Some(label))) + incoming.replaceJump(GoTo(synthExit)) + } + + synthEntryJump.addTarget(c) + val nb = synthEntry.createBlockBetween(c, "cut_join_to_" + label) + nb.statements.prepend(pcGuard(label)) + + } + } + + for (s <- p) { + s match { + case u: Unreachable if u.parent != synthExit => { + u.parent.statements.append(PCMan.setPCLabel(PCMan.assumptionFailLabel)) + u.parent.replaceJump(GoTo(synthExit)) + } + case g: GoTo if g.targets.isEmpty => { + g.parent.statements.append(PCMan.setPCLabel(PCMan.assumptionFailLabel)) + g.parent.replaceJump(GoTo(synthExit)) + } + case _ => () + } + + } + synthExit.replaceJump(Return()) + CutPointMap(cutPoints, cutPointRealBlockBegin) + + } + + def toTransitionSystemInPlace(p: Procedure): CutPointMap = { + require(p.entryBlock.isDefined) + + val loops = analysis.LoopDetector.identify_loops(p.entryBlock.get) + val floops = loops.loops_o + val cutPoints = procToTransition(p, floops) + p.formalInParam.clear() + p.formalOutParam.clear() + + cutPoints + } + + /** + * Converts each procedure to a transition system + */ + def toTransitionSystemClone(iprogram: Program) = { + + val program = IRToDSL.convertProgram(iprogram).resolve + + val loops = analysis.LoopDetector.identify_loops(program) + val floops = loops.loops_o + + val cutPoints = program.procedures + .map(p => { + p -> procToTransition(p, floops) + }) + .toMap + + (program, cutPoints) + } + + def removeUnreachableBlocks(p: Procedure) = { + val reachable = p.entryBlock.get.forwardIteratorFrom.toSet + val unreachable = p.blocks.filterNot(reachable.contains).toList + p.removeBlocksDisconnect(unreachable) + } + + /** + * Convert asserts in program to a jump to exit with a specific PC set. + * + * @param [[introdAsserts]] specifies which assertions set the pc to the [[assumptionFailLabel]] + * rather than the [[assertionFailLabel]]. + * + */ + def totaliseAsserts(proc: Procedure, introdAsserts: Set[String] = Set()) = { + val b = Block(assertFailBlockLabel) + proc.addBlock(b) + AssertsToPC(b, introdAsserts).transform(proc) + } + + private class AssertsToPC(val exitBl: Block, introdAsserts: Set[String] = Set()) { + + var count = 0 + + def transform(p: Procedure): Unit = { + val asserts = p.collect { case a: Assert => + a + } + + transform(asserts) + } + + def transform(s: Iterable[Assert]): Unit = { + for (stmt <- s) { + count += 1 + val label = s"assert$count" + + val bl = stmt.parent + val successor = bl.splitAfterStatement(stmt, label + "Pass") + // bl ends in Assert + // successor is rest of block + + bl.statements.remove(stmt) + bl.statements.append(Assume(stmt.body)) + + successor.statements.prepend(Assume(stmt.body, Some("assertpass"))) + + val failureLabel = + if (stmt.label.isDefined && (introdAsserts.contains(stmt.label.get))) then PCMan.assumptionFailLabel + else PCMan.assertionFailLabel + + val falseBranch = Block( + bl.label + label + "Fail", + None, + Seq(Assume(UnaryExpr(BoolNOT, stmt.body)), PCMan.setPCLabel(failureLabel)), + GoTo(Seq(exitBl)) + ) + + bl.parent.addBlock(falseBranch) + + bl.jump.asInstanceOf[GoTo].addTarget(falseBranch) + } + } + } +} diff --git a/src/main/scala/ir/transforms/validate/TranslationValidate.scala b/src/main/scala/ir/transforms/validate/TranslationValidate.scala new file mode 100644 index 0000000000..faa8f35b7b --- /dev/null +++ b/src/main/scala/ir/transforms/validate/TranslationValidate.scala @@ -0,0 +1,1422 @@ +package ir.transforms.validate + +import analysis.ProcFrames.* +import cats.collections.DisjointSets +import ir.* +import ir.cilvisitor.* +import translating.PrettyPrinter.* +import util.SMT.* +import util.{LogLevel, PerformanceTimer, tvLogger} + +import java.io.File + +/** + * Result of a translation validation task for a single procedure and transform step. + */ +case class TVResult( + runName: String, + proc: String, + verified: Option[SatResult], + smtFile: Option[String], + verifyTime: Map[String, Long] +) { + def toCSV = { + val veri = verified match { + case Some(SatResult.UNSAT) => "unsat" + case Some(s: SatResult.SAT) => "sat" + case Some(_) => "unknown" + case None => "disabled" + } + val times = verifyTime.toList + .sortBy(_._1) + .map { case (n, t) => + t.toString + } + .mkString(",") + + val timesHeader = verifyTime.toList + .sortBy(_._1) + .map { case (n, t) => + n + } + .mkString(",") + + val header = "pass,procedure,outcome," + timesHeader + val row = s"$runName,$proc,$veri,$times" + + (header, row) + } +} + +/** + * Configuration and result list for a transltion validation task of a program (possibly including multiple separate passes). + */ +case class TVJob( + outputPath: Option[String], + verify: Option[util.SMT.Solver] = None, + results: List[TVResult] = List(), + debugDumpAlways: Boolean = false, + /* minimum number of statements in source and target combined to trigger case analysis */ + splitLargeProceduresThreshold: Option[Int] = Some(60), + dryRun: Boolean = false +) { + + lazy val noneFailed = { + !(results.exists(_.verified.exists(_.isInstanceOf[SatResult.SAT]))) + } +} + +/** + * Describes the mapping from source variable to target expression at a given Block ID in the source program. + */ +type TransformDataRelationFun = (ProcID, Option[BlockID]) => (Variable | Memory) => Seq[Expr] + +/** + * Closures describing the reslationship between surce and target programs. + */ +case class InvariantDescription( + /** The way live variables at each cut in the source program relate to equivalent expressions or variables in the target. + * + * NOTE: !!! The first returned value of this is also used to map procedure call arguments in the source + * program to the equivalent arguments in the target program. + * */ + renamingSrcTgt: TransformDataRelationFun = (_, _) => e => Seq(e), + + /** + * Describes how live variables at a cut in the target program relate to equivalent variables in the source. + * + */ + renamingTgtSrc: TransformDataRelationFun = (_, _) => _ => Seq(), + + /** + * Set of values of [ir.Assert.label] for assertions introduced in this pass, whose should + * be ignored as far as translation validation is concerned. + */ + introducedAsserts: Set[String] = Set() +) { + def compose(i: InvariantDescription) = { + InvariantDescription(composeDRFun(renamingSrcTgt, i.renamingSrcTgt), composeDRFun(i.renamingTgtSrc, renamingTgtSrc)) + } + +} + +type CutLabel = String +type BlockID = String +type ProcID = String + +def composeDRFun(a: TransformDataRelationFun, b: TransformDataRelationFun): TransformDataRelationFun = { + + class subst(funct: TransformDataRelationFun)(p: ProcID, bl: Option[BlockID]) extends CILVisitor { + + def sub(v: Variable | Memory) = { + funct(p, bl)(v) match { + case e :: Nil => ChangeTo(e) + case Nil => throw Exception("none") + } + } + override def vexpr(e: Expr) = e match { + case v: Variable => sub(v) + case v: Memory => sub(v) + case _ => DoChildren() + } + } + + def toSource(funct: TransformDataRelationFun)(p: ProcID, bl: Option[BlockID])(e: Expr): Option[Expr] = { + try { + Some(visit_expr(subst(funct)(p, bl), e)) + } catch { + case x => None + } + } + + def composed(p: ProcID, bl: Option[BlockID])(v: Variable | Memory) = { + a(p, bl)(v).flatMap { case e: Expr => + toSource(b)(p, bl)(e) + } + } + + composed +} + +enum FormalParam { + case Global(v: Memory | GlobalVar) + case FormalParam(n: String, t: IRType) +} + +/** + * Describe renaming for a function call parameter list, map from variable to the (formal, actual) pair, + * if actual is Some() it is invariant at any call site. + */ +type ParameterRenamingFun = (Variable | Memory) => (Variable, Option[Expr]) + +def polyEqual(e1: Expr, e2: Expr) = { + (e1.getType, e2.getType) match { + case (l, r) if l == r => BinaryExpr(EQ, e1, e2) + case (BitVecType(sz1), BitVecType(sz2)) if sz1 > sz2 => BinaryExpr(EQ, e1, ZeroExtend(sz1 - sz2, e2)) + case (BitVecType(sz1), BitVecType(sz2)) if sz1 < sz2 => BinaryExpr(EQ, ZeroExtend(sz2 - sz1, e1), e2) + case (a, b) => throw Exception(s"wierd type $a == $b") + } +} + +def combineProcs(p1: Procedure, p2: Procedure): Program = { + import ir.dsl.* + import IRToDSL.* + import scala.collection.immutable.ArraySeq + val entryName = p1.name + "_P_ENTRY" + val eproc = EventuallyProcedure( + p1.procName + "_par_" + p2.procName, + Map(), + Map(), + Seq(block(entryName, goto(p1.entryBlock.get.label, p2.entryBlock.get.label))) ++ (p1.blocks ++ p2.blocks).toSet + .map(convertBlock) + .to(ArraySeq), + Some(entryName), + p2.returnBlock.map(_.label), + p1.address + ) + + val n = eproc.copy(blocks = eproc.blocks) + EventuallyProgram(n).resolve +} + +/** + * For a monadic transition sytem, renaming to partition variables and functions. + */ +class NamespaceState(val namespace: String) extends CILVisitor { + + def stripNamespace(n: String) = n.stripPrefix(namespace + "__") + + override def vblock(b: Block) = { + b.label = namespace + "__" + b.label + DoChildren() + } + + override def vexpr(e: Expr) = e match { + // case f @ FApplyExpr(n, p, r, _) => + // ChangeDoChildrenPost(f.copy(name = namespace + "__" + f.name), x => x) + case _ => DoChildren() + } + + override def vlvar(v: Variable) = v match { + case l: LocalVar => ChangeTo(l.copy(varName = namespace + "__" + l.varName)) + case l: GlobalVar => ChangeTo(l.copy(name = namespace + "__" + l.name)) + } + + override def vrvar(v: Variable) = v match { + case l: LocalVar => ChangeTo(l.copy(varName = namespace + "__" + l.varName)) + case l: GlobalVar => ChangeTo(l.copy(name = namespace + "__" + l.name)) + } + + override def vmem(m: Memory) = m match { + case m: SharedMemory => ChangeTo(m.copy(name = namespace + "__" + m.name)) + case m: StackMemory => ChangeTo(m.copy(name = namespace + "__" + m.name)) + } +} + +/** + * Structure of an invariant relating two programs + */ +enum Inv { + /* A constraint guarded by the PC value for a specific cut point */ + case CutPoint(cutPointPCGuard: String, pred: List[InvTerm], comment: Option[String] = None) + + /* a constraint on the variables defined at a specific cut-point, not guarded by the PC value for that cut */ + case GlobalConstraint(cutPointPCGuard: String, pred: List[InvTerm], comment: Option[String] = None) + + /* a constraint conditional on a predicate in the source program */ + +} + +object TranslationValidator { + + case class ProcInfo( + name: String, + transition: Procedure, + liveVars: Map[BlockID, Set[Variable]], + cuts: CutPointMap, + callParams: CallParamMapping, + private val ssaRenamingFun: ((String, Expr) => Expr), + private val ssaDefines: Map[BlockID, Map[Variable, Variable]], + cutRestict: Option[String] = None + ) { + + def defines(block: BlockID): Set[Variable] = { + ssaDefines.get(block).map(_.keys).toSet.flatten + } + + lazy val cutBlockLabels = cuts.cutLabelBlockInProcedure.map { case (cl, b) => + cl -> b.label + } + + /** + * Apply the ssa renaming for a variable at a specific block identifier. + */ + def renameSSA(block: BlockID, e: Expr): Expr = { + ssaRenamingFun(block, e) + } + } + + case class InterproceduralInfo( + program: Program, + sourceFrames: Map[ProcID, Frame], + targetFrames: Map[ProcID, Frame], + sourceParams: Map[ProcID, CallParamMapping], + targetParams: Map[ProcID, CallParamMapping] + ) + + class IntraLiveVarsDomainSideEffect(frames: Map[String, CallParamMapping]) + extends transforms.PowerSetDomain[Variable] { + // expected backwards + + val SideEffect = SideEffectStatementOfStatement(frames) + + def transfer(s: Set[Variable], a: Command): Set[Variable] = { + a match { + case SideEffectStatement(_, _, lhs, rhs) => { + (s -- lhs.map(_._2)) ++ rhs.flatMap(_._2.variables) + } + case SideEffect(SideEffectStatement(_, _, lhs, rhs)) => { + (s -- lhs.map(_._2)) ++ rhs.flatMap(_._2.variables) + } + case a: LocalAssign => (s - a.lhs) ++ a.rhs.variables + case a: MemoryAssign => (s - a.lhs) ++ a.rhs.variables + case c: SimulAssign => (s -- c.assignments.map(_._1)) ++ c.assignments.flatMap(_._2.variables) + case a: MemoryLoad => (s - a.lhs) ++ a.index.variables + case m: MemoryStore => s ++ m.index.variables ++ m.value.variables + case a: Assume => s ++ a.body.variables + case a: Assert => s ++ a.body.variables + case i: IndirectCall => s + i.target + case c: DirectCall => { + ??? + } + case g: GoTo => s + case r: Return => + val outFormal = frames(r.parent.parent.name).lhs.flatMap { + // case (l: Variable, Some(r)) => Seq(l, r) + case (l: (Variable | Memory), r) => Seq(SideEffectStatementOfStatement.param(l)._2) ++ r + case (_, r) => r.toSeq + } + + (s -- r.outParams.map(_._1)) ++ outFormal ++ r.outParams.flatMap(_._2.variables) ++ Seq( + TransitionSystem.traceVar, + TransitionSystem.programCounterVar + ) + case r: Unreachable => s + case n: NOP => s + } + } + } + + def getLiveVars(p: Procedure, frames: Map[String, CallParamMapping]): Map[BlockID, Set[Variable]] = { + transforms.reversePostOrder(p) + val liveVarsDom = IntraLiveVarsDomainSideEffect(frames) + val liveVarsSolver = transforms.worklistSolver(liveVarsDom) + val (b, a) = liveVarsSolver.solveProc(p, backwards = true) + b.map((k, v) => (k.label, v)).toMap + } + + /** + * Convert an invariant to a guarded invariant for a specific cut point as described by the invariant. + * + * renaming functions provide the expression rewriting for + * - the ssa index of varibales at exit block + * - variable renaming for source/target program + * + */ + def invToPredicateInState(renameSrcSSA: Expr => Expr, renameTgtSSA: Expr => Expr)(i: Inv) = { + i match { + // FIXME: this is a huge mess -- isPost; subtle as it can introduce soundness issues + // by generating the wrong constraint, good motivation to clean it up + // TODO: globalconstraint etc aren't used idk + case Inv.GlobalConstraint(cutLabel, preds, c) => { + val pred = boolAnd( + preds.map(_.toPred(x => (exprInSource(renameSrcSSA(x))), x => exprInTarget(renameTgtSSA(x)))) + ) + Assume(pred, c) + } + case Inv.CutPoint(cutLabel, preds, c) => { + val pred = boolAnd( + preds.map(_.toPred(x => (exprInSource(renameSrcSSA(x))), x => exprInTarget(renameTgtSSA(x)))) + ) + + val rn = renameSrcSSA((BinaryExpr(EQ, TransitionSystem.programCounterVar, PCMan.PCSym(cutLabel)))) + + val guarded = + BinaryExpr(BoolIMPLIES, exprInSource(rn), pred) + Assume(guarded, c) + } + } + } + + val beforeRenamer = NamespaceState("target") + val afterRenamer = NamespaceState("source") + def exprInSource(v: Expr) = visit_expr(afterRenamer, v) + def exprInTarget(v: Expr) = visit_expr(beforeRenamer, v) + def varInSource(v: Variable) = visit_rvar(afterRenamer, v) + def varInTarget(v: Variable) = visit_rvar(beforeRenamer, v) + + def extractProg(proc: Procedure): Iterable[Expr] = { + var assumes = List[Expr]() + + val begin = proc.entryBlock.get + for (nb <- begin.forwardIteratorFrom) { + nb match { + case Assume(b, _, c, _) => { + assumes = b :: assumes + } + case Assert(b, _, c) => { + assumes = b :: assumes + } + case o: Jump => () + case o: Block => {} + case o => { + throw Exception(s"Program has other statements : $o") + } + } + } + assumes.reverse + } + + object toVariable { + class SES extends CILVisitor { + override def vrvar(v: Variable) = v match { + case g: Global => ChangeTo(SideEffectStatementOfStatement.param(g)._2) + case o => SkipChildren() + } + override def vlvar(v: Variable) = v match { + case g: Global => ChangeTo(SideEffectStatementOfStatement.param(g)._2) + case o => SkipChildren() + } + override def vexpr(e: Expr) = e match { + case m: Memory => ChangeTo(SideEffectStatementOfStatement.param(m)._2) + case _ => DoChildren() + } + } + + def apply(v: Expr) = { + visit_expr(SES(), v) + } + } + + def globalsForSourceProc(i: InterproceduralInfo, p: ProcInfo)( + renaming: Variable | Memory => Seq[Expr] + ): Iterable[(Expr, Option[Expr])] = { + val globs = for { + af <- i.sourceFrames.get(p.name) + globs: Seq[Variable | Memory] = (af.readGlobalVars ++ af.readMem ++ af.modifiedGlobalVars ++ af.modifiedMem).toSeq + boop = globs + .flatMap(x => renaming(x).map(t => x -> t)) + .map((t, s) => (t, Some(s))) + } yield (boop) + globs.getOrElse(Seq()) + } + + /** + * join two lists of compat vars requiring them to be disjoint ish + * + * Combine two partial bijections by intersection + * + * f1 : a <-> Option[b] + * f2 : b <-> Option[a] + * + * => F : a <-> b + * + * + */ + def mergeCompat( + l: List[(Expr, Option[Expr])], + l2: List[(Expr, Option[Expr])], + intersect: Boolean = false + ): List[CompatArg] = { + val srcsrc: Map[Expr, Option[Expr]] = l.toMap + val srctgt = l.collect { case (l, Some(r)) => + (r, l) + }.toMap + + val tgtsrc = l2.toMap + + val tgttgt = l2.collect { case (l, Some(r)) => + (r, l) + }.toMap + + val srcDom: Set[Expr] = (srcsrc.keys ++ tgtsrc.keys).toSet + val tgtDom: Set[Expr] = (srctgt.keys ++ tgttgt.keys).toSet + + val srcImg: Iterable[(Expr, Expr)] = srcDom.map(k => + (srcsrc.get(k).flatten, tgtsrc.get(k).flatten) match { + case (Some(v), None) => k -> v + case (None, Some(v)) => k -> v + case (Some(v1), Some(v2)) if v1 == v2 => k -> v1 + case (Some(v1), Some(v2)) => + throw Exception(s"provided src -> target and target -> src renamings disagree ${v1} != $v2") + case (None, None) => ??? + } + ) + + val tgtImg: Iterable[(Expr, Expr)] = tgtDom.map(k => + (srctgt.get(k), tgttgt.get(k)) match { + case (Some(v), None) => v -> k + case (None, Some(v)) => v -> k + case (Some(v1), Some(v2)) if v1 == v2 => v1 -> k + case (Some(v1), Some(v2)) => + throw Exception(s"provided src -> target and target -> src renamings disagree $v1 != $v2") + case (None, None) => ??? + } + ) + + val merged = + if intersect then srcImg.filter(st => tgtDom.contains(st._2)) ++ tgtImg.filter(st => srcDom.contains(st._1)) + else srcImg.toSet ++ tgtImg + + merged.map { case (s, t) => CompatArg(toVariable(s), toVariable(t)) }.toList + } + + /** + * We re-infer the function signature of all target program procedures based on the transform + * described by [[renaming]], and the [[Frame]] of the source. + * + * **this describes all the observable effects of a procedure and forms invariant we validate** + * + * Then at aver call we take the signature (traceVar @ procedureParams @ globalModSet) and map it + * to the signature we infer here. + * + * We use this to describe the entry and exit invariant for every procedure, so if it is too weak + * then the verification of the procedure will fail. + * + * If it is too strong the ackermann instantiation of the call will fail; and verification should + * fail at the call-site. + * + * This means it is possible to drop parameters (read-global-variables or actual parameters) + * as long as they aren't needed in the verification of the procedure. + * + * Because we at minimum make the global trace variable part of the function signature, a malicious + * transform should only be able to verify by deleting all functionality if it was origionally a + * pure function. Assuming we ensure invariants are not valid or false. + * + */ + def getFunctionSigsRenaming( + program: Program, + afterFrame: Map[String, Frame], + renaming: TransformDataRelationFun + ): Map[String, (CallParamMapping, CallParamMapping)] = { + + val invParam = List(TransitionSystem.traceVar, TransitionSystem.programCounterVar).map(a => (a, Some(a))) + + def param(v: Variable | Memory): (Variable | Memory, Option[Variable]) = v match { + case g: GlobalVar => (g -> Some(g)) + case g: LocalVar => (g -> None) + case m: Memory => (m -> Some(SideEffectStatementOfStatement.traceVar(m))) + } + + def getParams(p: Procedure, frame: Frame) = { + def paramTgt(entry: Boolean)(v: Variable | Memory) = { + val bl = entry match { + case true => Some(p.entryBlock.map(_.label).getOrElse(p.name)) + case false => Some(p.returnBlock.map(_.label).getOrElse(p.name)) + } + + renaming(p.name, bl)(v).flatMap { + case (n: (Variable | Memory)) => Seq(param(n)) + case _ => Seq() + }.toList match { + case h :: Nil => h + case h :: tl => h + case Nil => + throw Exception(s"Param corresponding to $v at $p.name $bl undefined") + } + } + + val lhs: List[Variable | Memory] = + p.formalOutParam.toList ++ frame.modifiedGlobalVars.toList // ++ frame.modifiedMem.toList + val rhs: List[Variable | Memory] = + p.formalInParam.toList ++ frame.readGlobalVars.toList // ++ frame.readMem.toList + + val lhsSrc = lhs.map(param) + val rhsSrc = rhs.map(param) + + val lhsTgt = lhs.map(paramTgt(false)) + val rhsTgt = rhs.map(paramTgt(true)) + + ( + CallParamMapping(invParam ++ lhsSrc, invParam ++ rhsSrc), + CallParamMapping(invParam ++ lhsTgt, invParam ++ rhsTgt) + ) + } + + program.procedures.map(p => p.name -> getParams(p, afterFrame.getOrElse(p.name, Frame()))).toMap + } + + /** + * Set invariant defining a correspondence between variables in the source and target programs. + * + * @param renamingTgtSrc provides an optional corresponding source-program expression for a target porgam + * variable. E.g. representing a substitution performed by a transform at a given block label. + * + * + * In this case if there is no v such that renamingTgtSrc(tv) -> v \in s and + * renamingSrcTgt(v) = tv \in t then it means there is no correspondence. + * In isolation None means there is no information. + * + * + * The idea is that you can provide the rewriting in either direction, as a src -> target (e.g. drop ssa indexes) + * or target -> source, e.g. copyprop. + */ + def getEqualVarsInvariantRenaming( + // block label -> variable -> renamed variable + i: InterproceduralInfo, + sourceInfo: ProcInfo, + targetInfo: ProcInfo, + renamingSrcTgt: TransformDataRelationFun = (_, _) => e => Seq(e), + renamingTgtSrc: TransformDataRelationFun = (_, _) => e => Seq(e) + ) = { + + val globalsTgt = globalsForSourceProc(i, sourceInfo)(renamingSrcTgt(sourceInfo.name, None)).toList + val globals = globalsTgt.collect { case (a, Some(b)) => CompatArg(toVariable(a), toVariable(b)) }.toList + + def paramRepr(p: Variable | Memory): Variable = { + SideEffectStatementOfStatement.param(p) match { + case (l, r) => r + } + } + + def getVars(v: (((EffCallFormalParam), Option[Expr]), ((EffCallFormalParam), Option[Expr]))) = v match { + case ((_, Some(srcActual)), (_, Some(tgtActual))) => + Seq(CompatArg(srcActual, tgtActual)) + case ((srcFormal: (Variable | Memory), _), (tgtFormal: (Variable | Memory), _)) => + Seq(CompatArg(paramRepr(srcFormal), paramRepr(tgtFormal))) + case _ => Seq() + } + + val inparams = + Inv.CutPoint("ENTRY", sourceInfo.callParams.rhs.toSeq.zip(targetInfo.callParams.rhs).flatMap(getVars)) + + val outparams = + Inv.CutPoint("RETURN", sourceInfo.callParams.lhs.toSeq.zip(targetInfo.callParams.lhs).flatMap(getVars)) + + // skipping because should be live at entry and return resp. + // val inparams = p.formalInParam.toList.map(p => CompatArg(p, p)) + // val outparams = p.formalOutParam.toList.map(p => CompatArg(p, p)) + // TODO: can probably just set at entry and let the liveness sort the rest out? + val globalsInvEverywhere = + sourceInfo.cutBlockLabels.keys + .map(c => Inv.CutPoint(c, globals.toList)) + .toList + + val cuts = (targetInfo.cutBlockLabels.keys ++ sourceInfo.cutBlockLabels.keys).toSet.toList + + tvLogger.debug(s"cuts source: ${sourceInfo.cutBlockLabels}\ntarget: ${targetInfo.cutBlockLabels}") + val invs = (cuts.map { + case (label) => { + if (!(targetInfo.cutBlockLabels.contains(label) && sourceInfo.cutBlockLabels.contains(label))) { + throw Exception( + s"Mismatched cut labels (missing $label)\nsource: ${sourceInfo.cutBlockLabels}\ntarget: ${targetInfo.cutBlockLabels}" + ) + } + val tgtCut = targetInfo.cutBlockLabels(label) + val srcCut = sourceInfo.cutBlockLabels(label) + + val srcLives = sourceInfo.liveVars.get(srcCut).toSet.flatten + val tgtLives = targetInfo.liveVars.get(tgtCut).toSet.flatten + + val tgtDefines = targetInfo.defines(tgtCut) + val srcDefines = sourceInfo.defines(srcCut) + + val invSrc = srcLives.map(s => s -> renamingSrcTgt(sourceInfo.name, Some(srcCut))(s)).flatMap { + case (l, r) => { + r.filter(_.variables.forall(v => tgtDefines.contains(v) || tgtLives.contains(v))).map { case e => + CompatArg(toVariable(l), toVariable(e)) + } + } + } + + val invTgt = tgtLives.map(s => s -> renamingTgtSrc(sourceInfo.name, Some(tgtCut))(s)).flatMap { + case (l, r) => { + r.filter(_.variables.forall(v => srcDefines.contains(v) || srcLives.contains(v))).map { case e => + CompatArg(toVariable(e), toVariable(l)) + } + } + } + + Inv.CutPoint(label, (invSrc.toSet ++ invTgt).toList, Some(s"INVARIANT at $label")) + } + }).toList + + val inv = globalsInvEverywhere ++ invs ++ Seq(inparams) ++ Seq(outparams) + inv + } + + def getFlowFactsInvariant( + // block label -> variable -> renamed variable + source: ProcInfo, + target: ProcInfo, + flowFactTgtTgt: Map[Variable, Expr] + ) = { + + val cuts = (target.cutBlockLabels.keys ++ source.cutBlockLabels.keys).toSet.toList + + val invs = (cuts.map { + case (label) => { + val tgtCut = target.cutBlockLabels(label) + val tgtLives = target.liveVars.get(tgtCut).toSet.flatten + val tgtDefines = target.defines(tgtCut) + val m = flowFactTgtTgt + .collect { + case (v, e) + if tgtLives.contains(v) && (e.variables).forall(e => + tgtDefines.contains(e) || tgtLives.contains(e) + ) /*&& e.variables.forall(tgtLives.contains) */ => + List(TargetTerm(BinaryExpr(EQ, v, e))) + case (v, e) + if tgtDefines.contains(v) && (e.variables) + .forall(tgtLives.contains) && e.variables.nonEmpty /*&& e.variables.forall(tgtLives.contains) */ => + List(TargetTerm(BinaryExpr(EQ, v, e))) + case (v, e) => + List() + } + .toList + .flatten + val i = Inv.CutPoint(label, m, Some(s"FLOWFACT at $label")) + i + } + }).toList + + invs + } + + /** + * Dump some debug logs comparing source and target programs from the model retuned when [sat], to get an idea + * of what when wrong in the validation. + */ + private def processModel( + source: ProcInfo, + target: ProcInfo, + combinedProc: Procedure, + prover: SMTProver, + invariant: Seq[Expr], + renaming: TransformDataRelationFun = (_, _) => e => Seq(e), + sourceEntry: String, + targetEntry: String + ) = { + val eval = prover.getEvaluator() + + val done: Set[Block] = combinedProc.blocks + .map(b => { + eval.evalExpr(SSADAG.blockDoneVar(b)) match { + case Some(TrueLiteral) => Seq(b) + case _ => Seq() + } + + }) + .flatten + .toSet + + val cutMap = source.cuts.cutLabelBlockInProcedure.map { case (cl, b) => + PCMan.PCSym(cl) -> cl + }.toMap + tvLogger.info(s"Cut point labels: $cutMap") + + case object Conj { + def unapply(e: Expr): Option[List[Expr]] = e match { + case BinaryExpr(BoolAND, a, b) => Some(List(a, b)) + case AssocExpr(BoolAND, a) => Some(a.toList) + case n if n.getType == BoolType => Some(List(n)) + case _ => None + } + } + + val toUnion = combinedProc.flatMap { + case Assume(Conj(conjuncts), _, _, _) => + conjuncts.map(c => { + val v = c.variables + if (v.size > 1) then v else Seq() + }) + case Assert(Conj(conjuncts), _, _) => + conjuncts.map(c => { + val v = c.variables + if (v.size > 1) then v else Seq() + }) + case _ => Seq() + } + + val (variableDependencies, variableSets) = toUnion + .foldLeft(DisjointSets[Variable]())((ds, variables) => + variables.toList match { + case h :: tl => + tl.foldLeft(ds + h)((ds, v) => (ds + v).union(h, v)._1) + case _ => ds + } + ) + .toSets + + for (i <- invariant) { + eval.evalExpr(i) match { + case Some(FalseLiteral) => + tvLogger.error(s"Part of invariant failed: $i") + i match { + case BinaryExpr(BoolIMPLIES, BinaryExpr(EQ, pc, b: BitVecLiteral), Conj(conjuncts)) => { + + val ec = eval.evalExpr(exprInSource(TransitionSystem.programCounterVar)) match { + case Some(b: BitVecLiteral) => Some(b) + case _ => None + } + + tvLogger.error( + s" Specifically: at cut point transition ${ec.flatMap(cutMap.get)} --> ${cutMap.get(b)} ($ec -> $b) " + ) + val vars = (conjuncts) + .collect(c => { + eval.evalExpr(c) match { + case Some(FalseLiteral) => { + tvLogger.error(s" $c is false") + c.variables + .map(v => + variableDependencies.find(v)._2 match { + case Some(v) => v + case None => ??? + } + ) + .toSet + } + case _ => Set() + } + }) + .toSet + .flatten + + } + case _ => () + } + case _ => () + } + } + + class CollapsePhi extends CILVisitor { + + override def vstmt(s: Statement) = s match { + case ass @ Assume(Conj(xs), _, _, _) => { + val n = xs.toSeq.flatMap { + case bdy @ BinaryExpr(BoolIMPLIES, bld, rhs) => { + eval.evalExpr(bld) match { + case Some(FalseLiteral) => Seq() + case Some(TrueLiteral) => Seq(rhs) + case _ => Seq(bdy) + } + } + case x => Seq(x) + } + + ass.body = boolAnd(n) + SkipChildren() + } + case _ => SkipChildren() + } + } + + def getTrace(starting: String): Unit = { + + val b = combinedProc.blocks.find(_.label == starting).get + + def isReached(l: Block) = { + eval.evalExpr(SSADAG.blockDoneVar(l)) match { + case Some(TrueLiteral) => true + case _ => false + } + } + + if (!isReached(b)) { + return () + } + + def pt(b: Block, indent: Int = 0): Unit = { + if (isReached(b)) { + tvLogger.info(" ".repeat(indent * 2) + b.label) + } + + var n = b.nextBlocks.filter(isReached) + + while (n.size == 1) { + tvLogger.info(" ".repeat(indent * 2) + n.head.label) + n = n.head.nextBlocks.filter(isReached) + } + + for (nn <- n) { + pt(nn, indent + 1) + } + } + + pt(b) + + } + + class ComparVals extends CILVisitor { + override def vstmt(statement: Statement) = statement match { + case a @ Assert(BinaryExpr(BoolIMPLIES, Conj(prec), Conj(ante)), Some(com), _) if com.startsWith("ack") => { + // ackermann structure + + val triggered = prec.map(eval.evalExpr).forall(_.contains(TrueLiteral)) + + val bad = prec + .flatMap(e => { + eval.evalExpr(e) match { + case Some(TrueLiteral) => Seq() + case _ => e.variables + } + }) + .map(v => v -> eval.evalExpr(v)) + .map(_.toString) + + val precedent = + if (triggered) then "true" + else prec.map(p => pp_expr(p) + ":=" + eval.evalExpr(p).map(pp_expr)).mkString(" && ") + "\n" + + val reason = precedent + " ==> " + ante + .map(p => "(" + pp_expr(p) + " is " + eval.evalExpr(p).map(pp_expr) + ")") + .mkString("\n&& ") + a.comment = Some(com + "\n " + reason + "\n vars: " + bad) + SkipChildren() + } + case a => { + val vars = freeVarsPos(a).filter(v => v.name.startsWith("source__") || v.name.startsWith("target__")) + val pcomment = a.comment.getOrElse("") + val proc = statement.parent.parent.name + val blockLabel = Some(afterRenamer.stripNamespace(statement.parent.label)) + val compar = vars + .filter(_.name.startsWith("source__")) + .map(b => + val name = b.name.stripPrefix("source__").stripPrefix("target__") + val (sv, tv) = b match { + case GlobalVar(v, ty) => { + val s = GlobalVar(name, ty) + val t = (renaming(proc, blockLabel)(s)).headOption.getOrElse(s) + (exprInSource(s), exprInTarget(t)) + } + case LocalVar(v, ty, i) => { + val s = LocalVar(name, ty) + // FIXME: seq + val t = renaming(proc, blockLabel)(s).headOption.getOrElse(s) + (exprInSource(s), exprInTarget(t)) + } + } + val eq = eval.evalExpr(BinaryExpr(EQ, sv, tv)) + val (s, t) = (eval.evalExpr(sv), eval.evalExpr(tv)) + eq match { + case Some(TrueLiteral) => s"($name matches)" + case Some(FalseLiteral) => s"($sv NOT MATCHING $tv : $s != $t)" + case None => s"($name $s != $t)" + case _ => ??? + } + ) + .mkString(", ") + a.comment = Some(pcomment + " " + compar) + SkipChildren() + } + } + } + + tvLogger.info("Trace source:") + getTrace(sourceEntry) + tvLogger.info("Trace target:") + getTrace(targetEntry) + + visit_proc(CollapsePhi(), combinedProc) + visit_proc(ComparVals(), combinedProc) + + ir.dotBlockGraph(combinedProc.blocks.toList, done) + + } + + def inferInvariant( + interproc: InterproceduralInfo, + invariant: InvariantDescription, + sourceInfo: ProcInfo, + targetInfo: ProcInfo + ): List[Inv] = { + + val equalVarsInvariant = + getEqualVarsInvariantRenaming( + interproc, + sourceInfo, + targetInfo, + invariant.renamingSrcTgt, + invariant.renamingTgtSrc + ) + + val cuts = + targetInfo.cuts.cutLabelBlockInProcedure.map(_._1) ++ sourceInfo.cuts.cutLabelBlockInProcedure.map(_._1) + + val alwaysInv = List( + CompatArg(TransitionSystem.programCounterVar, TransitionSystem.programCounterVar), + CompatArg(TransitionSystem.traceVar, TransitionSystem.traceVar) + ) + + val invEverywhere = cuts.toList.map(label => Inv.CutPoint(label, alwaysInv, Some(s"GlobalConstraint$label"))) + + // val factsInvariant = getFlowFactsInvariant(sourceInfo, targetInfo, invariant.flowFacts(sourceInfo.name)) + + val concreteInvariant = equalVarsInvariant ++ invEverywhere + + concreteInvariant + } + + private def validateSMTSingleProc( + config: TVJob, + interproc: InterproceduralInfo, + runName: String, + splitName: String, + procTransformed: Procedure, + invariant: InvariantDescription, + concreteInvariant: List[Inv], + sourceInfo: ProcInfo, + targetInfo: ProcInfo + ): TVResult = { + val runNamePrefix = "$" + runName + "$" + procTransformed.name + "$" + splitName + val proc = procTransformed + + tvLogger.info("Generating TV for : " + runNamePrefix) + + val timer = PerformanceTimer(runNamePrefix, LogLevel.DEBUG, tvLogger) + + val source = sourceInfo.transition // afterProg.get.procedures.find(_.name == proc.name).get + val target = targetInfo.transition // beforeProg.get.procedures.find(_.name == proc.name).get + + // val preInv = invariant + + val ackInv = + Ackermann.instantiateAxioms( + sourceInfo.transition.entryBlock.get, + targetInfo.transition.entryBlock.get, + exprInSource, + exprInTarget, + invariant.renamingSrcTgt + ) + + val preInv = (concreteInvariant.map( + invToPredicateInState( + e => sourceInfo.renameSSA(sourceInfo.cuts.cutLabelBlockInTr("ENTRY").label, e), + e => targetInfo.renameSSA(targetInfo.cuts.cutLabelBlockInTr("ENTRY").label, e) + ) + ) ++ Seq( + Assume( + BinaryExpr( + EQ, + exprInSource(TransitionSystem.programCounterVar), + exprInTarget(TransitionSystem.programCounterVar) + ), + Some("GLOBALINVSOURCE") + ) + )) + + val primedInv = concreteInvariant + .map( + invToPredicateInState( + e => sourceInfo.renameSSA(sourceInfo.cuts.cutLabelBlockInTr("EXIT").label, e), + e => targetInfo.renameSSA(targetInfo.cuts.cutLabelBlockInTr("EXIT").label, e) + ) + ) + .map(_.body) + + visit_proc(afterRenamer, source) + visit_proc(beforeRenamer, target) + + SSADAG.passify(source) + SSADAG.passify(target) + + timer.checkPoint("passify") + + // build smt query + var b = translating.BasilIRToSMT2.SMTBuilder() + val solver = config.verify.map(solver => util.SMT.SMTSolver(Some(1000), solver)) + val prover = solver.map(_.getProver(true)) + + // b.addCommand("set-logic", "QF_UFBV") + + var count = 0 + + lazy val newProg = if (config.debugDumpAlways || config.verify.isDefined) { + Some(combineProcs(source, target)) + } else None + lazy val npe = newProg.map(_.mainProcedure.entryBlock.get) + + val cutR = sourceInfo.cutRestict.foreach(cutLabel => { + b.addAssert(exprInSource(BinaryExpr(EQ, PCMan.PCSym(cutLabel), TransitionSystem.programCounterVar))) + }) + + count = 0 + for (e <- preInv) { + count += 1 + val l = e.comment match { + case None => Some(s"inv$count") + case Some(s) => Some(s"${s.replace(' ', '_')}_inv$count") + } + b.addAssert(e.body, Some(s"inv$count")) + prover.map(_.addConstraint(e.body)) + npe.map(_.statements.append(Assert(e.body, l))) + } + + count = 0 + for ((ack, ackn) <- ackInv) { + count += 1 + val l = Some(s"ackermann$ackn$count") + npe.map(_.statements.append(Assert(ack, l))) + prover.map(_.addConstraint(ack)) + b.addAssert(ack, l) + } + count = 0 + for (i <- extractProg(source)) { + count += 1 + b.addAssert(i, Some(s"source$count")) + prover.map(_.addConstraint(i)) + } + count = 0 + for (i <- extractProg(target)) { + count += 1 + prover.map(_.addConstraint(i)) + b.addAssert(i, Some(s"tgt$count")) + } + + val sourceAssumeFail = + BinaryExpr( + EQ, + exprInSource( + sourceInfo.renameSSA( + sourceInfo.cuts.cutLabelBlockInTr("EXIT").label.stripPrefix("source__"), + TransitionSystem.programCounterVar + ) + ), + PCMan.PCSym(PCMan.assumptionFailLabel) + ) + + val pinv = UnaryExpr(BoolNOT, BinaryExpr(BoolOR, sourceAssumeFail, AssocExpr(BoolAND, primedInv.toList))) + npe.map(_.statements.append(Assert(pinv, Some("InvPrimed")))) + b.addAssert(pinv, Some("InvPrimed")) + + val requ = Seq("ASSUMEFAIL", "ASSERTFAIL", "ENTRY", "EXIT") + val pcPost = boolOr( + (targetInfo.cutBlockLabels.keys ++ sourceInfo.cutBlockLabels.keys ++ requ).toSet.toList + .map(cutLabel => PCMan.PCSym(cutLabel)) + .map(cutSym => BinaryExpr(EQ, cutSym, TransitionSystem.programCounterVar)) + ) + + val s = exprInSource( + sourceInfo.renameSSA(sourceInfo.cuts.cutLabelBlockInTr("EXIT").label.stripPrefix("source__"), pcPost) + ) + b.addAssert(s, Some("PCDomainPostSource")) + + // val rn = renameSrcSSA((BinaryExpr(EQ, TransitionSystem.programCounterVar, PCMan.PCSym(cutLabel)))) + b.addAssert( + exprInTarget( + targetInfo.renameSSA(targetInfo.cuts.cutLabelBlockInTr("EXIT").label.stripPrefix("target__"), pcPost) + ), + Some("PCDomainPostTarget") + ) + + prover.map(_.addConstraint(pinv)) + timer.checkPoint("extract prog") + + val smtPath = config.outputPath.map(f => s"$f/${runNamePrefix}.smt2") + + smtPath.foreach(fname => { + b.writeCheckSatToFile(File(fname)) + tvLogger.info(s"Write query $fname") + timer.checkPoint("writesmtfile") + }) + + val verified = prover.map(prover => { + val r = prover.checkSat(Some(1000)) + timer.checkPoint("checksat") + (prover, r) + }) + + config.outputPath.foreach(path => { + tvLogger.writeToFile(File(s"${path}/${runNamePrefix}.il"), translating.PrettyPrinter.pp_proc(procTransformed)) + }) + + if (config.debugDumpAlways) { + config.outputPath.foreach(path => { + newProg.foreach(newProg => + tvLogger.writeToFile( + File(s"${path}/${runNamePrefix}-combined.il"), + translating.PrettyPrinter.pp_prog(newProg) + ) + ) + // tvLogger.writeToFile(File(s"${path}/${runNamePrefix}.il"), translating.PrettyPrinter.pp_proc(procTransformed)) + }) + } + + verified.foreach((prover, _) => { + val res = prover.checkSat() + res match { + case SatResult.UNSAT => tvLogger.info("unsat") + case SatResult.SAT(m) => { + tvLogger.error(s"sat ${runNamePrefix} (verify failed)") + + val g = processModel( + sourceInfo, + targetInfo, + newProg.get.mainProcedure, + prover, + primedInv.toList, + invariant.renamingSrcTgt, + source.entryBlock.get.label, + target.entryBlock.get.label + ) + + config.outputPath.foreach(path => { + tvLogger.writeToFile(File(s"${path}/${runNamePrefix}-counterexample-combined-${proc.name}.dot"), g) + if (!config.debugDumpAlways) { + tvLogger.writeToFile( + File(s"${path}/${runNamePrefix}-combined.il"), + translating.PrettyPrinter.pp_prog(newProg.get) + ) + tvLogger.writeToFile( + File(s"${path}/${runNamePrefix}.il"), + translating.PrettyPrinter.pp_proc(procTransformed) + ) + } + }) + } + case SatResult.Unknown(m) => tvLogger.info(s"unknown: $m") + } + timer.checkPoint("model-extract-debug") + }) + + prover.foreach(prover => { + prover.close() + solver.get.close() + }) + + if (config.verify.contains(Solver.CVC5)) { + io.github.cvc5.Context.deletePointers() + } + + // throw away transition system + source.clearBlocks() + target.clearBlocks() + + TVResult(runName, proc.name, verified.map(_._2), smtPath, timer.checkPoints().toMap) + } + + /** + * Generate an SMT query for the product program, + * + * @param invariant.renamingSrcTgt + * function describing the transform as mapping from variable -> expression at a given block in the resulting program, + * using a lambda Option[BlockId] => Variable | Memory => Option[Expr] + * + * @param filePrefix + * Where filepath prefix where the SMT query is written to. + * + * + * Returns a map from proceudre -> smt query + */ + def getValidationSMT( + program: Program, + config: TVJob, + runName: String, + targetProgClone: Program, + sourceProgClone: Program, + invariant: InvariantDescription + ): TVJob = { + + val framesTarget = inferProcFrames(targetProgClone).map((k, v) => (k.name, v)).toMap + val framesSource = inferProcFrames(sourceProgClone).map((k, v) => (k.name, v)).toMap + + val interesting = program.procedures + .filterNot(_.isExternal.contains(true)) + .filterNot(_.procName.startsWith("indirect_call_launchpad")) + .filter(n => framesTarget.contains(n.name)) + .filter(n => framesSource.contains(n.name)) + + val paramMapping: Map[String, (CallParamMapping, CallParamMapping)] = + getFunctionSigsRenaming(sourceProgClone, framesSource, invariant.renamingSrcTgt) + + val sourceParams: Map[String, CallParamMapping] = paramMapping.toSeq.map { case (pn, (source, target)) => + (pn, source) + }.toMap + val targetParams: Map[String, CallParamMapping] = paramMapping.toSeq.map { case (pn, (source, target)) => + (pn, target) + }.toMap + + val interproc = InterproceduralInfo(sourceProgClone, framesSource, framesTarget, sourceParams, targetParams) + + def procToTrInplace(p: Procedure, params: Map[String, CallParamMapping], introducedAsserts: Set[String]) = { + + ir.transforms.reversePostOrder(p) + val liveVars: Map[String, Set[Variable]] = getLiveVars(p, params) + (PCMan.assertFailBlockLabel -> Set()) + + val cuts = TransitionSystem.toTransitionSystemInPlace(p) + + TransitionSystem.totaliseAsserts(p, introducedAsserts) + + TransitionSystem.removeUnreachableBlocks(p) + + SSADAG.convertToMonadicSideEffect(params, p) + + ProcInfo(p.name, p, liveVars, cuts, params(p.name), (_, _) => ???, Map()) + } + + def ssaProcInfo(p: ProcInfo, params: Map[String, CallParamMapping]) = { + ir.transforms.reversePostOrder(p.transition) + + // val liveVars: Map[String, Set[Variable]] = getLiveVars(p.transition, params) + + val (renameSSA, defines) = SSADAG.ssaTransform(p.transition, p.liveVars) + + p.copy(ssaRenamingFun = renameSSA, ssaDefines = defines) + } + + var result = config + interesting.foreach(proc => { + val sourceProc = sourceProgClone.procedures.find(_.name == proc.name).get + val targetProc = targetProgClone.procedures.find(_.name == proc.name).get + + val source = procToTrInplace(sourceProc, sourceParams, invariant.introducedAsserts) + val target = procToTrInplace(targetProc, targetParams, Set()) + + val runNamePrefix = runName + "-" + proc.name + tvLogger.debug(runNamePrefix) + + config.outputPath.foreach(path => { + tvLogger.writeToFile(File(s"${path}/${runNamePrefix}.il"), translating.PrettyPrinter.pp_proc(proc)) + }) + + val concreteInvariant = inferInvariant(interproc, invariant, source, target) + + if (!config.dryRun) { + if ( + config.splitLargeProceduresThreshold + .exists(_ <= (procComplexity(source.transition) + procComplexity(target.transition))) + ) { + + var splits = 0 + val splitProcs = chooseCuts(source, target) + tvLogger.info(s"Splitting ${proc.name} into ${splitProcs.size} elements") + for ((sourceSplit, targetSplit) <- splitProcs) { + splits += 1 + + val sourceSSA = ssaProcInfo(sourceSplit, sourceParams) + val targetSSA = ssaProcInfo(targetSplit, targetParams) + + val res = validateSMTSingleProc( + result, + interproc, + runName, + "split-" + splits, + proc, + invariant, + concreteInvariant, + sourceSSA, + targetSSA + ) + result = result.copy(results = res :: result.results) + } + + } else { + val sourceSSA = ssaProcInfo(source, sourceParams) + val targetSSA = ssaProcInfo(target, targetParams) + + val res = validateSMTSingleProc( + result, + interproc, + runName, + "thesplit", + proc, + invariant, + concreteInvariant, + sourceSSA, + targetSSA + ) + result = result.copy(results = res :: result.results) + } + } + }) + + result + } + + def procComplexity(p: Procedure) = { + p.foldLeft(1)((a, _) => a + 1) + } + + def chooseCuts(source: ProcInfo, target: ProcInfo) = { + + require(source.cutBlockLabels.keys.toSet == target.cutBlockLabels.keys.toSet) + + def findCut(b: Block, thecuts: Map[BlockID, String]) = { + var search: Block = b + + while (!thecuts.contains(search.label)) { + assert(search.nextBlocks.size == 1) + + search = search.nextBlocks.head + + } + + thecuts(search.label) -> b + } + + val sourceCuts = source.cutBlockLabels.map((lbl, block) => (block, lbl)).toMap + val targetCuts = target.cutBlockLabels.map((lbl, block) => (block, lbl)).toMap + + val nextS = source.transition.entryBlock.get.nextBlocks.map(findCut(_, sourceCuts)) + val nextT = target.transition.entryBlock.get.nextBlocks.map(findCut(_, targetCuts)) + + assert(nextS.map(_._1).toSet.size == nextT.map(_._1).toSet.size) + assert(nextS.map(_._2).toSet.size == nextT.map(_._2).toSet.size) + val targetBlockForCut = nextT.toMap + + /** + * slice: the only target of proc.entryBlock to keep + * - must be a successor of proc.transition.entryBlock + */ + def cloneWithSlice(slice: BlockID, cutLabel: String, proc: ProcInfo) = { + val np = ir.dsl.IRToDSL.cloneSingleProcedure(proc.transition) + + val bl = proc.transition.blocks.find(_.label == slice).get + + val jump = np.mainProcedure.entryBlock.get.jump match { + case g: GoTo => g + case _ => throw Exception("unexpected") + } + + val toRemove = jump.targets.filterNot(_.label == slice).toList + toRemove.foreach(jump.removeTarget) + + TransitionSystem.removeUnreachableBlocks(np.mainProcedure) + proc.copy(transition = np.mainProcedure, cutRestict = Some(cutLabel)) + } + + nextS.map { + case (cutLabel, sourceBlock) => { + val newsource = cloneWithSlice(sourceBlock.label, cutLabel, source) + + val targetBlock = targetBlockForCut(cutLabel).label + + val newtarget = cloneWithSlice(targetBlock, cutLabel, target) + + (newsource, newtarget) + } + } + } + + def forTransform[T]( + transformName: String, + transform: Program => T, + invariant: T => InvariantDescription = (_: T) => InvariantDescription() + ): ((Program, TVJob) => TVJob) = { (p: Program, tvconf: TVJob) => + { + val before = ir.dsl.IRToDSL.convertProgram(p).resolve + + val beforeprocs = before.nameToProcedure + for (p <- p.procedures) { + assert(p.blocks.map(_.label).corresponds(beforeprocs(p.procName).blocks.map(_.label).toList)(_.equals(_))) + } + + val r = transform(p) + val inv = invariant(r) + val after = ir.dsl.IRToDSL.convertProgram(p).resolve + getValidationSMT(p, tvconf, transformName, before, after, inv) + } + } + +} diff --git a/src/main/scala/specification/Specification.scala b/src/main/scala/specification/Specification.scala index 4734fb4866..684b543d55 100644 --- a/src/main/scala/specification/Specification.scala +++ b/src/main/scala/specification/Specification.scala @@ -1,9 +1,10 @@ package specification import boogie.* -import ir.* +import ir.Expr import ir.dsl.given +case class LoopInvariant(header: String, inv: List[Expr], comment: Option[String] = None) trait SymbolTableEntry { val name: String val size: Int diff --git a/src/main/scala/translating/GTIRBLoader.scala b/src/main/scala/translating/GTIRBLoader.scala index 7026ed5be2..d8f9a2fa25 100644 --- a/src/main/scala/translating/GTIRBLoader.scala +++ b/src/main/scala/translating/GTIRBLoader.scala @@ -386,6 +386,7 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) { case "and_bits.0" => resolveBinaryOp(BVAND, function, 1, typeArgs, args, ctx.getText) case "eor_bits.0" => resolveBinaryOp(BVXOR, function, 1, typeArgs, args, ctx.getText) case "eq_bits.0" => resolveBinaryOp(EQ, function, 1, typeArgs, args, ctx.getText) + case "ne_bits.0" => resolveBinaryOp(NEQ, function, 1, typeArgs, args, ctx.getText) case "add_bits.0" => resolveBinaryOp(BVADD, function, 1, typeArgs, args, ctx.getText) case "sub_bits.0" => resolveBinaryOp(BVSUB, function, 1, typeArgs, args, ctx.getText) case "mul_bits.0" => resolveBinaryOp(BVMUL, function, 1, typeArgs, args, ctx.getText) diff --git a/src/main/scala/translating/GTIRBToIR.scala b/src/main/scala/translating/GTIRBToIR.scala index 4ec4175e64..68b85aa1aa 100644 --- a/src/main/scala/translating/GTIRBToIR.scala +++ b/src/main/scala/translating/GTIRBToIR.scala @@ -346,7 +346,7 @@ class GTIRBToIR( block.address.foreach { addr => val pcCorrectExpr = BinaryExpr(EQ, Register("_PC", 64), BitVecLiteral(addr, 64)) - val assertPC = Assert(pcCorrectExpr, Some("pc-tracking"), Some("pc-tracking")) + val assertPC = Assert(pcCorrectExpr, Some("pc-tracking " + pcCorrectExpr), Some("pc-tracking")) block.statements.append(assertPC) } block @@ -723,11 +723,18 @@ class GTIRBToIR( ): (Option[Call], Jump) = { // TODO add assertion that target register is low val label = handlePCAssign(block) + val targetRegister = getPCTarget(block) val withinProcedureTargets = targets.collect { case t: Block if procedure.blocks.contains(t) => t } + val assertion = boolOr(targets.map(target => BinaryExpr(EQ, targetRegister, BitVecLiteral(target.address.get, 64)))) if (withinProcedureTargets.size == targets.size) { // all target blocks are within the calling procedure + for (target <- targets) { + val assume = Assume(BinaryExpr(EQ, targetRegister, BitVecLiteral(target.address.get, 64))) + target.statements.prepend(assume) + } + block.statements.append(Assert(assertion)) (None, GoTo(targets, label)) } else if (withinProcedureTargets.nonEmpty) { // TODO - only some target blocks are within the calling procedure - unclear how to handle @@ -748,7 +755,6 @@ class GTIRBToIR( } else { // indirect jump targeting multiple blocks outside this procedure val newBlocks = ArrayBuffer[Block]() - val targetRegister = getPCTarget(block) for (targetBlock <- targets) { if (!targetBlock.isEntry) { @@ -761,11 +767,13 @@ class GTIRBToIR( val resolvedCall = DirectCall(target) val assume = Assume(BinaryExpr(EQ, targetRegister, BitVecLiteral(target.address.get, 64))) + val label = block.label + "_" + target.name // unreachable because R30 is not set and any returns will need to be manually resolved through later analysis newBlocks.append(Block(label, None, ArrayBuffer(assume, resolvedCall), Unreachable())) } handlePCAssign(block) + block.statements.append(Assert(assertion)) procedure.addBlocks(newBlocks) (None, GoTo(newBlocks)) } @@ -804,6 +812,12 @@ class GTIRBToIR( } handlePCAssign(block) procedure.addBlocks(newBlocks) + val assertion = boolOr( + indirectCallTargets.map(target => + BinaryExpr(EQ, targetRegister, BitVecLiteral(entranceUUIDtoProcedure(target.targetUuid).address.get, 64)) + ) + ) + block.statements.append(Assert(assertion)) GoTo(newBlocks) } diff --git a/src/main/scala/translating/IRExpToSMT2.scala b/src/main/scala/translating/IRExpToSMT2.scala index 8d93c9841e..00e737ce87 100644 --- a/src/main/scala/translating/IRExpToSMT2.scala +++ b/src/main/scala/translating/IRExpToSMT2.scala @@ -2,7 +2,9 @@ package translating import ir.* import ir.cilvisitor.* -import util.{OnCrash, RingTrace} +import util.Logger + +import java.io.{BufferedWriter, File, FileWriter, Writer} trait BasilIR[Repr[+_]] extends BasilIRExp[Repr] { // def vstmt(s: Statement) : Repr[Statement] @@ -47,7 +49,7 @@ trait BasilIR[Repr[+_]] extends BasilIRExp[Repr] { case ZeroExtend(bits, arg) => vzeroextend(bits, vexpr(arg)) case SignExtend(bits, arg) => vsignextend(bits, vexpr(arg)) case BinaryExpr(op, arg, arg2) => vbinary_expr(op, vexpr(arg), vexpr(arg2)) - case b @ AssocExpr(op, arg) => vexpr(b.toBinaryExpr) + case b @ AssocExpr(op, args) => vassoc_expr(op, args.map(vexpr)) case UnaryExpr(op, arg) => vunary_expr(op, vexpr(arg)) case v: Variable => vrvar(v) case f @ FApplyExpr(n, params, rt, _) => vfapply_expr(n, params.map(vexpr)) @@ -114,6 +116,7 @@ trait BasilIR[Repr[+_]] extends BasilIRExp[Repr] { trait BasilIRExp[Repr[+_]] { def vexpr(e: Expr): Repr[Expr] + def vassoc_expr(o: BoolBinOp, es: List[Repr[Expr]]): Repr[Expr] def vextract(ed: Int, start: Int, a: Repr[Expr]): Repr[Expr] def vquantifier(q: QuantifierExpr): Repr[Expr] def vlambda(q: LambdaExpr): Repr[Expr] @@ -121,7 +124,6 @@ trait BasilIRExp[Repr[+_]] { def vzeroextend(bits: Int, b: Repr[Expr]): Repr[Expr] def vsignextend(bits: Int, b: Repr[Expr]): Repr[Expr] def vbinary_expr(e: BinOp, l: Repr[Expr], r: Repr[Expr]): Repr[Expr] - def vbool_expr(e: BoolBinOp, l: List[Repr[Expr]]): Repr[Expr] def vunary_expr(e: UnOp, arg: Repr[Expr]): Repr[Expr] def vliteral(l: Literal): Repr[Literal] = { l match { @@ -164,7 +166,7 @@ trait BasilIRExpWithVis[Repr[+_]] extends BasilIRExp[Repr] { } case UnaryExpr(op, arg) => vunary_expr(op, vexpr(arg)) case v: Variable => vrvar(v) - case b @ AssocExpr(op, args) => vbool_expr(op, args.map(vexpr)) + case b @ AssocExpr(op, args) => vassoc_expr(op, args.map(vexpr)) case r: SharedMemory => ??? case r: StackMemory => ??? case f @ FApplyExpr(n, params, rt, _) => vfapply_expr(n, params.map(vexpr)) @@ -178,11 +180,24 @@ trait BasilIRExpWithVis[Repr[+_]] extends BasilIRExp[Repr] { enum Sexp[+T] { case Symb(v: String) - case Slist(v: List[Sexp[T]]) + case Slist(v: Iterable[Sexp[T]]) } object Sexp { + def write[T](b: Writer, s: Sexp[T]): Unit = s match { + case Sexp.Symb(a) => b.append(a) + case Sexp.Slist(v) => { + b.append("(") + for (s <- v) { + write(b, s) + b.append(" ") + } + b.append(")") + } + + } + def print[T](s: Sexp[T]): String = s match { case Sexp.Symb(a) => a case Sexp.Slist(v) => "(" + v.map(print).mkString(" ") + ")" @@ -190,13 +205,10 @@ object Sexp { } def sym[T](l: String): Sexp[T] = Sexp.Symb[T](l) -def list[T](l: Sexp[T]*): Sexp[T] = Sexp.Slist(l.toList) +def list[T](l: Sexp[T]*): Sexp[T] = Sexp.Slist(l) -val dumpTrace = RingTrace[String](3, "BasilIRToSMT2") object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { - OnCrash.register(dumpTrace) - def vload(lhs: Sexp[Variable], mem: String, index: Sexp[Expr], endian: Endian, size: Int): Sexp[MemoryLoad] = ??? def vstore(mem: Memory, index: Sexp[Expr], value: Sexp[Expr], endian: Endian, size: Int): Sexp[MemoryStore] = ??? @@ -212,25 +224,40 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { def vintlit(b: BigInt): Sexp[IntLiteral] = ??? class SMTBuilder() { + + enum Cmd { + case AssertExp(e: Expr, n: Option[String]) + case Raw(e: Sexp[Expr]) + + def toSexp = this match { + case Cmd.Raw(e) => e + case Cmd.AssertExp(e, name) => { + val expr: Sexp[Expr] = BasilIRToSMT2.vexpr(e) + val inner: Sexp[Expr] = name.map(n => list(sym("!"), expr, sym(":named"), sym(n))).getOrElse(expr) + list(sym("assert"), inner) + } + } + } + var before = true - var exprs = Vector[Sexp[Expr]]() + var exprs = List[Cmd]() var exprsBefore = Vector[Sexp[Expr]]() var decls = Set[Sexp[Expr]]() var typedecls = Set[Sexp[Expr]]() - def addAssume(e: Expr) = { - before = false - val (t, d) = BasilIRToSMT2.extractDecls(e) - decls = decls ++ d - typedecls = typedecls ++ t - exprs = exprs ++ List(list(sym("assume"), BasilIRToSMT2.vexpr(e))) - } + // def addAssume(e: Expr) = { + // before = false + // val (t, d) = BasilIRToSMT2.extractDecls(e) + // decls = decls ++ d + // typedecls = typedecls ++ t + // exprs = exprs ++ List(list(sym("assume"), BasilIRToSMT2.vexpr(e))) + // } def addCommand(rawSexp: String*) = { if (before) { exprsBefore = exprsBefore.appended(list(rawSexp.map(sym[Expr](_)): _*)) } else { - exprs = exprs.appended(list(rawSexp.map(sym[Expr](_)): _*)) + exprs = Cmd.Raw(list(rawSexp.map(sym[Expr](_)): _*)) :: exprs } } @@ -239,18 +266,53 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { val (t, d) = BasilIRToSMT2.extractDecls(e) decls = decls ++ d typedecls = typedecls ++ t - val expr: Sexp[Expr] = BasilIRToSMT2.vexpr(e) - val inner: Sexp[Expr] = name.map(n => list(sym("!"), expr, sym(":named"), sym(n))).getOrElse(expr) + exprs = Cmd.AssertExp(e, name) :: exprs + } + + def writeCheckSat(b: Writer, getUnsatCore: Boolean = false) = { + val setUnsat = + if getUnsatCore then Seq(list(sym("set-option"), sym(":produce-unsat-cores"), sym("true"))) else Seq() + val getUnsat = if getUnsatCore then Seq(list(sym("get-unsat-core"))) else Seq() + + def psexp(p: Sexp[Expr]) = { + Sexp.write(b, p) + b.append("\n") + } - exprs = exprs ++ List(list(sym("assert"), inner)) + setUnsat.foreach(psexp) + exprsBefore.foreach(psexp) + typedecls.foreach(psexp) + decls.foreach(psexp) + exprs.foreach(e => psexp(e.toSexp)) + // b.append("(check-sat-using (then (repeat (then (repeat (then euf-completion simplify)) (par-or (try-for smt 5000) skip))) smt))") + psexp(list(sym("check-sat"))) + getUnsat.foreach(psexp) } - def getCheckSat() = { - (exprsBefore.toVector ++ typedecls ++ decls ++ exprs ++ List(list(sym("check-sat")))) - .map(Sexp.print) - .mkString("\n") + def writeCheckSatToFile(fname: File, getUnsatCore: Boolean = false): Unit = { + val fw = FileWriter(fname) + val f = BufferedWriter(fw) + try { + writeCheckSat(f, getUnsatCore) + } finally { + if (f != null) { + f.close() + } + if (fw != null) { + fw.close() + } + } } + // def getCheckSat(getUnsatCore: Boolean = false) = { + // val setUnsat = + // if getUnsatCore then Seq(list(sym("set-option"), sym(":produce-unsat-cores"), sym("true"))) else Seq() + // val getUnsat = if getUnsatCore then Seq(list(sym("get-unsat-core"))) else Seq() + + // setUnsat.iterator ++ exprsBefore.iterator ++ typedecls ++ decls ++ exprs.view.map(_()) ++ Seq( + // list(sym("check-sat")) + // ) ++ getUnsat + // } } /** Immediately invoke z3 and block until it returns a result. @@ -326,11 +388,9 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { override def vextract(ed: Int, start: Int, a: Sexp[Expr]): Sexp[Expr] = list(list(sym("_"), sym("extract"), int2smt(ed - 1), int2smt(start)), a) override def vbinary_expr(e: BinOp, l: Sexp[Expr], r: Sexp[Expr]): Sexp[Expr] = { - dumpTrace.add(e.toString + "(" + l + "," + r + ")") list(sym(opnameToFun(e)), l, r) } - override def vbool_expr(e: BoolBinOp, l: List[Sexp[Expr]]): Sexp[Expr] = - dumpTrace.add(e.toString + "(" + l.mkString(",") + ")") + override def vassoc_expr(e: BoolBinOp, l: List[Sexp[Expr]]): Sexp[Expr] = Sexp.Slist(sym(opnameToFun(e)) :: l) override def vunary_expr(e: UnOp, arg: Sexp[Expr]): Sexp[Expr] = list(sym(unaryOpnameToFun(e)), arg) @@ -341,14 +401,25 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { case FalseLiteral => sym("false") } + def mkIte(cases: List[Seq[Sexp[Expr]]]): Sexp[Expr] = { + cases match { + case Seq(cond, casev) :: Nil => casev + case Seq(cond, casev) :: tl => list(sym("ite"), (cond), (casev), mkIte(tl)) + } + } + def endianToBool(endian: Endian): Sexp[Expr] = { if endian == Endian.LittleEndian then vexpr(FalseLiteral) else vexpr(TrueLiteral) } override def vfapply_expr(name: String, args: Seq[Sexp[Expr]]): Sexp[Expr] = { - if (args.size == 1) { - list(sym(name), args.head) - } else { - list(sym(name), Sexp.Slist(args.toList)) + name match { + case "ite" => { + val cases = args.grouped(2) + mkIte(cases.toList) + } + case _ => { + list(sym(name) :: args.toList: _*) + } } } @@ -376,6 +447,7 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { def interpretFun(x: FApplyExpr): Option[Sexp[Expr]] = { x.name match { + case "ite" => None case "bool2bv1" => { Some(booltoBVDef) } @@ -390,6 +462,9 @@ object BasilIRToSMT2 extends BasilIRExpWithVis[Sexp] { ) ) } + case _ => + Logger.warn(s"Undeclared uninterp emitted : ${x.name}") + None } } diff --git a/src/main/scala/translating/PrettyPrinter.scala b/src/main/scala/translating/PrettyPrinter.scala index 4668b2e91b..4a2e1b89e9 100644 --- a/src/main/scala/translating/PrettyPrinter.scala +++ b/src/main/scala/translating/PrettyPrinter.scala @@ -542,6 +542,7 @@ class BasilIRPrettyPrinter( override def vindirect(target: PPProg[Variable]): PPProg[IndirectCall] = BST(s"indirect call ${target} ") override def vassert(body: Assert): PPProg[Assert] = { + BST(s"assert ${vexpr(body.body)}") } @@ -592,7 +593,7 @@ class BasilIRPrettyPrinter( val opn = e.getClass.getSimpleName.toLowerCase.stripSuffix("$") BST(s"$opn($l, $r)") } - override def vbool_expr(e: BoolBinOp, l: List[PPProg[Expr]]): PPProg[Expr] = { + override def vassoc_expr(e: BoolBinOp, l: List[PPProg[Expr]]): PPProg[Expr] = { val opn = e.getClass.getSimpleName.toLowerCase.stripSuffix("$") BST(s"$opn(${l.mkString(",")})") } diff --git a/src/main/scala/translating/offlineLifter/InstructionBuilder.scala b/src/main/scala/translating/offlineLifter/InstructionBuilder.scala index 540a61f6b8..0393fa3765 100644 --- a/src/main/scala/translating/offlineLifter/InstructionBuilder.scala +++ b/src/main/scala/translating/offlineLifter/InstructionBuilder.scala @@ -82,7 +82,12 @@ trait LifterIFace[L] extends LiftState[Expr, L, BitVecLiteral] { smt_bvlshr(arg0, BitVecLiteral(arg1.value, arg0.size)) def f_decl_bool(arg0: String): Expr = LocalVar(b.fresh_local + "_bool", BoolType) - def f_decl_bv(arg0: String, arg1: BigInt): Expr = LocalVar(b.fresh_local + "_bv" + arg1, BitVecType(arg1.toInt)) + def f_decl_bv(arg0: String, arg1: BigInt): Expr = { + val v = LocalVar(b.fresh_local + "_bv" + arg1, BitVecType(arg1.toInt)) + // FIXME: shouldn't need always clear + b.push_stmt(LocalAssign(v, BitVecLiteral(0, arg1.toInt))) + v + } def f_AtomicEnd(): Expr = LocalVar("ATOMICEND", BoolType) def f_AtomicStart(): Expr = LocalVar("ATOMICSTART", BoolType) diff --git a/src/main/scala/translating/offlineLifter/OfflineLifter.scala b/src/main/scala/translating/offlineLifter/OfflineLifter.scala index 96964dbbb8..e7f2c9f1c6 100644 --- a/src/main/scala/translating/offlineLifter/OfflineLifter.scala +++ b/src/main/scala/translating/offlineLifter/OfflineLifter.scala @@ -116,6 +116,7 @@ class StmtListLifter extends LifterIFace[Int] { object Lifter { def liftBlockBytes(ops: Iterable[Int], initialSp: BigInt): Seq[Seq[Statement]] = { + var sp = initialSp ops.toSeq.map { op => val ins = if (op == 0xd503201f.toInt) { @@ -124,10 +125,15 @@ object Lifter { Seq(Assert(FalseLiteral, Some(s"aarch64_system_exceptions_debug_breakpoint (0x$op)"))) } else { try { + val checker = ir.invariant.ReadUninitialised() val lift = StmtListLifter() lift.builder.pcValue = sp f_A64_decoder[Expr, Int, BitVecLiteral](lift, BitVecLiteral(BigInt(op), 32), BitVecLiteral(sp, 64)) - lift.extract.toSeq + val stmts = lift.extract.toSeq + checker.readUninitialised(stmts) + val r = checker.getResult() + if (r.isDefined) throw Exception(r.get + "\n" + stmts.mkString("\n")) + stmts } catch { case e => { val o = "%x".format(op) @@ -137,7 +143,8 @@ object Lifter { } } } - sp = sp + 32 + sp += 4 + ins.foreach(s => s.comment = s.comment.orElse(Some("op: " + "0x%x".format(op)))) ins } } diff --git a/src/main/scala/util/BASILConfig.scala b/src/main/scala/util/BASILConfig.scala index 49c7d0d8d5..99a8ece41e 100644 --- a/src/main/scala/util/BASILConfig.scala +++ b/src/main/scala/util/BASILConfig.scala @@ -10,6 +10,16 @@ enum PCTrackingOption { case None, Keep, Assert } +enum SimplifyMode { + case Disabled + case Simplify + case ValidatedSimplify( + verify: Option[util.SMT.Solver] = Some(util.SMT.Solver.Z3), + dumpSMT: Option[String] = None, + dryRun: Boolean = false + ) +} + case class BoogieGeneratorConfig( memoryFunctionType: BoogieMemoryAccessMode = BoogieMemoryAccessMode.SuccessiveStoreSelect, coalesceConstantMemory: Boolean = true, @@ -77,7 +87,7 @@ case class BASILConfig( context: Option[IRContext] = None, loading: ILLoadingConfig, runInterpret: Boolean = false, - simplify: Boolean = false, + simplify: SimplifyMode = SimplifyMode.Disabled, validateSimp: Boolean = false, tvSimp: Boolean = false, dsaConfig: Option[DSConfig] = None, diff --git a/src/main/scala/util/Logging.scala b/src/main/scala/util/Logging.scala index 2b1eeb07d1..e954067912 100644 --- a/src/main/scala/util/Logging.scala +++ b/src/main/scala/util/Logging.scala @@ -189,3 +189,5 @@ val SVALogger = DSALogger.deriveLogger("SVA").setLevel(LogLevel.OFF) val IntervalDSALogger = DSALogger.deriveLogger("SadDSA", Console.out).setLevel(LogLevel.OFF) val condPropDebugLogger = SimplifyLogger.deriveLogger("inlineCond") val StackLogger = Logger.deriveLogger("Stack").setLevel(LogLevel.OFF) +val tvLogger = Logger.deriveLogger("TranslationValidation").setLevel(LogLevel.INFO) +val tvEvalLogger = Logger.deriveLogger("tv-eval").setLevel(LogLevel.INFO) diff --git a/src/main/scala/util/PerformanceTimer.scala b/src/main/scala/util/PerformanceTimer.scala index 41d57f27db..a5a664f2bd 100644 --- a/src/main/scala/util/PerformanceTimer.scala +++ b/src/main/scala/util/PerformanceTimer.scala @@ -21,7 +21,11 @@ case class RegionTimer(name: String) { } } -case class PerformanceTimer(timerName: String = "", logLevel: LogLevel = LogLevel.DEBUG) { +case class PerformanceTimer( + timerName: String = "", + logLevel: LogLevel = LogLevel.DEBUG, + logger: GenericLogger = Logger +) { private var lastCheckpoint: Long = System.currentTimeMillis() private var end: Long = 0 private val checkpoints: mutable.Map[String, Long] = mutable.HashMap() @@ -35,10 +39,10 @@ case class PerformanceTimer(timerName: String = "", logLevel: LogLevel = LogLeve checkpoints.put(name, delta) trace.add((name, delta)) logLevel match { - case LogLevel.DEBUG => Logger.debug(s"timer:$timerName [$name]: ${delta}ms") - case LogLevel.INFO => Logger.info(s"timer:$timerName [$name]: ${delta}ms") - case LogLevel.WARN => Logger.warn(s"timer:$timerName [$name]: ${delta}ms") - case LogLevel.ERROR => Logger.error(s"timer:$timerName [$name]: ${delta}ms") + case LogLevel.DEBUG => logger.debug(s"timer:$timerName [$name]: ${delta}ms") + case LogLevel.INFO => logger.info(s"timer:$timerName [$name]: ${delta}ms") + case LogLevel.WARN => logger.warn(s"timer:$timerName [$name]: ${delta}ms") + case LogLevel.ERROR => logger.error(s"timer:$timerName [$name]: ${delta}ms") case _ => ??? } delta diff --git a/src/main/scala/util/RunUtils.scala b/src/main/scala/util/RunUtils.scala index c5963b2512..126ed306ac 100644 --- a/src/main/scala/util/RunUtils.scala +++ b/src/main/scala/util/RunUtils.scala @@ -54,23 +54,26 @@ object RunUtils { Logger.info("[!] Loading Program") val q = conf var ctx = q.context.getOrElse(IRLoading.load(q.loading)) + assert(invariant.readUninitialised(ctx.program)) postLoad(ctx) // allows extracting information from the original loaded program - assert(ir.invariant.checkTypeCorrect(ctx.program)) + assert(invariant.checkTypeCorrect(ctx.program)) assert(invariant.singleCallBlockEnd(ctx.program)) assert(invariant.cfgCorrect(ctx.program)) assert(invariant.blocksUniqueToEachProcedure(ctx.program)) + assert(invariant.readUninitialised(ctx.program)) val analysisManager = AnalysisManager(ctx.program) - if conf.simplify then doCleanupWithSimplify(ctx, analysisManager) + if conf.simplify != SimplifyMode.Disabled then doCleanupWithSimplify(ctx, analysisManager) else doCleanupWithoutSimplify(ctx, analysisManager) - assert(ir.invariant.programDiamondForm(ctx.program)) + assert(invariant.programDiamondForm(ctx.program)) + assert(invariant.readUninitialised(ctx.program)) transforms.inlinePLTLaunchpad(ctx, analysisManager) - assert(ir.invariant.programDiamondForm(ctx.program)) + assert(invariant.programDiamondForm(ctx.program)) if q.loading.trimEarly then getStripUnreachableFunctionsTransform(q.loading.procedureTrimDepth)(ctx, analysisManager) @@ -81,9 +84,10 @@ object RunUtils { ctx.program.procedures.foreach(transforms.RemoveUnreachableBlocks.apply) Logger.info(s"[!] Removed unreachable blocks") - if (q.loading.parameterForm && !q.simplify) { + if (q.loading.parameterForm && !(q.simplify != SimplifyMode.Disabled)) { ir.transforms.clearParams(ctx.program) ctx = ir.transforms.liftProcedureCallAbstraction(ctx) + assert(ir.invariant.readUninitialised(ctx.program)) if (conf.assertCalleeSaved) { transforms.CalleePreservedParam.transform(ctx.program) } @@ -108,23 +112,38 @@ object RunUtils { assert(ir.invariant.programDiamondForm(ctx.program)) ir.eval.SimplifyValidation.validate = conf.validateSimp - if (conf.simplify) { - - ir.transforms.clearParams(ctx.program) - - ir.transforms.liftIndirectCall(ctx.program) - transforms.liftSVCompNonDetEarlyIR(ctx.program) - - DebugDumpIRLogger.writeToFile(File("il-after-indirectcalllift.il"), pp_prog(ctx.program)) - ctx = ir.transforms.liftProcedureCallAbstraction(ctx) - DebugDumpIRLogger.writeToFile(File("il-after-proccalls.il"), pp_prog(ctx.program)) - if (conf.assertCalleeSaved) { - transforms.CalleePreservedParam.transform(ctx.program) + conf.simplify match { + case c: SimplifyMode.ValidatedSimplify => { + ir.transforms.clearParams(ctx.program) + ir.transforms.liftIndirectCall(ctx.program) + DebugDumpIRLogger.writeToFile(File("il-beforetvsimp.il"), pp_prog(ctx.program)) + val (tvres, nctx) = transforms.validate.validatedSimplifyPipeline(ctx, conf.simplify) + ctx = nctx } - - assert(ir.invariant.programDiamondForm(ctx.program)) - doSimplify(ctx, conf.staticAnalysis) + case SimplifyMode.Simplify => { + assert(ir.invariant.readUninitialised(ctx.program)) + ir.transforms.clearParams(ctx.program) + assert(ir.invariant.readUninitialised(ctx.program)) + + ir.transforms.liftIndirectCall(ctx.program) + transforms.liftSVCompNonDetEarlyIR(ctx.program) + + assert(ir.invariant.readUninitialised(ctx.program)) + DebugDumpIRLogger.writeToFile(File("il-after-indirectcalllift.il"), pp_prog(ctx.program)) + ctx = ir.transforms.liftProcedureCallAbstraction(ctx) + DebugDumpIRLogger.writeToFile(File("il-after-proccalls.il"), pp_prog(ctx.program)) + + assert(ir.invariant.readUninitialised(ctx.program)) + if (conf.assertCalleeSaved) { + transforms.CalleePreservedParam.transform(ctx.program) + } + + assert(ir.invariant.readUninitialised(ctx.program)) + assert(ir.invariant.programDiamondForm(ctx.program)) + doSimplify(ctx, conf.staticAnalysis) + } + case SimplifyMode.Disabled => () } assert(ir.invariant.programDiamondForm(ctx.program)) @@ -147,10 +166,14 @@ object RunUtils { val memTransferTimer = PerformanceTimer("Mem Transfer Timer", INFO) visit_prog(MemoryTransform(dsaResults.topDown, dsaResults.globals), ctx.program) memTransferTimer.checkPoint("Performed Memory Transform") + invariant.readUninitialised(ctx.program) } if q.summariseProcedures then - getGenerateProcedureSummariesTransform(q.loading.parameterForm || q.simplify)(ctx, analysisManager) + getGenerateProcedureSummariesTransform(q.loading.parameterForm || conf.simplify != SimplifyMode.Disabled)( + ctx, + analysisManager + ) if (!conf.staticAnalysis.exists(!_.irreducibleLoops) && conf.generateLoopInvariants) { if (!conf.staticAnalysis.exists(_.irreducibleLoops)) { diff --git a/src/main/scala/util/functional/List.scala b/src/main/scala/util/functional/List.scala index 61b686ddff..c678e609cc 100644 --- a/src/main/scala/util/functional/List.scala +++ b/src/main/scala/util/functional/List.scala @@ -14,11 +14,12 @@ import collection.{IterableOps, Factory} */ object Snoc { - def unapply[A, CC[_], C](x: IterableOps[A, CC, C]): Option[(C, A)] = { + def unapply[A, CC[_], C <: Iterable[A]](x: IterableOps[A, CC, C]): Option[(C, A)] = { if (x.isEmpty) then { None } else { - Some((x.init, x.last)) + val (init, last) = x.splitAt(x.size - 1) + Some((init, last.head)) } } diff --git a/src/main/scala/util/functional/Memo.scala b/src/main/scala/util/functional/Memo.scala new file mode 100644 index 0000000000..8c895ba764 --- /dev/null +++ b/src/main/scala/util/functional/Memo.scala @@ -0,0 +1,24 @@ +package util.functional +import scala.collection.mutable + +case class Stats(hits: Long, misses: Long) { + def hitRate: Double = hits.toDouble / (misses + hits) +} + +def memoised[T, P](f: T => P, withStats: Boolean = false): (T => P, () => Stats) = { + var hits: Long = 0 + var misses: Long = 0 + val cache = mutable.Map[T, P]() + def memoFn(arg: T): P = { + if (cache.contains(arg)) then { + hits += 1 + cache(arg) + } else { + misses += 1 + val r = f(arg) + cache(arg) = r + r + } + } + (memoFn, () => Stats(hits, misses)) +} diff --git a/src/main/scala/util/intrusive_list/IntrusiveList.scala b/src/main/scala/util/intrusive_list/IntrusiveList.scala index 9d393d7732..292bed0315 100644 --- a/src/main/scala/util/intrusive_list/IntrusiveList.scala +++ b/src/main/scala/util/intrusive_list/IntrusiveList.scala @@ -207,6 +207,8 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( debugAssert(newElem.unitary) onInsert(newElem) if (size > 0) { + + // HERE insertAfter(lastElem.get, newElem) } else { firstElem = Some(newElem) @@ -290,10 +292,10 @@ final class IntrusiveList[T <: IntrusiveListElement[T]] private ( */ def insertAfter(intrusiveListElement: T, newElem: T): T = { debugAssert(size >= 1) - debugAssert( - containsRef(intrusiveListElement), - "element is not a member of this list, insertAfter could mangle start and end tracking" - ) + // debugAssert( + // containsRef(intrusiveListElement), + // "element is not a member of this list, insertAfter could mangle start and end tracking" + // ) debugAssert(newElem.unitary) numElems += 1 if (intrusiveListElement == lastElem.get) { @@ -479,9 +481,15 @@ trait IntrusiveListElement[T <: IntrusiveListElement[T]]: first().insertBefore(elem) } - private[intrusive_list] final def getNext: T = next.get + private[intrusive_list] final def getNext: T = + val n = next.get + assert(n != this) + n - private[intrusive_list] final def getPrev: T = prev.get + private[intrusive_list] final def getPrev: T = + val p = prev.get + assert(p != this) + p private[intrusive_list] final def hasNext: Boolean = next.isDefined private[intrusive_list] final def hasPrev: Boolean = prev.isDefined diff --git a/src/main/scala/util/smt/SMT.scala b/src/main/scala/util/smt/SMT.scala index 435aff7305..ad6d121107 100644 --- a/src/main/scala/util/smt/SMT.scala +++ b/src/main/scala/util/smt/SMT.scala @@ -9,11 +9,15 @@ import org.sosy_lab.java_smt.SolverContextFactory import org.sosy_lab.java_smt.api.{ BitvectorFormula, BooleanFormula, + Evaluator, + Formula, FormulaManager, FormulaType, FunctionDeclaration, + ProverEnvironment, SolverContext } +import util.functional.Snoc import scala.collection.mutable import scala.jdk.CollectionConverters.{SeqHasAsJava, SetHasAsJava} @@ -29,6 +33,11 @@ enum SatResult { case Unknown(s: String) } +enum Solver { + case Z3 + case CVC5 +} + /** A wrapper around an SMT solver. * * (!!) It is very important (!!) to close the solver with [[close]] once you are done with it to prevent memory leaks! @@ -42,7 +51,7 @@ enum SatResult { * * Models can be obtained by requesting for them in smt query method calls. */ -class SMTSolver(var defaultTimeoutMillis: Option[Int] = None) { +class SMTSolver(var defaultTimeoutMillis: Option[Int] = None, solver: Solver = Solver.Z3) { /** Create solver with timeout * @@ -52,48 +61,145 @@ class SMTSolver(var defaultTimeoutMillis: Option[Int] = None) { val shutdownManager = ShutdownManager.create() - val solverContext = { - val config = Configuration.defaultConfiguration() + val solverContext: SolverContext = { + val builder = Configuration.builder() + builder.copyFrom(Configuration.defaultConfiguration()) + (solver, defaultTimeoutMillis) match { + case (Solver.CVC5, Some(tl)) => + builder.setOption("solver.cvc5.furtherOptions", s"tlimit-per=${tl}") + case _ => () + } val logger = LogManager.createNullLogManager() val shutdown = shutdownManager.getNotifier() - SolverContextFactory.createSolverContext(config, logger, shutdown, SolverContextFactory.Solvers.Z3) + SolverContextFactory.createSolverContext( + builder.build(), + logger, + shutdown, + solver match { + case Solver.Z3 => SolverContextFactory.Solvers.Z3 + case Solver.CVC5 => SolverContextFactory.Solvers.CVC5 + } + ) } val formulaConverter = FormulaConverter(solverContext.getFormulaManager()) + def getProver(obtainModel: Boolean = false): SMTProver = { + val prover = + if obtainModel + then solverContext.newProverEnvironment(SolverContext.ProverOptions.GENERATE_MODELS) + else solverContext.newProverEnvironment() + + SMTProver(solverContext, shutdownManager, formulaConverter, prover) + } + private def sat(f: BooleanFormula, timeoutMillis: Option[Int], obtainModel: Boolean = false): SatResult = { + val env = getProver(obtainModel) + env.addConstraint(f) + val r = env.checkSat(timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) + env.close() + r + } + + /** Run solver on a [[analysis.Predicate]] */ + def predSat(p: Predicate, timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { + sat(formulaConverter.convertPredicate(p), timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) + } + + /** Run solver on a boolean typed BASIL [[ir.Expr]] */ + def exprSat(p: Expr, timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { + sat(formulaConverter.convertBoolExpr(p), timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) + } + + /** Run solver on a predicate given as an SMT2 string */ + def smt2Sat(s: String, timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { + sat(solverContext.getFormulaManager().parse(s), timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) + } + + /** Close the solver to prevent a memory leak when done. */ + def close() = { + solverContext.close() + } + +} + +class SMTEvaluator(formulaConverter: FormulaConverter, eval: Evaluator) { + + def evalExpr(e: Expr): Option[Literal] = { + e.getType match { + case BoolType => + evalBoolExpr(e).map { + case true => TrueLiteral + case false => FalseLiteral + } + case _: BitVecType => evalBVExpr(e) + case _ => throw Exception(s"Model eval not supported for expr : ${e.getType}") + } + } + + def evalBoolExpr(e: Expr): Option[Boolean] = { + Option(eval.evaluate(formulaConverter.convertBoolExpr(e))) + } + + def evalBVExpr(e: Expr): Option[BitVecLiteral] = { + val width = e.getType match { + case BitVecType(s) => s + case _ => throw Exception("not a bv formula") + } + Option(eval.evaluate(formulaConverter.convertBVExpr(e))).map(v => BitVecLiteral(v, width)) + } + +} + +class SMTProver( + val solverContext: SolverContext, + val shutdownManager: ShutdownManager, + val formulaConverter: FormulaConverter, + val prover: ProverEnvironment +) { + + def addConstraint(e: BooleanFormula) = { + prover.addConstraint(e) + } + + def addConstraint(e: Expr) = { + prover.addConstraint(formulaConverter.convertBoolExpr(e)) + } + + def addConstraint(e: Predicate) = { + prover.addConstraint(formulaConverter.convertPredicate(e)) + } + + def close() = { + prover.close() + } + + def checkSat(timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { // To handle timeouts, we must create a thread that sends a shutdown request after an amount of milliseconds val thread = timeoutMillis.map(m => { new Thread(new Runnable() { - def run() = { + override def run() = { try { Thread.sleep(m) shutdownManager.requestShutdown("Timeout") - } catch { _ => {} } + } catch { e => { println(s"$e") } } } }) }) - val env = - if obtainModel - then solverContext.newProverEnvironment(SolverContext.ProverOptions.GENERATE_MODELS) - else solverContext.newProverEnvironment() - try { - env.push(f) thread.map(_.start) val res = - if env.isUnsat() then SatResult.UNSAT + if prover.isUnsat() then SatResult.UNSAT else SatResult.SAT(obtainModel match { - case true => Some(Model(env.getModel)) + case true => Some(Model(prover.getModel)) case false => None }) res } catch { e => SatResult.Unknown(e.toString()) } finally { - env.close() thread.map(t => { t.interrupt() t.join() @@ -101,31 +207,17 @@ class SMTSolver(var defaultTimeoutMillis: Option[Int] = None) { } } - /** Run solver on a [[analysis.Predicate]] */ - def predSat(p: Predicate, timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { - sat(formulaConverter.convertPredicate(p), timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) - } - - /** Run solver on a boolean typed BASIL [[ir.Expr]] */ - def exprSat(p: Expr, timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { - sat(formulaConverter.convertBoolExpr(p), timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) - } - - /** Run solver on a predicate given as an SMT2 string */ - def smt2Sat(s: String, timeoutMillis: Option[Int] = None, obtainModel: Boolean = false): SatResult = { - sat(solverContext.getFormulaManager().parse(s), timeoutMillis.orElse(defaultTimeoutMillis), obtainModel) - } - - /** Close the solver to prevent a memory leak when done. */ - def close() = { - solverContext.close() + def getEvaluator() = { + SMTEvaluator(formulaConverter, prover.getEvaluator()) } } class FormulaConverter(formulaManager: FormulaManager) { + lazy val ufFormulaMAnager = formulaManager.getUFManager() lazy val bitvectorFormulaManager = formulaManager.getBitvectorFormulaManager() lazy val booleanFormulaManager = formulaManager.getBooleanFormulaManager() + lazy val integerFormulaManager = formulaManager.getIntegerFormulaManager() lazy val uninterpretedFunctionManager = formulaManager.getUFManager() var uninterpretedFunctions: mutable.Map[String, FunctionDeclaration[BitvectorFormula]] = mutable.Map() @@ -198,11 +290,43 @@ class FormulaConverter(formulaManager: FormulaManager) { // Convert IR expressions + def convertExpr(e: Expr): Formula = { + val () = e match { + case FApplyExpr("ite", iteParams, retType, true) => + val itePairs = iteParams.map(convertExpr).grouped(2).toList + return itePairs match { + // NOTE: treats the last value in the ite (cond, value) pair list as + // being an infallible fallback value. + case Snoc(init, Seq(_, lastValue)) => + init.foldRight(lastValue) { case (Seq(cond, value), rest) => + booleanFormulaManager.ifThenElse(cond.asInstanceOf[BooleanFormula], value, rest) + } + case _ => throw new Exception("unrecognised ite argument structure in: " + e) + } + case _ => () + } + + e.getType match { + case BoolType => convertBoolExpr(e) + case _: BitVecType => convertBVExpr(e) + case IntType => + e match { + case IntLiteral(v) => + bitvectorFormulaManager.makeBitvector(v.bitLength + 2, v.bigInteger) + case _ => throw Exception(s"integer formulas are not supported ${e.getType}: $e") + } + case _ => throw Exception(s"unsupported expr type ${e.getType}: $e") + } + } + def convertBoolExpr(e: Expr): BooleanFormula = { assert(e.getType == BoolType) e match { + case FApplyExpr("ite", _, _, _) => convertExpr(e).asInstanceOf[BooleanFormula] case TrueLiteral => booleanFormulaManager.makeTrue() case FalseLiteral => booleanFormulaManager.makeFalse() + case AssocExpr(BoolAND, args) => booleanFormulaManager.and(args.map(convertBoolExpr).asJava) + case AssocExpr(BoolOR, args) => booleanFormulaManager.or(args.map(convertBoolExpr).asJava) case BinaryExpr(op, arg, arg2) => op match { case op: BoolBinOp => convertBoolBinOp(op, convertBoolExpr(arg), convertBoolExpr(arg2)) @@ -233,13 +357,14 @@ class FormulaConverter(formulaManager: FormulaManager) { } case v: Variable => booleanFormulaManager.makeVariable(v.name) case r: OldExpr => ??? - case _ => throw Exception("Non boolean expression was attempted to be converted") + case e => throw Exception(s"Non boolean expression was attempted to be converted: $e") } } def convertBVExpr(e: Expr): BitvectorFormula = { assert(e.getType.isInstanceOf[BitVecType]) e match { + case FApplyExpr("ite", _, _, _) => convertExpr(e).asInstanceOf[BitvectorFormula] case BitVecLiteral(value, size) => bitvectorFormulaManager.makeBitvector(size, value.bigInteger) case Extract(end, start, arg) => bitvectorFormulaManager.extract(convertBVExpr(arg), end - 1, start) case Repeat(repeats, arg) => { @@ -266,7 +391,10 @@ class FormulaConverter(formulaManager: FormulaManager) { } case v: Variable => convertBVVar(v.irType, v.name) case r: OldExpr => ??? - case _ => throw Exception("Non bitvector expression was attempted to be converted") + case FApplyExpr(n, p, rt @ BitVecType(sz), _) => + val t: FormulaType[BitvectorFormula] = FormulaType.getBitvectorTypeWithSize(sz) + ufFormulaMAnager.declareAndCallUF(n, t, p.map(convertExpr).toList.asJava) + case e => throw Exception(s"Non bitvector expression was attempted to be converted: $e") } } diff --git a/src/test/scala/ConditionLiftingTests.scala b/src/test/scala/ConditionLiftingTests.scala index c2c25d2815..d8273ee210 100644 --- a/src/test/scala/ConditionLiftingTests.scala +++ b/src/test/scala/ConditionLiftingTests.scala @@ -2,6 +2,7 @@ import analysis.AnalysisManager import ir.* import ir.dsl.* import org.scalatest.funsuite.AnyFunSuite +import util.SimplifyMode @test_util.tags.UnitTest class ConditionLiftingRegressionTest extends AnyFunSuite with test_util.CaptureOutput { @@ -815,6 +816,31 @@ class ConditionLiftingRegressionTest extends AnyFunSuite with test_util.CaptureO proc("strtoul") ) + test("conds inline test tvsimp") { + + var ctx = ir.IRLoading.load(testProgram) + ir.transforms.doCleanupWithSimplify(ctx, analysis.AnalysisManager(ctx.program)) + ir.transforms.clearParams(ctx.program) + + ir.transforms.validate.validatedSimplifyPipeline(ctx, SimplifyMode.Simplify) + for (p <- ctx.program.procedures) { + p.normaliseBlockNames() + } + + (ctx.program).foreach { + case a: Assume => { + assert(!(a.body.variables.exists(v => { + v.name.startsWith("ZF") + || v.name.startsWith("CF") + || v.name.startsWith("VF") + || v.name.startsWith("NF") + || v.name.startsWith("Cse") + }))) + } + case _ => () + } + } + test("conds inline test") { var ctx = ir.IRLoading.load(testProgram) diff --git a/src/test/scala/IndirectCallTests.scala b/src/test/scala/IndirectCallTests.scala index 2a02a81cc1..ff1f143d9b 100644 --- a/src/test/scala/IndirectCallTests.scala +++ b/src/test/scala/IndirectCallTests.scala @@ -2,7 +2,7 @@ import analysis.data_structure_analysis.{DSAContext, *} import ir.* import org.scalatest.funsuite.* import test_util.{BASILTest, CaptureOutput, TestConfig, TestCustomisation} -import util.{BASILResult, DSConfig, LogLevel, Logger, StaticAnalysisConfig} +import util.{BASILResult, DSConfig, LogLevel, Logger, SimplifyMode, StaticAnalysisConfig} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -68,7 +68,7 @@ class IndirectCallTests extends AnyFunSuite, CaptureOutput, BASILTest, TestCusto BPLPath, staticAnalysisConf, dsa = Some(DSConfig()), - simplify = true, + simplify = SimplifyMode.Simplify, postLoad = ctx => { indircalls = getIndirectCalls(ctx.program); } ) (basilresult, indircalls.map(_.label.get)) diff --git a/src/test/scala/IntervalDSATest.scala b/src/test/scala/IntervalDSATest.scala index 2d723037b5..8d24c02bb0 100644 --- a/src/test/scala/IntervalDSATest.scala +++ b/src/test/scala/IntervalDSATest.scala @@ -235,7 +235,7 @@ class IntervalDSATest extends AnyFunSuite with test_util.CaptureOutput { trimEarly = main.isDefined, mainProcedureName = main.getOrElse("main") ), - simplify = true, + simplify = SimplifyMode.Simplify, staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out", @@ -249,7 +249,7 @@ class IntervalDSATest extends AnyFunSuite with test_util.CaptureOutput { BASILConfig( context = Some(context), loading = ILLoadingConfig(inputFile = "", relfFile = None), - simplify = true, + simplify = SimplifyMode.Simplify, staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out", @@ -626,6 +626,8 @@ class IntervalDSATest extends AnyFunSuite with test_util.CaptureOutput { val locals = res.dsa.get.local assert(locals.values.forall(_.glIntervals.size == 1)) + println(locals.values.filterNot(IntervalDSA.checksStackMaintained).map(_.proc.procName).toSet) + assert( locals.values.filterNot(g => stackCollapsed.contains(g.proc.procName)).forall(IntervalDSA.checksStackMaintained) ) diff --git a/src/test/scala/LoopInvariantTests.scala b/src/test/scala/LoopInvariantTests.scala index c63871a3c9..ea586b5174 100644 --- a/src/test/scala/LoopInvariantTests.scala +++ b/src/test/scala/LoopInvariantTests.scala @@ -6,7 +6,7 @@ import org.scalatest.funsuite.AnyFunSuite import test_util.BASILTest.programToContext import test_util.CaptureOutput import util.SMT.{SMTSolver, SatResult} -import util.{BASILConfig, BASILResult, BoogieGeneratorConfig, ILLoadingConfig, RunUtils} +import util.{BASILConfig, BASILResult, BoogieGeneratorConfig, ILLoadingConfig, RunUtils, SimplifyMode} @test_util.tags.UnitTest class LoopInvariantTests extends AnyFunSuite, CaptureOutput { @@ -43,7 +43,7 @@ class LoopInvariantTests extends AnyFunSuite, CaptureOutput { BASILConfig( context = Some(context), loading = ILLoadingConfig(inputFile = "", relfFile = None), - simplify = true, + simplify = SimplifyMode.Simplify, generateLoopInvariants = true, staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), diff --git a/src/test/scala/MemoryTransformTests.scala b/src/test/scala/MemoryTransformTests.scala index 63c08b2d8f..1e62ec3c9a 100644 --- a/src/test/scala/MemoryTransformTests.scala +++ b/src/test/scala/MemoryTransformTests.scala @@ -15,7 +15,7 @@ class MemoryTransformTests extends AnyFunSuite with CaptureOutput { RunUtils.loadAndTranslate( BASILConfig( loading = ILLoadingConfig(inputFile = path + ".adt", relfFile = Some(path + ".relf")), - simplify = true, + simplify = SimplifyMode.Simplify, staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out", @@ -30,7 +30,7 @@ class MemoryTransformTests extends AnyFunSuite with CaptureOutput { BASILConfig( context = Some(context), loading = ILLoadingConfig(inputFile = "", relfFile = None), - simplify = true, + simplify = SimplifyMode.Simplify, staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out", @@ -214,7 +214,7 @@ class MemoryTransformTests extends AnyFunSuite with CaptureOutput { MemoryStore(mem, xAddress, R31, LittleEndian, 64, Some("01")), MemoryLoad(R0, mem, xAddress, LittleEndian, 64, Some("02")), MemoryStore(mem, zAddress, R0, LittleEndian, 64, Some("03")), - goto("k") + goto("k", "dummy") ), block("h", goto("k")), block("k", ret), diff --git a/src/test/scala/PCTrackingTest.scala b/src/test/scala/PCTrackingTest.scala index 9f65bd82a7..e2417bbeb1 100644 --- a/src/test/scala/PCTrackingTest.scala +++ b/src/test/scala/PCTrackingTest.scala @@ -1,7 +1,7 @@ import ir.{IRContext, *} import org.scalatest.funsuite.AnyFunSuite import test_util.{BASILTest, CaptureOutput} -import util.{BASILConfig, BoogieGeneratorConfig, ILLoadingConfig, PCTrackingOption, StaticAnalysisConfig} +import util.{BASILConfig, BoogieGeneratorConfig, ILLoadingConfig, PCTrackingOption, SimplifyMode, StaticAnalysisConfig} @test_util.tags.UnitTest class PCTrackingTest extends AnyFunSuite with CaptureOutput { @@ -19,7 +19,7 @@ class PCTrackingTest extends AnyFunSuite with CaptureOutput { staticAnalysis = Some(StaticAnalysisConfig(None)), boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out", - simplify = simplify + simplify = if simplify then SimplifyMode.Simplify else SimplifyMode.Disabled ) ) } diff --git a/src/test/scala/SVATest.scala b/src/test/scala/SVATest.scala index 2f8b91b5e8..c3df165539 100644 --- a/src/test/scala/SVATest.scala +++ b/src/test/scala/SVATest.scala @@ -16,7 +16,7 @@ class SVATest extends AnyFunSuite with CaptureOutput { BASILConfig( context = Some(context), loading = ILLoadingConfig(inputFile = "", relfFile = None), - simplify = true, + simplify = SimplifyMode.Simplify, staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out", diff --git a/src/test/scala/SystemTests.scala b/src/test/scala/SystemTests.scala index 1f664cdfd4..6d5f75362e 100644 --- a/src/test/scala/SystemTests.scala +++ b/src/test/scala/SystemTests.scala @@ -3,7 +3,16 @@ import test_util.BASILTest.* import test_util.{BASILTest, CaptureOutput, Histogram, TestConfig, TestCustomisation} import util.DSAPhase.TD import util.boogie_interaction.* -import util.{DSConfig, DebugDumpIRLogger, LogLevel, Logger, MemoryRegionsMode, PerformanceTimer, StaticAnalysisConfig} +import util.{ + DSConfig, + DebugDumpIRLogger, + LogLevel, + Logger, + MemoryRegionsMode, + PerformanceTimer, + SimplifyMode, + StaticAnalysisConfig +} import java.io.File import scala.collection.immutable.ListMap @@ -283,6 +292,51 @@ class SystemTestsGTIRB extends SystemTests { } } +@test_util.tags.TVSystemTest +class SystemTestsGTIRBSimplifyTV extends SystemTests { + // array theory solver seems to do bettwe with these difficult specs + val simplify = SimplifyMode.ValidatedSimplify(Some(util.SMT.Solver.Z3), None) + private val timeout = 60 + + override def customiseTestsByName(name: String) = super.customiseTestsByName(name).orElse { + name match { + case x if (!x.endsWith("gcc_O2:GTIRB") || x.endsWith("clang_O2:GTIRB")) => + // "correct/functionpointer/clang:GTIRB" | "correct/functionpointer/clang_pic:GTIRB" | + // "correct/malloc_with_local/clang:GTIRB" | "correct/malloc_with_local2/clang:GTIRB" | + // "correct/malloc_with_local2/gcc:GTIRB" | "correct/malloc_with_local3/clang:GTIRB" | + // "correct/malloc_with_local3/gcc:GTIRB" | "correct/functionpointer/gcc_pic:GTIRB" | + // "correct/functionpointer/gcc:GTIRB" => + Mode.Disabled("disable unoptimised examples for performance") + case _ => Mode.Normal + } + } + + runTests( + "correct", + TestConfig( + baseBoogieFlags = Seq("/proverOpt:O:smt.array.extensional=false"), + timeout = timeout, + useBAPFrontend = false, + expectVerify = true, + checkExpected = false, + logResults = true, + simplify = simplify + ) + ) + runTests( + "incorrect", + TestConfig( + baseBoogieFlags = Seq("/proverOpt:O:smt.array.extensional=false"), + timeout = timeout, + useBAPFrontend = false, + expectVerify = false, + checkExpected = false, + logResults = true, + simplify = simplify + ) + ) +} + @test_util.tags.StandardSystemTest class SystemTestsGTIRBOfflineLifter extends SystemTests { runTests( @@ -375,10 +429,22 @@ class ExtraSpecTests extends SystemTests { @test_util.tags.DisabledTest class NoSimplifySystemTests extends SystemTests { - runTests("correct", TestConfig(simplify = false, useBAPFrontend = true, expectVerify = true, logResults = true)) - runTests("incorrect", TestConfig(simplify = false, useBAPFrontend = true, expectVerify = false, logResults = true)) - runTests("correct", TestConfig(simplify = false, useBAPFrontend = false, expectVerify = true, logResults = true)) - runTests("incorrect", TestConfig(simplify = false, useBAPFrontend = false, expectVerify = false, logResults = true)) + runTests( + "correct", + TestConfig(simplify = SimplifyMode.Disabled, useBAPFrontend = true, expectVerify = true, logResults = true) + ) + runTests( + "incorrect", + TestConfig(simplify = SimplifyMode.Disabled, useBAPFrontend = true, expectVerify = false, logResults = true) + ) + runTests( + "correct", + TestConfig(simplify = SimplifyMode.Disabled, useBAPFrontend = false, expectVerify = true, logResults = true) + ) + runTests( + "incorrect", + TestConfig(simplify = SimplifyMode.Disabled, useBAPFrontend = false, expectVerify = false, logResults = true) + ) test("summary-nosimplify") { summary("nosimplify") } @@ -387,10 +453,22 @@ class NoSimplifySystemTests extends SystemTests { @test_util.tags.AnalysisSystemTest2 @test_util.tags.AnalysisSystemTest class SimplifySystemTests extends SystemTests { - runTests("correct", TestConfig(simplify = true, useBAPFrontend = true, expectVerify = true, logResults = true)) - runTests("incorrect", TestConfig(simplify = true, useBAPFrontend = true, expectVerify = false, logResults = true)) - runTests("correct", TestConfig(simplify = true, useBAPFrontend = false, expectVerify = true, logResults = true)) - runTests("incorrect", TestConfig(simplify = true, useBAPFrontend = false, expectVerify = false, logResults = true)) + runTests( + "correct", + TestConfig(simplify = SimplifyMode.Simplify, useBAPFrontend = true, expectVerify = true, logResults = true) + ) + runTests( + "incorrect", + TestConfig(simplify = SimplifyMode.Simplify, useBAPFrontend = true, expectVerify = false, logResults = true) + ) + runTests( + "correct", + TestConfig(simplify = SimplifyMode.Simplify, useBAPFrontend = false, expectVerify = true, logResults = true) + ) + runTests( + "incorrect", + TestConfig(simplify = SimplifyMode.Simplify, useBAPFrontend = false, expectVerify = false, logResults = true) + ) test("summary-simplify") { summary("simplify") } @@ -415,7 +493,7 @@ class SimplifyMemorySystemTests extends SystemTests { runTests( "correct", TestConfig( - simplify = true, + simplify = SimplifyMode.Simplify, useBAPFrontend = true, expectVerify = true, logResults = true, @@ -425,7 +503,7 @@ class SimplifyMemorySystemTests extends SystemTests { runTests( "incorrect", TestConfig( - simplify = true, + simplify = SimplifyMode.Simplify, useBAPFrontend = true, expectVerify = false, logResults = true, @@ -435,7 +513,7 @@ class SimplifyMemorySystemTests extends SystemTests { runTests( "correct", TestConfig( - simplify = true, + simplify = SimplifyMode.Simplify, useBAPFrontend = false, expectVerify = true, logResults = true, @@ -445,7 +523,7 @@ class SimplifyMemorySystemTests extends SystemTests { runTests( "incorrect", TestConfig( - simplify = true, + simplify = SimplifyMode.Simplify, useBAPFrontend = false, expectVerify = false, logResults = true, @@ -615,11 +693,16 @@ class MemoryRegionTestsNoRegion extends SystemTests { class ProcedureSummaryTests extends SystemTests { runTests( "procedure_summaries", - TestConfig(summariseProcedures = true, simplify = true, useBAPFrontend = true, expectVerify = true) + TestConfig(summariseProcedures = true, simplify = SimplifyMode.Simplify, useBAPFrontend = true, expectVerify = true) ) runTests( "procedure_summaries", - TestConfig(summariseProcedures = true, simplify = true, useBAPFrontend = false, expectVerify = true) + TestConfig( + summariseProcedures = true, + simplify = SimplifyMode.Simplify, + useBAPFrontend = false, + expectVerify = true + ) ) } @@ -633,10 +716,13 @@ class UnimplementedTests extends SystemTests { @test_util.tags.AnalysisSystemTest4 @test_util.tags.AnalysisSystemTest class IntervalDSASystemTests extends SystemTests { - runTests("correct", TestConfig(useBAPFrontend = false, expectVerify = true, simplify = true, dsa = Some(DSConfig()))) + runTests( + "correct", + TestConfig(useBAPFrontend = false, expectVerify = true, simplify = SimplifyMode.Simplify, dsa = Some(DSConfig())) + ) runTests( "incorrect", - TestConfig(useBAPFrontend = false, expectVerify = false, simplify = true, dsa = Some(DSConfig())) + TestConfig(useBAPFrontend = false, expectVerify = false, simplify = SimplifyMode.Simplify, dsa = Some(DSConfig())) ) } @@ -644,16 +730,31 @@ class IntervalDSASystemTests extends SystemTests { class IntervalDSASystemTestsSplitGlobals extends SystemTests { runTests( "correct", - TestConfig(useBAPFrontend = false, expectVerify = true, simplify = true, dsa = Some(DSConfig(TD, true, true))) + TestConfig( + useBAPFrontend = false, + expectVerify = true, + simplify = SimplifyMode.Simplify, + dsa = Some(DSConfig(TD, true, true)) + ) ) runTests( "dsa/correct", - TestConfig(useBAPFrontend = false, expectVerify = true, simplify = true, dsa = Some(DSConfig(TD, true, true))) + TestConfig( + useBAPFrontend = false, + expectVerify = true, + simplify = SimplifyMode.Simplify, + dsa = Some(DSConfig(TD, true, true)) + ) ) runTests( "incorrect", - TestConfig(useBAPFrontend = false, expectVerify = false, simplify = true, dsa = Some(DSConfig(TD, true, true))) + TestConfig( + useBAPFrontend = false, + expectVerify = false, + simplify = SimplifyMode.Simplify, + dsa = Some(DSConfig(TD, true, true)) + ) ) } @@ -661,14 +762,19 @@ class IntervalDSASystemTestsSplitGlobals extends SystemTests { class IntervalDSASystemTestsEqClasses extends SystemTests { runTests( "correct", - TestConfig(useBAPFrontend = false, expectVerify = true, simplify = true, dsa = Some(DSConfig(TD, eqClasses = true))) + TestConfig( + useBAPFrontend = false, + expectVerify = true, + simplify = SimplifyMode.Simplify, + dsa = Some(DSConfig(TD, eqClasses = true)) + ) ) runTests( "incorrect", TestConfig( useBAPFrontend = false, expectVerify = false, - simplify = true, + simplify = SimplifyMode.Simplify, dsa = Some(DSConfig(TD, eqClasses = true)) ) ) @@ -681,7 +787,7 @@ class MemoryTransformSystemTests extends SystemTests { TestConfig( useBAPFrontend = false, expectVerify = false, - simplify = true, + simplify = SimplifyMode.Simplify, dsa = Some(DSConfig()), memoryTransform = true ) @@ -692,7 +798,7 @@ class MemoryTransformSystemTests extends SystemTests { TestConfig( useBAPFrontend = false, expectVerify = false, - simplify = true, + simplify = SimplifyMode.Simplify, dsa = Some(DSConfig()), memoryTransform = true ) diff --git a/src/test/scala/ir/IRToDSLTest.scala b/src/test/scala/ir/IRToDSLTest.scala index aff9d4450c..076b309df6 100644 --- a/src/test/scala/ir/IRToDSLTest.scala +++ b/src/test/scala/ir/IRToDSLTest.scala @@ -6,7 +6,7 @@ import org.scalactic.* import org.scalatest.funsuite.AnyFunSuite import test_util.{BASILTest, CaptureOutput} import translating.PrettyPrinter.* -import util.{BASILConfig, BoogieGeneratorConfig, ILLoadingConfig, LogLevel, Logger} +import util.{BASILConfig, BoogieGeneratorConfig, ILLoadingConfig, LogLevel, Logger, SimplifyMode} import scala.collection.immutable.* @@ -143,7 +143,7 @@ class IRToDSLTest extends AnyFunSuite with CaptureOutput { staticAnalysis = None, boogieTranslation = BoogieGeneratorConfig(), outputPrefix = "boogie_out.bpl", - simplify = true + simplify = SimplifyMode.Simplify ) ) diff --git a/src/test/scala/test_util/BASILTest.scala b/src/test/scala/test_util/BASILTest.scala index 74602f3833..dff125d6d3 100644 --- a/src/test/scala/test_util/BASILTest.scala +++ b/src/test/scala/test_util/BASILTest.scala @@ -14,6 +14,7 @@ import util.{ ILLoadingConfig, Logger, RunUtils, + SimplifyMode, StaticAnalysisConfig } @@ -29,7 +30,7 @@ case class TestConfig( expectVerify: Boolean, checkExpected: Boolean = false, logResults: Boolean = false, - simplify: Boolean = false, + simplify: SimplifyMode = SimplifyMode.Disabled, summariseProcedures: Boolean = false, dsa: Option[DSConfig] = None, memoryTransform: Boolean = false, @@ -49,7 +50,7 @@ trait BASILTest { specPath: Option[String], BPLPath: String, staticAnalysisConf: Option[StaticAnalysisConfig], - simplify: Boolean = false, + simplify: SimplifyMode = SimplifyMode.Disabled, summariseProcedures: Boolean = false, dsa: Option[DSConfig] = None, memoryTransform: Boolean = false, diff --git a/src/test/scala/test_util/tags/TVSystemTest.java b/src/test/scala/test_util/tags/TVSystemTest.java new file mode 100644 index 0000000000..ed147317ba --- /dev/null +++ b/src/test/scala/test_util/tags/TVSystemTest.java @@ -0,0 +1,9 @@ +package test_util.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@Inherited +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +public @interface TVSystemTest {}