Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,50 +178,14 @@ class AstCreator(
}

private def astForGlobalStmt(stmt: PhpGlobalStmt): List[Ast] = {
val surroundingIter = scope.getSurroundingMethods.drop(1).iterator // drop first to ignore global method
val surroundingMethods = scope.getSurroundingMethods.dropRight(1) // drop last to ignore innermost method

surroundingMethods.foreach { currentMethod =>
val innerMethodScope = surroundingIter.next()
val innerMethodNode = innerMethodScope.methodNode
val innerMethodRef = innerMethodScope.methodRefNode
innerMethodRef match {
case Some(methodRef) =>
scope.getMethodRef(innerMethodNode.fullName) match {
case None =>
diffGraph.addNode(methodRef)
diffGraph.addEdge(currentMethod.bodyNode, methodRef, EdgeTypes.AST)
scope.addMethodRef(innerMethodNode.fullName, methodRef)
case _ =>
}

stmt.vars.foreach {
case PhpVariable(name: PhpNameExpr, _) =>
val closureBindingId = s"$relativeFileName:${innerMethodNode.fullName}:${name.name}"
val closureLocal = localNode(stmt, name.name, name.name, Defines.Any, Option(closureBindingId))

val closureBindingNode = NewClosureBinding()
.closureBindingId(closureBindingId)
.evaluationStrategy(EvaluationStrategies.BY_SHARING)

scope.lookupVariable(name.name) match {
case Some(refLocal) => diffGraph.addEdge(closureBindingNode, refLocal, EdgeTypes.REF)
case _ => // do nothing
}

scope.addVariableToMethodScope(closureLocal.name, closureLocal, innerMethodNode.fullName) match {
case Some(ms) => diffGraph.addEdge(ms.bodyNode, closureLocal, EdgeTypes.AST)
case _ => // do nothing
}

diffGraph.addNode(closureBindingNode)
diffGraph.addEdge(methodRef, closureBindingNode, EdgeTypes.CAPTURE)
case x =>
logger.warn(s"Unexpected variable type ${x.getClass} found")
}
case None =>
logger.warn(s"No methodRef found for capturing global variable in method ${innerMethodNode.fullName}")
}
stmt.vars.foreach {
case PhpVariable(name: PhpNameExpr, _) =>
val surroundingIter = scope.getSurroundingMethods.drop(1).iterator // drop first to ignore global method
val surroundingMethods = scope.getSurroundingMethods.dropRight(1) // drop last to ignore innermost method

createClosureCaptureForNode(stmt, name.name, surroundingIter, surroundingMethods)
case x =>
logger.warn(s"Unexpected variable type ${x.getClass} found")
}

stmt.vars.map(astForExpr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@ import io.joern.php2cpg.passes.SymbolSummaryPass.PhpFunction
import io.joern.x2cpg.Defines.UnresolvedNamespace
import io.joern.x2cpg.utils.AstPropertiesUtil.RootProperties
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, ModifierTypes}
import io.shiftleft.codepropertygraph.generated.{EdgeTypes, EvaluationStrategies, ModifierTypes}
import io.shiftleft.codepropertygraph.generated.nodes.{
MethodParameterIn,
NewBlock,
NewClosureBinding,
NewIdentifier,
NewLiteral,
NewLocal,
NewMethod,
NewMethodParameterIn,
NewMethodRef,
NewModifier,
NewNamespaceBlock,
NewNode
}
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal

import scala.collection.mutable
import java.nio.charset.StandardCharsets

trait AstCreatorHelper(disableFileContent: Boolean)(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>
Expand Down Expand Up @@ -147,16 +152,84 @@ trait AstCreatorHelper(disableFileContent: Boolean)(implicit withSchemaValidatio
}

scope.addToScope(name, local) match {
case BlockScope(block, _) => diffGraph.addEdge(block, local, EdgeTypes.AST)
case MethodScope(_, block, _, _, _) => diffGraph.addEdge(block, local, EdgeTypes.AST)
case _ => // do nothing
case BlockScope(block, _) => diffGraph.addEdge(block, local, EdgeTypes.AST)
case MethodScope(_, block, _, _, _, _) => diffGraph.addEdge(block, local, EdgeTypes.AST)
case _ => // do nothing
}

local
case Some(local: NewLocal)
if scope.isSurroundedByArrowClosure && local.closureBindingId.exists(_.contains("<lambda>")) =>
local // the contains check ensures that we can capture global variables into an arrow closure
case Some(param: NewMethodParameterIn)
if scope.isSurroundedByArrowClosure && !scope.surroundingMethodParams.contains(param.name) =>
createClosureBindingsForArrowClosure(expr, name)
case Some(_: NewLocal) if scope.isSurroundedByArrowClosure =>
createClosureBindingsForArrowClosure(expr, name)
case Some(local) => local
}
}

def createClosureCaptureForNode(
expr: PhpNode,
name: String,
innerMethodsIterator: Iterator[MethodScope],
surroundingMethods: List[MethodScope]
): Unit = {
surroundingMethods.foreach { currentMethod =>
val innerMethodScope = innerMethodsIterator.next()
val innerMethodNode = innerMethodScope.methodNode
val innerMethodRef = innerMethodScope.methodRefNode
innerMethodRef match {
case Some(methodRef) =>
scope.getMethodRef(innerMethodNode.fullName) match {
case None =>
diffGraph.addNode(methodRef)
diffGraph.addEdge(currentMethod.bodyNode, methodRef, EdgeTypes.AST)
scope.addMethodRef(innerMethodNode.fullName, methodRef)
case _ =>
}

val closureBindingId = if (innerMethodNode.fullName.contains(NamespaceTraversal.globalNamespaceName)) {
s"${innerMethodNode.fullName}:${name}"
} else {
s"$relativeFileName:${innerMethodNode.fullName}:${name}"
}

val closureLocal = localNode(expr, name, name, Defines.Any, Option(closureBindingId))

val closureBindingNode = createClosureBinding(closureBindingId)

scope.lookupVariable(name) match {
case Some(refLocal) =>
diffGraph.addEdge(closureBindingNode, refLocal, EdgeTypes.REF)
case _ => // do nothing
}

scope.addVariableToMethodScope(closureLocal.name, closureLocal, innerMethodNode.fullName) match {
case Some(ms) => diffGraph.addEdge(ms.bodyNode, closureLocal, EdgeTypes.AST)
case _ => // do nothing
}

diffGraph.addNode(closureBindingNode)
diffGraph.addEdge(methodRef, closureBindingNode, EdgeTypes.CAPTURE)
case None =>
logger.warn(s"No methodRef found for capturing global variable in method ${innerMethodNode.fullName}")
}
}
}

private def createClosureBindingsForArrowClosure(expr: PhpNode, name: String): NewNode = {
val surroundingIter = scope.getSurroundingMethodsForArrowClosure.drop(1).iterator
val surroundingMethods = scope.getSurroundingMethodsForArrowClosure.dropRight(1)
val localNodes = mutable.ArrayBuffer[NewLocal]()
createClosureCaptureForNode(expr, name, surroundingIter, surroundingMethods)
scope.lookupVariable(name).get
}

protected def createClosureBinding(closureBindingId: String): NewClosureBinding =
NewClosureBinding().closureBindingId(closureBindingId).evaluationStrategy(EvaluationStrategies.BY_SHARING)

protected def staticInitMethodAst(node: PhpNode, methodNode: NewMethod, body: Ast, returnType: String): Ast = {
val staticModifier = NewModifier().modifierType(ModifierTypes.STATIC)
val methodReturn = methodReturnNode(node, returnType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
val local = localNode(variable, name.name, name.name, Defines.Any)

val node = scope.addToScope(name.name, local) match {
case NamespaceScope(namespaceNode, _) => namespaceNode
case TypeScope(typeDeclNode, _) => typeDeclNode
case MethodScope(methodNode, _, _, _, _) => methodNode
case BlockScope(blockNode, _) => blockNode
case NamespaceScope(namespaceNode, _) => namespaceNode
case TypeScope(typeDeclNode, _) => typeDeclNode
case MethodScope(methodNode, _, _, _, _, _) => methodNode
case BlockScope(blockNode, _) => blockNode
}

diffGraph.addEdge(node, local, EdgeTypes.AST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package io.joern.php2cpg.astcreation
import io.joern.php2cpg.astcreation.AstCreator.{NameConstants, TypeConstants}
import io.joern.php2cpg.parser.Domain.*
import io.joern.php2cpg.parser.Domain.PhpModifiers.containsAccessModifier
import io.joern.php2cpg.utils.MethodScope
import io.joern.php2cpg.utils.{BlockScope, MethodScope}
import io.joern.x2cpg.Defines.UnresolvedSignature
import io.joern.x2cpg.utils.AstPropertiesUtil.RootProperties
import io.joern.x2cpg.{Ast, Defines, ValidationMode}
import io.shiftleft.codepropertygraph.generated.nodes.*
Expand Down Expand Up @@ -45,9 +46,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

local.closureBindingId(closureBindingId)

val closureBindingNode = NewClosureBinding()
.closureBindingId(closureBindingId)
.evaluationStrategy(EvaluationStrategies.BY_SHARING)
val closureBindingNode = createClosureBinding(closureBindingId)

scope.lookupVariable(local.name) match {
case Some(refLocal) =>
Expand Down Expand Up @@ -84,7 +83,14 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
}

val methodAst =
astForMethodDecl(methodDecl, localsForUses.map(Ast(_)), Option(methodName), usesCode = Option(usesCode))
astForMethodDecl(
methodDecl,
localsForUses.map(Ast(_)),
Option(methodName),
usesCode = Option(usesCode),
isArrowClosure = closureExpr.isArrowFunc,
closureMethodRef = Option(methodRef)
)

// Add method to scope to be attached to typeDecl later
scope.addAnonymousMethod(methodAst)
Expand All @@ -97,7 +103,9 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
bodyPrefixAsts: List[Ast] = Nil,
fullNameOverride: Option[String] = None,
isConstructor: Boolean = false,
usesCode: Option[String] = None
usesCode: Option[String] = None,
isArrowClosure: Boolean = false,
closureMethodRef: Option[NewMethodRef] = None
): Ast = {
val isStatic = decl.modifiers.contains(ModifierTypes.STATIC)
val thisParam = if (decl.isClassMethod && !isStatic) {
Expand Down Expand Up @@ -136,7 +144,9 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

val methodBodyNode = blockNode(decl)

scope.pushNewScope(MethodScope(method, methodBodyNode, method.fullName, methodRef))
scope.pushNewScope(
MethodScope(method, methodBodyNode, method.fullName, decl.params.map(_.name), methodRef, isArrowClosure)
)
scope.useFunctionDecl(methodName, fullName)

val returnType = decl.returnType.map(_.name).getOrElse(Defines.Any)
Expand All @@ -151,6 +161,28 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
val attributeAsts = decl.attributeGroups.flatMap(astForAttributeGroup)
val methodBody = blockAst(methodBodyNode, methodBodyStmts)

if (isArrowClosure) {
scope.getAndClearClosureBindings.foreach { (variableLocal, variableClosureBinding) =>
val closureBindingId = s"$methodName:${variableLocal.name}"
val closureBinding = createClosureBinding(closureBindingId)
val localNode_ =
localNode(decl, variableLocal.name, variableLocal.name, variableLocal.typeFullName, Option(closureBindingId))

scope.addClosureBinding(closureBinding, localNode_)

diffGraph.addNode(closureBinding)
diffGraph.addEdge(variableClosureBinding, localNode_, EdgeTypes.REF)

scope.addToScope(localNode_.name, localNode_) match {
case BlockScope(block, _) => diffGraph.addEdge(block, localNode_, EdgeTypes.AST)
case MethodScope(_, block, _, _, _, _) => diffGraph.addEdge(block, localNode_, EdgeTypes.AST)
case _ => // do nothing
}

diffGraph.addEdge(closureMethodRef.get, closureBinding, EdgeTypes.CAPTURE)
}
}

scope.popScope()
val methodAst = methodAstWithAnnotations(method, parameters, methodBody, methodReturn, modifiers, attributeAsts)

Expand Down Expand Up @@ -207,7 +239,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
scope.surroundingScopeFullName.map(method.astParentFullName(_))
scope.surroundingAstLabel.map(method.astParentType(_))

scope.pushNewScope(MethodScope(method, methodBodyBlock, method.fullName, Option(methodRef)))
scope.pushNewScope(MethodScope(method, methodBodyBlock, method.fullName, methodRefNode = Option(methodRef)))

val methodBody = blockAst(methodBodyBlock, initAsts)

Expand Down Expand Up @@ -270,7 +302,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

val methodBlock = NewBlock()

scope.pushNewScope(MethodScope(methodNode_, methodBlock, fullName, Option(methodRef)))
scope.pushNewScope(MethodScope(methodNode_, methodBlock, fullName, methodRefNode = Option(methodRef)))

val assignmentAsts = inits.map { init =>
astForMemberAssignment(init.originNode, init.memberNode, init.value, isField = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Scope(summary: Map[String, Seq[SymbolSummary]] = Map.empty, closureNameFn:
private var tmpClassCounter = 0
private var importedSymbols = Map.empty[String, SymbolSummary]
private val methodRefsInAst = mutable.HashMap[String, NewMethodRef]()
private val capturedVariableClosureBindings = mutable.ArrayBuffer[(NewLocal, NewClosureBinding)]()

override def pushNewScope(scopeNode: TypedScopeElement): Unit = {
val mappedNode = scopeNode match {
Expand Down Expand Up @@ -142,6 +143,14 @@ class Scope(summary: Map[String, Seq[SymbolSummary]] = Map.empty, closureNameFn:
def addMethodRef(methodRefName: String, methodRef: NewMethodRef): Unit = methodRefsInAst.put(methodRefName, methodRef)
def getMethodRef(methodRefName: String): Option[NewMethodRef] = methodRefsInAst.get(methodRefName)

def addClosureBinding(closureBinding: NewClosureBinding, localNode: NewLocal): Unit =
capturedVariableClosureBindings.addOne((localNode, closureBinding))
def getAndClearClosureBindings: List[(NewLocal, NewClosureBinding)] = {
val capturedClosureBindings = capturedVariableClosureBindings.toList
capturedVariableClosureBindings.clear()
capturedClosureBindings
}

def addVariableToMethodScope(identifier: String, variable: NewNode, methodFullName: String): Option[MethodScope] = {
stack.collectFirst {
case el @ ScopeElement(methodScope: MethodScope, _) if methodScope.fullName == methodFullName =>
Expand Down Expand Up @@ -171,6 +180,9 @@ class Scope(summary: Map[String, Seq[SymbolSummary]] = Map.empty, closureNameFn:
.collectFirst { case TypeScope(td, _) => td }
.exists(_.name.endsWith(MetaTypeDeclExtension))

def isSurroundedByArrowClosure: Boolean =
stack.map(_.scopeNode).collectFirst { case nm: MethodScope if nm.isArrowFunc => nm }.isDefined

def isTopLevel: Boolean =
getEnclosingTypeDeclTypeName.forall(_ == NamespaceTraversal.globalNamespaceName)

Expand All @@ -192,7 +204,7 @@ class Scope(summary: Map[String, Seq[SymbolSummary]] = Map.empty, closureNameFn:
.collectFirst {
case NamespaceScope(nm, _) if nm.name != NamespaceTraversal.globalNamespaceName => s"${nm.name}\\$methodName"
case TypeScope(td, _) if td.name != NamespaceTraversal.globalNamespaceName => s"${td.fullName}.$methodName"
case MethodScope(nm, _, _, _, _) if nm.name != NamespaceTraversal.globalNamespaceName =>
case MethodScope(nm, _, _, _, _, _) if nm.name != NamespaceTraversal.globalNamespaceName =>
if (namespaces.isEmpty) {
s"${nm.fullName}.$methodName"
} else {
Expand All @@ -202,9 +214,36 @@ class Scope(summary: Map[String, Seq[SymbolSummary]] = Map.empty, closureNameFn:
.getOrElse(methodName)
}

def getSurroundingArrowClosureMethodRef: Option[NewMethodRef] =
stack.map(_.scopeNode).collectFirst { case ms: MethodScope if ms.isArrowFunc => ms }.flatMap(_.methodRefNode)

def getSurroundingMethods: List[MethodScope] =
stack.map(_.scopeNode).collect { case nm: MethodScope => nm }.reverse

def getSurroundingMethodsForArrowClosure: List[MethodScope] = {
val methods = mutable.ArrayBuffer[MethodScope]()
stack
.collect { case scopeEl @ ScopeElement(_: MethodScope, _) =>
scopeEl
}
.takeWhile {
case ScopeElement(ms: MethodScope, _) if ms.isArrowFunc =>
methods.addOne(ms)
true
case ScopeElement(ms: MethodScope, _) =>
methods.addOne(ms)
false
}

methods.toList.reverse
}

def getSurroundingArrowClosures: List[MethodScope] =
stack.map(_.scopeNode).collect { case nm: MethodScope if nm.isArrowFunc => nm }.reverse

def surroundingMethodParams: List[String] =
stack.map(_.scopeNode).collectFirst { case ms: MethodScope => ms }.map(_.parameterNames.toList).get

def getConstAndStaticInits: List[PhpInit] = {
getInits(constAndStaticInits)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ case class MethodScope(
methodNode: NewMethod,
bodyNode: NewBlock,
fullName: String,
parameterNames: Seq[String] = Seq.empty,
methodRefNode: Option[NewMethodRef] = None,
isArrowFunc: Boolean = false
) extends MethodLikeScope
Expand Down
Loading
Loading