Skip to content

Commit 5710607

Browse files
committed
[ruby] capture self variable
1 parent 2af87d7 commit 5710607

File tree

4 files changed

+48
-25
lines changed

4 files changed

+48
-25
lines changed

joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForFunctionsCreator.scala

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
100100

101101
// Consider which variables are captured from the outer scope
102102
val stmtBlockAst = if (isClosure || isSingletonObjectMethod) {
103+
// create closure local used for capturing
104+
createClosureBindingInformation(scope.lookupSelfInOuterScope.toSet)
105+
.collect { case (_, name, _, Some(closureBindingId)) =>
106+
val capturingLocal = localNode(node.body, name, name, Defines.Any, closureBindingId = Option(closureBindingId))
107+
scope.addToScope(capturingLocal.name, capturingLocal)
108+
}
103109
val baseStmtBlockAst = astForMethodBody(node.body, optionalStatementList)
104110
transformAsClosureBody(node.body, refs, baseStmtBlockAst)
105111
} else {
@@ -214,22 +220,25 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
214220
private def transformAsClosureBody(originNode: RubyExpression, refs: List[Ast], baseStmtBlockAst: Ast) = {
215221
// Determine which locals are captured
216222
val capturedLocalNodes = baseStmtBlockAst.nodes
217-
.collect { case x: NewIdentifier => x }
223+
.collect { case x: NewIdentifier if x.name != Defines.Self => x }
218224
.distinctBy(_.name)
219225
.map {
220-
case i if i.name == "self" => scope.lookupSelfInOuterScope // we only need to bind to the closest self in scope
221-
case i => scope.lookupVariableInOuterScope(i.name)
226+
case i if i.name == Defines.Self => scope.lookupSelfInOuterScope
227+
case i => scope.lookupVariableInOuterScope(i.name)
222228
}
223229
.filter(_.iterator.nonEmpty)
224230
.flatten
225231
.toSet
226232

233+
val selfLocal = scope.lookupSelfInOuterScope.toSet
234+
val capturedNodes = capturedLocalNodes ++ selfLocal
235+
227236
val capturedIdentifiers = baseStmtBlockAst.nodes.collect {
228-
case i: NewIdentifier if capturedLocalNodes.map(_.name).contains(i.name) => i
237+
case i: NewIdentifier if capturedNodes.map(_.name).contains(i.name) => i
229238
}
230239
// Copy AST block detaching the REF nodes between parent locals/params and identifiers, with the closures' one
231240
val capturedBlockAst = baseStmtBlockAst.copy(refEdges = baseStmtBlockAst.refEdges.filterNot {
232-
case AstEdge(_: NewIdentifier, dst: DeclarationNew) => capturedLocalNodes.contains(dst)
241+
case AstEdge(_: NewIdentifier, dst: DeclarationNew) => capturedNodes.contains(dst)
233242
case _ => false
234243
})
235244

@@ -238,15 +247,8 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
238247
val astChildren = mutable.Buffer.empty[NewNode]
239248
val refEdges = mutable.Buffer.empty[(NewNode, NewNode)]
240249
val captureEdges = mutable.Buffer.empty[(NewNode, NewNode)]
241-
capturedLocalNodes
242-
.collect {
243-
case local: NewLocal =>
244-
val closureBindingId = scope.variableScopeFullName(local.name).map(x => s"$x.${local.name}")
245-
(local, local.name, local.code, closureBindingId)
246-
case param: NewMethodParameterIn =>
247-
val closureBindingId = scope.variableScopeFullName(param.name).map(x => s"$x.${param.name}")
248-
(param, param.name, param.code, closureBindingId)
249-
}
250+
251+
createClosureBindingInformation(capturedNodes)
250252
.collect { case (capturedLocal, name, code, Some(closureBindingId)) =>
251253
val capturingLocal =
252254
localNode(originNode, name, name, Defines.Any, closureBindingId = Option(closureBindingId))
@@ -624,4 +626,16 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
624626
case _ => false
625627
}
626628
}
629+
630+
private def createClosureBindingInformation(capturedNodes: Set[DeclarationNew]): Set[(DeclarationNew, String, String, Option[String])] = {
631+
capturedNodes
632+
.collect {
633+
case local: NewLocal =>
634+
val closureBindingId = scope.variableScopeFullName(local.name).map(x => s"$x.${local.name}")
635+
(local, local.name, local.code, closureBindingId)
636+
case param: NewMethodParameterIn =>
637+
val closureBindingId = scope.variableScopeFullName(param.name).map(x => s"$x.${param.name}")
638+
(param, param.name, param.code, closureBindingId)
639+
}
640+
}
627641
}

joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/DoBlockTests.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class DoBlockTests extends RubyCode2CpgFixture {
6868

6969
"have the return node under the closure (returning the literal)" in {
7070
inside(cpg.method("<lambda>0").block.astChildren.l) {
71-
case ret :: Nil =>
71+
case ret :: _ :: Nil =>
7272
ret.code shouldBe "\"world!\""
7373
case xs => fail(s"Expected the closure to have a single call, instead got [${xs.code.mkString(", ")}]")
7474
}
@@ -229,9 +229,9 @@ class DoBlockTests extends RubyCode2CpgFixture {
229229
}
230230
}
231231

232-
"annotate the nodes via CAPTURE bindings" in {
232+
"nnotate the nodes via CAPTURE bindings" in {
233233
cpg.closureBinding.l match {
234-
case myValue :: Nil =>
234+
case myValue :: _ :: Nil =>
235235
inside(myValue._localViaRefOut) {
236236
case Some(local) =>
237237
local.name shouldBe "myValue"
@@ -400,17 +400,19 @@ class DoBlockTests extends RubyCode2CpgFixture {
400400
|""".stripMargin)
401401

402402
inside(cpg.local.nameNot(".*<tmp-\\d>").l) {
403-
case jfsOutsideLocal :: schedules :: hashInsideLocal :: jfsCapturedLocal :: Nil =>
403+
case jfsOutsideLocal :: _ :: hashInsideLocal :: jfsCapturedLocal :: selfCapturedLocal :: Nil =>
404404
jfsOutsideLocal.closureBindingId shouldBe None
405405
hashInsideLocal.closureBindingId shouldBe None
406406
jfsCapturedLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.jfs")
407+
selfCapturedLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.<lambda>0.self")
407408
case xs => fail(s"Expected 6 locals, got ${xs.code.mkString(",")}")
408409
}
409410

410411
inside(cpg.method.isLambda.local.l) {
411-
case hashLocal :: _ :: jfsLocal :: Nil =>
412+
case hashLocal :: _ :: jfsLocal :: selfLocal :: Nil =>
412413
hashLocal.closureBindingId shouldBe None
413414
jfsLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.jfs")
415+
selfLocal.closureBindingId shouldBe Some("Test0.rb:<main>.get_pto_schedule.<lambda>0.self")
414416
case xs => fail(s"Expected 3 locals in lambda, got ${xs.code.mkString(",")}")
415417
}
416418
}
@@ -459,7 +461,7 @@ class DoBlockTests extends RubyCode2CpgFixture {
459461
gSplatLocal.code shouldBe "g"
460462

461463
selfLocal.code shouldBe "self"
462-
selfLocal.closureBindingId shouldBe Some("Test0.rb:<global>.self")
464+
selfLocal.closureBindingId shouldBe Some("Test0.rb:<main>.<lambda>0.self")
463465
case xs => fail(s"Expected 4 locals, got [${xs.name.mkString(", ")}]")
464466
}
465467
}

joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ErbTests.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class ErbTests extends RubyCode2CpgFixture {
247247
lambdaParamForm.code shouldBe "form"
248248

249249
inside(lambdaMethod.body.astChildren.l) {
250-
case _ :: _ :: (appendCallTemplate: Call) :: _ :: Nil =>
250+
case _ :: _ :: (appendCallTemplate: Call) :: _ :: _ :: Nil =>
251251
appendCallTemplate.code shouldBe "joern__inner_buffer << <%= form.text_field :name %>"
252252
val List(appendCallArgOne, appendCallArgTwo: Call) = appendCallTemplate.argument.l: @unchecked
253253

@@ -390,15 +390,15 @@ class ErbTests extends RubyCode2CpgFixture {
390390
inside(cpg.method.name("<main>").parameter.name("self").l) {
391391
case selfMain :: Nil =>
392392
selfMain._closureBindingViaRefIn.map(_.closureBindingId).toList shouldBe List(
393-
Some("index.html.erb:<global>.self")
393+
Some("index.html.erb:<main>.<lambda>0.self")
394394
)
395395
selfMain.closureBindingId shouldBe None
396396
case xs => fail(s"Expected one local in global, got ${xs.name.mkString("[", ",", "]")}")
397397
}
398398

399399
inside(cpg.method.isLambda.local.name("self").l) {
400400
case selfLocal :: Nil =>
401-
selfLocal.closureBindingId shouldBe Some("index.html.erb:<global>.self")
401+
selfLocal.closureBindingId shouldBe Some("index.html.erb:<main>.<lambda>0.self")
402402
case xs => fail(s"Expected one self local, got ${xs.name.mkString("[", ",", "]")}")
403403
}
404404
}
@@ -452,7 +452,14 @@ class ErbTests extends RubyCode2CpgFixture {
452452
cpg.method.fullNameExact("Test0.rb:<main>.Admin.UsersController.show.<lambda>0.<lambda>0").local.name("self").l
453453
) {
454454
case selfLocal :: Nil =>
455-
selfLocal.closureBindingId shouldBe Some("Test0.rb:<main>.Admin.UsersController.show.self")
455+
selfLocal.closureBindingId shouldBe Some("Test0.rb:<main>.Admin.UsersController.show.<lambda>0.<lambda>0.self")
456+
case xs => fail(s"Expected one local for self, got ${xs.name.mkString("[", ",", "]")}")
457+
}
458+
459+
inside(
460+
cpg.method.fullNameExact("Test0.rb:<main>.Admin.UsersController.show.<lambda>0").local.name("self").l) {
461+
case selfLocal :: Nil =>
462+
selfLocal.closureBindingId shouldBe Some("Test0.rb:<main>.Admin.UsersController.show.<lambda>0.self")
456463
case xs => fail(s"Expected one local for self, got ${xs.name.mkString("[", ",", "]")}")
457464
}
458465
}

joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/MethodReturnTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ class MethodReturnTests extends RubyCode2CpgFixture {
412412

413413
"have the return node under the closure (returning the literal)" in {
414414
inside(cpg.method("<lambda>0").block.astChildren.l) {
415-
case ret :: Nil =>
415+
case ret :: _ :: Nil =>
416416
ret.code shouldBe "\"hello\""
417417
case xs => fail(s"Expected the closure to have a single call, instead got [${xs.code.mkString(", ")}]")
418418
}

0 commit comments

Comments
 (0)