From db6d3e8c378781a7eab838c7de0a70db879bd1d6 Mon Sep 17 00:00:00 2001 From: zhenghe3119 Date: Wed, 21 Aug 2019 15:17:57 +0800 Subject: [PATCH] support nested case in func_body.go --- func_body.go | 62 ++++++++++++++++++++----------- func_body_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 21 deletions(-) diff --git a/func_body.go b/func_body.go index 458c811..d839c5f 100644 --- a/func_body.go +++ b/func_body.go @@ -7,56 +7,76 @@ import ( // HasStmtInsideFuncBody checks if the body of function has given statement func HasStmtInsideFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (ret bool) { + var inside bool + pre := func(c *dstutil.Cursor) bool { node := c.Node() switch node.(type) { case *dst.FuncDecl: if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName { - funcBody := nn.Body - for _, ss := range funcBody.List { - if nodesEqual(ss, stmt) { - ret = true - } - } - return false + inside = true + } + case dst.Stmt: + if inside && nodesEqual(node, stmt) { + ret = true } } return true } - dstutil.Apply(df, pre, nil) + post := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.FuncDecl: + if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName && inside { + inside = false + } + } + + return true + } + + dstutil.Apply(df, pre, post) return } // DeleteStmtFromFuncBody deletes any statement, inside the body of function, // that is semantically equal to the given statement. func DeleteStmtFromFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) { + var inside bool + pre := func(c *dstutil.Cursor) bool { node := c.Node() switch node.(type) { case *dst.FuncDecl: if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName { - funcBody := nn.Body - - var newList []dst.Stmt - for _, ss := range funcBody.List { - if !nodesEqual(ss, stmt) { - newList = append(newList, ss) - } else { - modified = true - } - } + inside = true + } + case dst.Stmt: + if inside && nodesEqual(node, stmt) { + c.Delete() + modified = true + } + } + return true + } - funcBody.List = newList - return false + post := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.FuncDecl: + if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName && inside { + inside = false } } return true } - dstutil.Apply(df, pre, nil) + dstutil.Apply(df, pre, post) return } diff --git a/func_body_test.go b/func_body_test.go index 5d66df3..d30f84a 100644 --- a/func_body_test.go +++ b/func_body_test.go @@ -15,6 +15,10 @@ func TestHasStmtInsideFuncBody(t *testing.T) { import "fmt" + func B() { + c := 1 + } + func main() { a := 1 b := 1 @@ -64,6 +68,17 @@ func TestHasStmtInsideFuncBody(t *testing.T) { Value: "2", }}, }, false}, + { + &dst.AssignStmt{ + Lhs: []dst.Expr{dst.NewIdent("c")}, + Tok: token.DEFINE, + Rhs: []dst.Expr{&dst.BasicLit{ + Kind: token.INT, + Value: "1", + }}, + Decs: dst.AssignStmtDecorations{}, + }, false, + }, } df, _ := ParseSrcFileFromBytes([]byte(src)) @@ -72,6 +87,36 @@ func TestHasStmtInsideFuncBody(t *testing.T) { assert.Equal(t, c.expected, HasStmtInsideFuncBody(df, "main", c.stmt)) } }) + + t.Run("nested", func(t *testing.T) { + var src = ` + package main + + import "fmt" + + func main() { + a := 1 + b := 1 + defer func() { + c := 3 + fmt.Println(c) + }() + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + stmt := &dst.AssignStmt{ + Lhs: []dst.Expr{dst.NewIdent("c")}, + Tok: token.DEFINE, + Rhs: []dst.Expr{&dst.BasicLit{ + Kind: token.INT, + Value: "3", + }}, + Decs: dst.AssignStmtDecorations{}, + } + + assert.Equal(t, true, HasStmtInsideFuncBody(df, "main", stmt)) + }) } func TestDeleteStmtFromFuncBody(t *testing.T) { @@ -144,6 +189,54 @@ func TestDeleteStmtFromFuncBody(t *testing.T) { } }) + t.Run("nested", func(t *testing.T) { + var src = ` + package main + + import ( + "fmt" + ) + + func main() { + a := 1 + b := 1 + defer func() { + c := 1 + fmt.Println("hello world") + }() + } + ` + var expected = ` + package main + + import ( + "fmt" + ) + + func main() { + a := 1 + b := 1 + defer func() { + fmt.Println("hello world") + }() + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + stmt := &dst.AssignStmt{ + Lhs: []dst.Expr{dst.NewIdent("c")}, + Tok: token.DEFINE, + Rhs: []dst.Expr{&dst.BasicLit{ + Kind: token.INT, + Value: "1", + }}, + Decs: dst.AssignStmtDecorations{}, + } + assert.Equal(t, true, DeleteStmtFromFuncBody(df, "main", stmt)) + buf := printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) + t.Run("multiple", func(t *testing.T) { var src = ` package main