Skip to content

Commit

Permalink
[gosrc2cpg] - Package level TypeDecl constructor handling (joernio#3899)
Browse files Browse the repository at this point in the history
1. Pending Package level TypeDecl constructor handling is added to handle the initialisation of global variables with the respective RHS side. Handling this was a little tricky as global variables might get defined in different files. It is being handled by caching the package level global variable assignment statement AST mapped against the package TypeDecl fullname. We create this package-level constructor Method node in a separate pass by using this cached information.
2. Added simple data flow test as well.

TODO:
1. Need to add more unit tests around package level variable Lambda expression initialisation as well as data flow tests for the same.
2. More unit tests around the general package-level variable declaration.
3. Handling of closure use case.
4. More data flow tests around lambda expression handling.
  • Loading branch information
pandurangpatil authored Dec 7, 2023
1 parent 7682b74 commit cdcaf26
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package io.joern.gosrc2cpg

import better.files.File
import io.joern.gosrc2cpg.datastructures.GoGlobal
import io.joern.gosrc2cpg.model.GoModHelper
import io.joern.gosrc2cpg.parser.GoAstJsonParser
import io.joern.gosrc2cpg.passes.{AstCreationPass, DownloadDependenciesPass, MethodAndTypeCacheBuilderPass}
import io.joern.gosrc2cpg.passes.{
AstCreationPass,
DownloadDependenciesPass,
MethodAndTypeCacheBuilderPass,
PackageCtorCreationPass
}
import io.joern.gosrc2cpg.utils.AstGenRunner
import io.joern.gosrc2cpg.utils.AstGenRunner.GoAstGenRunnerResult
import io.joern.x2cpg.X2Cpg.withNewEmptyCpg
Expand Down Expand Up @@ -32,7 +38,8 @@ class GoSrc2Cpg extends X2CpgFrontend[Config] {
val astCreators =
new MethodAndTypeCacheBuilderPass(Some(cpg), astGenResult.parsedFiles, config, goMod).process()
new AstCreationPass(cpg, astCreators, config, report).createAndApply()
// TypeNodePass.withRegisteredTypes(GoGlobal.typesSeen(), cpg).createAndApply()
if GoGlobal.pkgLevelVarAndConstantAstMap.size() > 0 then
new PackageCtorCreationPass(cpg, config).createAndApply()
report.print()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,17 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
val typeInfo = createParserNodeInfo(x(ParserKeys.Type))
val (typeFullName, typeFullNameForcode, isVariadic, _) = processTypeInfo(typeInfo, genericTypeMethodMap)
x(ParserKeys.Names).arrOpt
/*
While generating the signature for a function structure
func test(a, b int, c string) int {
}
it works as there is no situation where parameter name will not be there.
As we re reuse the same function for generating the signature for lambda types as below
type Operation func(int, int) int
Now in this case there is no parameter name exist, in order to handle this situation add this empty string as default value,
which only facilitates adding the parameter type to list.
*/
.getOrElse(List(""))
.map(_ => {
// We are returning same type from x object for each name in the names array.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
.flatMap { parserNode =>
val localParserNode = createParserNodeInfo(parserNode)
if globalStatements then {
astForGlobalVarAndConstants(typeFullName.getOrElse(Defines.anyTypeName), localParserNode)
if !recordVar then
astForGlobalVarAndConstants(typeFullName.getOrElse(Defines.anyTypeName), localParserNode)
Seq.empty
} else {
Seq(astForLocalNode(localParserNode, typeFullName, recordVar)) ++: astForNode(localParserNode)
Expand All @@ -96,7 +97,7 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
): (Ast, Ast) = {
val rhsAst = astForBooleanLiteral(rhsParserNode)
val rhsTypeFullName = typeFullName.getOrElse(getTypeFullNameFromAstNode(rhsAst))
if (globalStatements) {
if (globalStatements && !recordVar) {
astForGlobalVarAndConstants(rhsTypeFullName, lhsParserNode, Some(rhsAst))
(Ast(), Ast())
} else {
Expand Down Expand Up @@ -128,6 +129,24 @@ trait AstForGenDeclarationCreator(implicit withSchemaValidation: ValidationMode)
.astParentFullName(fullyQualifiedPackage)
)
Ast.storeInDiffGraph(memberAst, diffGraph)
rhsAst match
case Some(rhsSeqAst) =>
// Only in case rhs ast is present then the respective variable or constant will be added as part
// of package level initializer/constructor statement
val lhsAst = astForPackageGlobalFieldAccess(typeFullName, name, lhsParserNode)
val arguments = Seq(lhsAst) ++: rhsSeqAst
val cNode = callNode(
lhsParserNode,
lhsParserNode.code,
Operators.assignment,
Operators.assignment,
DispatchTypes.STATIC_DISPATCH,
None,
Some(typeFullName)
)
GoGlobal.recordPkgLevelVarAndConstantAst(fullyQualifiedPackage, callAst(cNode, arguments), relPathFileName)
case _ =>

}

protected def astForLocalNode(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.joern.gosrc2cpg.astcreation

import io.joern.x2cpg.{Ast, AstCreatorBase, ValidationMode, Defines as XDefines}
import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, NodeTypes}
import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewMethod, NewMethodReturn}
import org.apache.commons.lang.StringUtils
import overflowdb.BatchedUpdate.DiffGraphBuilder

import scala.collection.immutable.Set

class AstForPackageConstructorCreator(val pacakgePath: String, statements: Set[(Ast, String)])(implicit
withSchemaValidation: ValidationMode
) extends AstCreatorBase(pacakgePath) {

override def createAst(): DiffGraphBuilder = {
val name = StringUtils.normalizeSpace(s"$pacakgePath${XDefines.StaticInitMethodName}")
val fakeGlobalMethod =
NewMethod()
.name(name)
.code(name)
.fullName(name)
.filename(pacakgePath)
.astParentType(NodeTypes.TYPE_DECL)
.astParentFullName(pacakgePath)
.isExternal(false)
.lineNumber(0)
.columnNumber(0)
.lineNumberEnd(0)
.columnNumberEnd(0)

val blockNode_ = NewBlock()
.code(name)
.typeFullName(Defines.voidTypeName)
.lineNumber(0)
.columnNumber(0)

val declsAsts = statements.map(_._1).toList
setArgumentIndices(declsAsts)

val methodReturn = NewMethodReturn()
.typeFullName(Defines.voidTypeName)
.code("RET")
.evaluationStrategy(EvaluationStrategies.BY_VALUE)
.lineNumber(0)
.columnNumber(0)
val ctorAst = methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn)
Ast.storeInDiffGraph(ctorAst, diffGraph)
diffGraph
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode) { t
}
}

private def astForPackageGlobalFieldAccess(
protected def astForPackageGlobalFieldAccess(
fieldTypeFullName: String,
identifierName: String,
ident: ParserNodeInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ trait AstForTypeDeclCreator(implicit withSchemaValidation: ValidationMode) { thi
)
(identifierAsts, fieldTypeFullName)
}

protected def astForFieldAccess(info: ParserNodeInfo): Seq[Ast] = {
val (identifierAsts, fieldTypeFullName) = processReceiver(info)
val fieldIdentifier = info.json(ParserKeys.Sel)(ParserKeys.Name).str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.joern.gosrc2cpg.datastructures

import io.joern.gosrc2cpg.astcreation.Defines
import io.joern.x2cpg.datastructures.Global

import io.joern.x2cpg.Ast
import java.util.concurrent.ConcurrentHashMap
import scala.jdk.CollectionConverters.EnumerationHasAsScala

Expand All @@ -25,6 +25,8 @@ object GoGlobal extends Global {

val lambdaSignatureToLambdaTypeMap: ConcurrentHashMap[String, Set[(String, String)]] = new ConcurrentHashMap()

val pkgLevelVarAndConstantAstMap: ConcurrentHashMap[String, Set[(Ast, String)]] = new ConcurrentHashMap()

// Mapping method fullname to its return type and signature
val methodFullNameReturnTypeMap: ConcurrentHashMap[String, (String, String)] = new ConcurrentHashMap()

Expand Down Expand Up @@ -59,6 +61,17 @@ object GoGlobal extends Global {
methodFullNameReturnTypeMap.putIfAbsent(methodFullName, (returnType, signature))
}

def recordPkgLevelVarAndConstantAst(pkg: String, ast: Ast, filePath: String): Unit = {
synchronized {
Option(pkgLevelVarAndConstantAstMap.get(pkg)) match {
case Some(existingList) =>
val t = (ast, filePath)
pkgLevelVarAndConstantAstMap.put(pkg, existingList + t)
case None => pkgLevelVarAndConstantAstMap.put(pkg, Set((ast, filePath)))
}
}
}

def recordLambdaSigntureToLambdaType(
signature: String,
lambdaStructTypeFullName: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.joern.gosrc2cpg.passes

import io.joern.gosrc2cpg.Config
import io.joern.gosrc2cpg.astcreation.AstForPackageConstructorCreator
import io.joern.gosrc2cpg.datastructures.GoGlobal
import io.joern.x2cpg.Ast
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.passes.ConcurrentWriterCpgPass

import scala.jdk.CollectionConverters.*

class PackageCtorCreationPass(cpg: Cpg, config: Config)
extends ConcurrentWriterCpgPass[(String, Set[(Ast, String)])](cpg) {
override def generateParts(): Array[(String, Set[(Ast, String)])] =
GoGlobal.pkgLevelVarAndConstantAstMap
.keys()
.asScala
.map(key => (key, GoGlobal.pkgLevelVarAndConstantAstMap.get(key)))
.toArray

override def runOnPart(diffGraph: DiffGraphBuilder, part: (String, Set[(Ast, String)])): Unit = {
val (packageStr, statementAsts) = part
val packageCtorAstCreator = new AstForPackageConstructorCreator(packageStr, statementAsts)(config.schemaValidation)
val localDiff = packageCtorAstCreator.createAst()
diffGraph.absorb(localDiff)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.joern.go2cpg.dataflow

import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.semanticcpg.language.*
import io.joern.dataflowengineoss.language.*
import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.shiftleft.semanticcpg.language.*
import java.io.File
class GlobalVariableDataflowTests extends GoCodeToCpgSuite(withOssDataflow = true) {

"Global variable declaration check" should {
val cpg = code("""
|package main
|const (
| FooConst = "Test"
|)
|var (
| BarVar = 100
|)
|func main() {
| println(FooConst)
|}
|""".stripMargin)
"get the data flow from global variable to println sink" in {
val source = cpg.literal("\"Test\"")
val sink = cpg.call("println")
sink.reachableByFlows(source).size shouldBe 1
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.joern.go2cpg.passes.ast

import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.joern.x2cpg.Defines
import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.semanticcpg.language.*

Expand Down Expand Up @@ -37,8 +38,21 @@ class GlobalVariableAndConstantTests extends GoCodeToCpgSuite {
}

"Be correct for Field Access CALL Node for Global variable access" in {
val List(x) = cpg.call(Operators.fieldAccess).l
x.typeFullName shouldBe "string"
val List(a, b, c) = cpg.call(Operators.fieldAccess).l
a.lineNumber shouldBe Some(10)
b.lineNumber shouldBe Some(4)
c.lineNumber shouldBe Some(7)
}

"Create Constructor method for Package level global variable initialisation" in {
val List(x) = cpg.method(s".*${Defines.StaticInitMethodName}").l
x.fullName shouldBe s"main${Defines.StaticInitMethodName}"
}

"Be correct for Literal nodes " in {
val List(a, b) = cpg.literal.l
a.code shouldBe "\"Test\""
b.code shouldBe "100"
}
}

Expand Down Expand Up @@ -71,8 +85,9 @@ class GlobalVariableAndConstantTests extends GoCodeToCpgSuite {
)

"Be correct for Field Access CALL Node for Global variable access" in {
val List(x) = cpg.call(Operators.fieldAccess).l
x.typeFullName shouldBe "string"
val List(x, y) = cpg.call(Operators.fieldAccess).l
x.code shouldBe "lib1.SchemeHTTP"
y.code shouldBe "SchemeHTTP"
}

"Check methodfullname of variable imported from other package " in {
Expand Down Expand Up @@ -110,8 +125,9 @@ class GlobalVariableAndConstantTests extends GoCodeToCpgSuite {
)

"Be correct for Field Access CALL Node for Global variable access" in {
val List(x) = cpg.call(Operators.fieldAccess).l
x.typeFullName shouldBe "string"
val List(x, y) = cpg.call(Operators.fieldAccess).l
x.code shouldBe "lib1.SchemeHTTP"
y.code shouldBe "SchemeHTTP"
}

"Check methodfullname of variable imported from other package " in {
Expand Down Expand Up @@ -148,8 +164,9 @@ class GlobalVariableAndConstantTests extends GoCodeToCpgSuite {
)

"Be correct for Field Access CALL Node for Global variable access" in {
val List(x) = cpg.call(Operators.fieldAccess).l
x.typeFullName shouldBe "string"
val List(x, y) = cpg.call(Operators.fieldAccess).l
x.code shouldBe "lib1.SchemeHTTP"
y.code shouldBe "SchemeHTTP"
}

"Check methodfullname of constant imported from other package " in {
Expand Down Expand Up @@ -189,9 +206,10 @@ class GlobalVariableAndConstantTests extends GoCodeToCpgSuite {
)

"Be correct for Field Access CALL Node for Global variable access" in {
val List(a, b) = cpg.call(Operators.fieldAccess).l
val List(a, b, c) = cpg.call(Operators.fieldAccess).l
a.typeFullName shouldBe "string"
b.typeFullName shouldBe "joern.io/sample/lib2.SchemeHTTP.<FieldAccess>.<unknown>"
c.code shouldBe "SchemeHTTP"
}

"Check methodfullname of constant imported from other package " in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1607,8 +1607,12 @@ class MethodTests extends GoCodeToCpgSuite {
}

"Check fieldAccess node for global variable access" in {
val List(x) = cpg.call(Operators.fieldAccess).l
x.typeFullName shouldBe "string"
val List(a, b, c) = cpg.call(Operators.fieldAccess).l
a.typeFullName shouldBe "string"
b.typeFullName shouldBe "main.Name"
b.code shouldBe "person"
c.typeFullName shouldBe "string"
c.code shouldBe "personName"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,7 @@ class GoCodeToCpgSuite(fileSuffix: String = ".go", withOssDataflow: Boolean = fa
GoGlobal.methodFullNameReturnTypeMap.clear()
GoGlobal.aliasToNameSpaceMapping.clear()
GoGlobal.structTypeMemberTypeMapping.clear()
GoGlobal.lambdaSignatureToLambdaTypeMap.clear()
GoGlobal.pkgLevelVarAndConstantAstMap.clear()
}
}

0 comments on commit cdcaf26

Please sign in to comment.