From ec29b98b1350a673ced250c8159f69ef164607c8 Mon Sep 17 00:00:00 2001 From: David Baker Effendi Date: Mon, 20 May 2024 21:32:46 +0200 Subject: [PATCH] [ruby] Handle Rescue Exception Lists (#4575) --- .../AstForExpressionsCreator.scala | 29 +++++++++-------- .../astcreation/AstForStatementsCreator.scala | 4 +-- .../astcreation/RubyIntermediateAst.scala | 12 +++---- .../rubysrc2cpg/parser/RubyNodeCreator.scala | 17 +++++----- .../rubysrc2cpg/querying/ClassTests.scala | 31 +++++++++++-------- 5 files changed, 50 insertions(+), 43 deletions(-) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala index a32145bd4fd1..a035c18de463 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -604,20 +604,23 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { protected def astForRescueExpression(node: RescueExpression): Ast = { val tryAst = astForStatementList(node.body.asStatementList) val rescueAsts = node.rescueClauses - .map { - case x: RescueClause => - // TODO: add exception assignment - astForStatementList(x.thenClause.asStatementList) - case x => astForUnknown(x) + .map { x => + val classes = + x.exceptionClassList.map(e => scope.tryResolveTypeReference(e.text).map(_.name).getOrElse(e.text)).toSeq + val variables = x.variables + .flatMap { v => + handleVariableOccurrence(v) + scope.lookupVariable(v.text) + } + .collect { + case x: NewLocal => Ast(x.dynamicTypeHintFullName(classes)) + case x: NewMethodParameterIn => Ast(x.dynamicTypeHintFullName(classes)) + } + .toList + astForStatementList(x.thenClause.asStatementList).withChildren(variables) } - val elseAst = node.elseClause.map { - case x: ElseClause => astForStatementList(x.thenClause.asStatementList) - case x => astForUnknown(x) - } - val ensureAst = node.ensureClause.map { - case x: EnsureClause => astForStatementList(x.thenClause.asStatementList) - case x => astForUnknown(x) - } + val elseAst = node.elseClause.map { x => astForStatementList(x.thenClause.asStatementList) } + val ensureAst = node.ensureClause.map { x => astForStatementList(x.thenClause.asStatementList) } tryCatchAst( NewControlStructure() .controlStructureType(ControlStructureTypes.TRY) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 67ca26dcaad6..4312b93f883b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -382,8 +382,8 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t // Ensure never returns a value, only the main body, rescue & else clauses RescueExpression( transform(body), - rescueClauses.map(transform), - elseClause.map(transform).orElse(defaultElseBranch(node.span)), + rescueClauses.map(transform).collect { case x: RescueClause => x }, + elseClause.map(transform).orElse(defaultElseBranch(node.span)).collect { case x: ElseClause => x }, ensureClause )(node.span) case WhileExpression(condition, body) => WhileExpression(condition, transform(body))(node.span) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala index 642295423eb5..0dc17603bcdc 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/RubyIntermediateAst.scala @@ -30,10 +30,10 @@ object RubyIntermediateAst { } implicit class RubyNodeHelper(node: RubyNode) { - def asStatementList = node match + def asStatementList: StatementList = node match { case stmtList: StatementList => stmtList case _ => StatementList(List(node))(node.span) - + } } final case class Unknown()(span: TextSpan) extends RubyNode(span) @@ -150,16 +150,16 @@ object RubyIntermediateAst { final case class RescueExpression( body: RubyNode, - rescueClauses: List[RubyNode], - elseClause: Option[RubyNode], - ensureClause: Option[RubyNode] + rescueClauses: List[RescueClause], + elseClause: Option[ElseClause], + ensureClause: Option[EnsureClause] )(span: TextSpan) extends RubyNode(span) with ControlFlowExpression final case class RescueClause( exceptionClassList: Option[RubyNode], - assignment: Option[RubyNode], + variables: Option[RubyNode], thenClause: RubyNode )(span: TextSpan) extends RubyNode(span) diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala index 244c94ccecf1..5113385d8e2c 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/parser/RubyNodeCreator.scala @@ -1083,10 +1083,11 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } override def visitBodyStatement(ctx: RubyParser.BodyStatementContext): RubyNode = { - val body = visit(ctx.compoundStatement()) - val rescueClauses = Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList) - val elseClause = Option(ctx.elseClause).map(visit) - val ensureClause = Option(ctx.ensureClause).map(visit) + val body = visit(ctx.compoundStatement()) + val rescueClauses = + Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList).collect { case x: RescueClause => x } + val elseClause = Option(ctx.elseClause).map(visit).collect { case x: ElseClause => x } + val ensureClause = Option(ctx.ensureClause).map(visit).collect { case x: EnsureClause => x } if (rescueClauses.isEmpty && elseClause.isEmpty && ensureClause.isEmpty) { visit(ctx.compoundStatement()) @@ -1096,16 +1097,14 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] { } override def visitExceptionClassList(ctx: RubyParser.ExceptionClassListContext): RubyNode = { - // Requires implementing multiple rhs with splatting - logger.warn(s"Exception class lists are not handled: '${ctx.toTextSpan}'") - Unknown()(ctx.toTextSpan) + Option(ctx.multipleRightHandSide()).map(visitMultipleRightHandSide).getOrElse(visit(ctx.operatorExpression())) } override def visitRescueClause(ctx: RubyParser.RescueClauseContext): RubyNode = { val exceptionClassList = Option(ctx.exceptionClassList).map(visit) - val elseClause = Option(ctx.exceptionVariableAssignment).map(visit) + val variables = Option(ctx.exceptionVariableAssignment).map(visit) val thenClause = visit(ctx.thenClause) - RescueClause(exceptionClassList, elseClause, thenClause)(ctx.toTextSpan) + RescueClause(exceptionClassList, variables, thenClause)(ctx.toTextSpan) } override def visitEnsureClause(ctx: RubyParser.EnsureClauseContext): RubyNode = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala index e19d7938185c..4f358d0eba49 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/querying/ClassTests.scala @@ -575,27 +575,32 @@ class ClassTests extends RubyCode2CpgFixture { "Bodies that aren't StatementList" should { val cpg = code(""" | class EventWebhook - | # * *Args* : - | # - +public_key+ -> elliptic curve public key - | # - +payload+ -> event payload in the request body - | # - +signature+ -> signature value obtained from the 'X-Twilio-Email-Event-Webhook-Signature' header - | # - +timestamp+ -> timestamp value obtained from the 'X-Twilio-Email-Event-Webhook-Timestamp' header + | ERRORS = [CustomErrorA, CustomErrorB] + | | def verify_signature(public_key, payload, signature, timestamp) | verify_engine - | timestamped_playload = "#{timestamp}#{payload}" - | payload_digest = Digest::SHA256.digest(timestamped_playload) + | timestamped_payload = "#{timestamp}#{payload}" + | payload_digest = Digest::SHA256.digest(timestamped_payload) | decoded_signature = Base64.decode64(signature) | public_key.dsa_verify_asn1(payload_digest, decoded_signature) - | rescue StandardError + | rescue *ERRORS => splat_errors + | false + | rescue StandardError => some_variable | false | end | end |""".stripMargin) - "not throw an execption" in { - inside(cpg.method.name("verify_signature").l) { - case verifySigMethod :: Nil => // Passing case - case _ => fail("Expected method for verify_sginature") - } + + "successfully parse and create the method" in { + cpg.method.nameExact("verify_signature").nonEmpty shouldBe true + } + + "create the `StandardError` local variable" in { + cpg.local.nameExact("some_variable").dynamicTypeHintFullName.toList shouldBe List("__builtin.StandardError") + } + + "create the splatted error local variable" in { + cpg.local.nameExact("splat_errors").size shouldBe 1 } }