Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/main/scala/analysis/KnownBits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,10 @@ case class TNum(value: BitVecLiteral, mask: BitVecLiteral) {
}

// Get smallest possible unsigned value of the TNum (e.g. Min value of TT0 is 000)
def minUnsigned = (this.value & ~this.mask).value
def minUnsigned = mustBits.value

// Get largest possible unsigned value of the TNum (e.g. Max value of TT0 is 110)
def maxUnsigned = (this.value | this.mask).value
def maxUnsigned = mayBits.value

def mustBits = (this.value & ~this.mask)
def mustNotBits = ((~this.value) & ~this.mask)
Expand Down Expand Up @@ -706,7 +706,7 @@ class TNumDomain extends AbstractDomain[Map[Variable, TNum]] {

// s is the abstract state from previous command/block
override def transfer(s: Map[Variable, TNum], b: Command): Map[Variable, TNum] = {
b match {
val r = b match {
// Assign variable to variable (e.g. x = y)
case LocalAssign(lhs: Variable, rhs: Expr, _) =>
s.updated(lhs, evaluateExprToTNum(s, rhs))
Expand All @@ -716,9 +716,22 @@ class TNumDomain extends AbstractDomain[Map[Variable, TNum]] {
// Overapproxiate memory values with Top
s.updated(lhs, TNum.top(size))

case i: IndirectCall => Map()
case a: Assign => s ++ a.assignees.map(l => l -> TNum.top(sizeBits(l.irType)))
// Default case
case _ => s
case _: NOP => s
case _: Assert => s
case _: Assume => s
case _: GoTo => s
case _: Return => s
case _: Unreachable => s
case _: MemoryStore => s
}
r
}

override def join(left: Map[Variable, TNum], right: Map[Variable, TNum], pos: Block): Map[Variable, TNum] = {
join(left, right)
}

/**
Expand All @@ -729,7 +742,7 @@ class TNumDomain extends AbstractDomain[Map[Variable, TNum]] {
* x = 1111 => value = 1111, mask = 0000
* Joined x = 1111 => value = 1111, mask = 0000
*/
override def join(left: Map[Variable, TNum], right: Map[Variable, TNum], pos: Block): Map[Variable, TNum] = {
def join(left: Map[Variable, TNum], right: Map[Variable, TNum]): Map[Variable, TNum] = {
(left.keySet ++ right.keySet).map { key =>
val width = sizeBits(key.getType)
val leftTNum = left.getOrElse(key, TNum.top(width))
Expand Down
224 changes: 224 additions & 0 deletions src/main/scala/ir/transforms/ExtractExtendZeroBits.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package ir.transforms

import ir.transforms.interprocSummaryFixpointSolver
import analysis.*
import ir.*
import ir.cilvisitor.*
import collection.immutable.SortedMap

object ExtractExtendZeroBits {

def resultToTransform(result: Map[Variable, TNum]): Map[Variable, Variable] = {
val d = TNumDomain()

def bitSet(b: BitVecLiteral, n: Int) = {
((b.value >> n) & 1) == 1
}

def highBitsZero(n: TNum): Option[Int] = {
var num: Option[Int] = None

val maxSet = n.mayBits

var delt = false
for (r <- (maxSet.size - 1) to 0 by -1) {
if (bitSet(maxSet, r)) {
delt = true
} else {
num = Some(r)
}
}
num match {
case Some(n) if n < maxSet.size => Some(n)
case _ => None
}
}

val varReplace = result.flatMap {
case (v: Variable, tn: TNum) => {
highBitsZero(tn) match {
case None => None
case Some(mostSigZeroBit) => {
v match {
case l @ LocalVar(n, BitVecType(sz), i) =>
Some(v -> l.copy(irType = BitVecType(sz - mostSigZeroBit)))
case l @ Register(n, sz) =>
Some(v -> l.copy(size = sz - mostSigZeroBit))
case _ => None
}
}
}
}
}
varReplace
}

class ReplaceSigns(procReplacements: Map[Procedure, Map[Variable, Variable]]) extends CILVisitor {

var replacement: Map[Variable, Variable] = Map()

override def vproc(proc: Procedure) = {
replacement = procReplacements.get(proc).getOrElse(Map())

// change formal param function signature
val newIn = proc.formalInParam.foreach {
case p if replacement.contains(p) => {
proc.formalInParam.remove(p)
proc.formalInParam.add(replacement(p).asInstanceOf[LocalVar])
}
case _ => ()
}
val newOut = proc.formalOutParam.foreach {
case p if replacement.contains(p) => {
proc.formalOutParam.remove(p)
proc.formalOutParam.add(replacement(p).asInstanceOf[LocalVar])
}
case _ => ()
}

// fixup calls to this function for new signature
for (call <- proc.incomingCalls()) {
call.outParams = SortedMap.from(call.outParams.map {
case (p, v) if replacement.contains(p) && replacement.contains(v) =>
val l = replacement.getOrElse(p, p).asInstanceOf[LocalVar]
val r = replacement.getOrElse(v, v)
l -> r
case p => p
})

call.actualParams = SortedMap.from(call.actualParams.map {
case (p, e0) if replacement.contains(p) => {
val e = visit_expr(this, e0)
val l = replacement.getOrElse(p, p).asInstanceOf[LocalVar]
val r = if size(e).get != size(l).get then Extract(size(l).get, 0, e) else e
l -> r
}
case p => p
})

}
DoChildren()
}

override def vjump(j: Jump) = j match {
case r: Return => {
ChangeDoChildrenPost(
r,
_ match {
case r: Return => {
r.outParams = SortedMap.from(r.outParams.map {
case (v, e) if replacement.contains(v) =>
val repl: LocalVar = replacement(v).asInstanceOf[LocalVar]
val rhs = if size(e).get != size(repl).get then Extract(size(repl).get, 0, e) else e
repl -> rhs
case o => o
})
r
}
case o => o
}
)
}
case _ => DoChildren()
}

override def vstmt(s: Statement) = {
s match {
case l @ LocalAssign(lhs, rhs, _) if replacement.contains(lhs) => {
l.lhs = replacement(lhs)
l.rhs = Extract(size(replacement(lhs)).get, 0, rhs)
}
case l @ MemoryAssign(lhs, rhs, _) if replacement.contains(lhs) => {
l.lhs = replacement(lhs)
l.rhs = Extract(size(replacement(lhs)).get, 0, rhs)
}
case m: MemoryLoad => {
// not possible to replace lhs
()
}
// case s: SimulAssign => {
// s.assignments = s.assignments.map {
// case (lhs, rhs) if replacement.contains(lhs) => {
// replacement(lhs) -> Extract(size(replacement(lhs)).get, 0, rhs)
// }
// case o => o
// }
// }
case _ => ()
}
DoChildren()
}

override def vexpr(e: Expr) = e match {
case v: Variable if replacement.contains(v) => {
val r = replacement(v)
ChangeTo(ZeroExtend(size(v).get - size(r).get, r))
}
case _ => DoChildren()
}
}

def doTransform(p: Program) = {
val d = TNumDomain()

class SummaryGen extends ProcedureSummaryGenerator[Map[Variable, TNum], Map[Variable, TNum]] {

val dom = TNumDomain()
override def bot = Map[Variable, TNum]()
override def top = Map[Variable, TNum]()
override def join(l: Map[Variable, TNum], r: Map[Variable, TNum], pos: Procedure) = dom.join(l, r)
// overrided in analysis
override def transfer(l: Map[Variable, TNum], b: Procedure) = ???

def localTransferCall(
l: Map[Variable, TNum],
summaryForTarget: Map[Variable, TNum],
p: DirectCall
): Map[Variable, TNum] = {
val joined = p.outParams.map {
case (formal, lhs) => {
val joined = (summaryForTarget.get(formal), l.get(lhs)) match {
case (Some(a), Some(b)) => a.join(b)
case (Some(a), None) => a
case (None, Some(a)) => a
case (None, None) => TNum.top(d.sizeBits(formal.irType))
}
lhs -> joined
}
}

l ++ joined.toMap
}
def updateSummary(
prevSummary: Map[Variable, TNum],
p: Procedure,
resBefore: Map[Block, Map[Variable, TNum]],
resAfter: Map[Block, Map[Variable, TNum]]
): Map[Variable, TNum] = {
p.returnBlock.flatMap(resAfter.get(_)).getOrElse(Map())
}
}
applyRPO(p)

val solver = interprocSummaryFixpointSolver(d, SummaryGen())

val result = solver.solveProgInterProc(p)

val tx = p.procedures
.filter(result.contains)
.map(proc => {
val r = result(proc)
val tx = resultToTransform(r)
proc -> tx
})
.toMap

val vis = ReplaceSigns(tx)
visit_prog(vis, p)

visit_prog(ir.eval.SimpExpr(ir.eval.simplifyPaddingAndSlicingExprFixpoint), p)

assert(invariant.correctCalls(p))
}

}
2 changes: 2 additions & 0 deletions src/main/scala/ir/transforms/Simp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def copypropTransform(

visit_proc(CleanupAssignments(), p)
t.checkPoint("redundant assignments")
t.checkPoint("cleanup extract extend zerobits ")
// SimplifyLogger.info(s" ${p.name} after dead var cleanup expr complexity ${ExprComplexity()(p)}")

AlgebraicSimplifications(p)
Expand Down Expand Up @@ -1068,6 +1069,7 @@ def doCopyPropTransform(p: Program, rela: Map[BigInt, BigInt]) = {

// cleanup
visit_prog(CleanupAssignments(), p)
ExtractExtendZeroBits.doTransform(p)

SimplifyLogger.info("[!] Simplify :: Merge empty blocks")
cleanupBlocks(p)
Expand Down
Loading