Skip to content

Commit

Permalink
support nested case in func_body.go
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghe3119 committed Aug 21, 2019
1 parent d4f7852 commit db6d3e8
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 21 deletions.
62 changes: 41 additions & 21 deletions func_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,56 +7,76 @@ import (

// HasStmtInsideFuncBody checks if the body of function has given statement
func HasStmtInsideFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (ret bool) {
var inside bool

pre := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.FuncDecl:
if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName {
funcBody := nn.Body
for _, ss := range funcBody.List {
if nodesEqual(ss, stmt) {
ret = true
}
}
return false
inside = true
}
case dst.Stmt:
if inside && nodesEqual(node, stmt) {
ret = true
}
}
return true
}

dstutil.Apply(df, pre, nil)
post := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.FuncDecl:
if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName && inside {
inside = false
}
}

return true
}

dstutil.Apply(df, pre, post)
return
}

// DeleteStmtFromFuncBody deletes any statement, inside the body of function,
// that is semantically equal to the given statement.
func DeleteStmtFromFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) {
var inside bool

pre := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.FuncDecl:
if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName {
funcBody := nn.Body

var newList []dst.Stmt
for _, ss := range funcBody.List {
if !nodesEqual(ss, stmt) {
newList = append(newList, ss)
} else {
modified = true
}
}
inside = true
}
case dst.Stmt:
if inside && nodesEqual(node, stmt) {
c.Delete()
modified = true
}
}
return true
}

funcBody.List = newList
return false
post := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.FuncDecl:
if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName && inside {
inside = false
}
}
return true
}

dstutil.Apply(df, pre, nil)
dstutil.Apply(df, pre, post)
return
}

Expand Down
93 changes: 93 additions & 0 deletions func_body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ func TestHasStmtInsideFuncBody(t *testing.T) {
import "fmt"
func B() {
c := 1
}
func main() {
a := 1
b := 1
Expand Down Expand Up @@ -64,6 +68,17 @@ func TestHasStmtInsideFuncBody(t *testing.T) {
Value: "2",
}},
}, false},
{
&dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("c")},
Tok: token.DEFINE,
Rhs: []dst.Expr{&dst.BasicLit{
Kind: token.INT,
Value: "1",
}},
Decs: dst.AssignStmtDecorations{},
}, false,
},
}

df, _ := ParseSrcFileFromBytes([]byte(src))
Expand All @@ -72,6 +87,36 @@ func TestHasStmtInsideFuncBody(t *testing.T) {
assert.Equal(t, c.expected, HasStmtInsideFuncBody(df, "main", c.stmt))
}
})

t.Run("nested", func(t *testing.T) {
var src = `
package main
import "fmt"
func main() {
a := 1
b := 1
defer func() {
c := 3
fmt.Println(c)
}()
}
`

df, _ := ParseSrcFileFromBytes([]byte(src))
stmt := &dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("c")},
Tok: token.DEFINE,
Rhs: []dst.Expr{&dst.BasicLit{
Kind: token.INT,
Value: "3",
}},
Decs: dst.AssignStmtDecorations{},
}

assert.Equal(t, true, HasStmtInsideFuncBody(df, "main", stmt))
})
}

func TestDeleteStmtFromFuncBody(t *testing.T) {
Expand Down Expand Up @@ -144,6 +189,54 @@ func TestDeleteStmtFromFuncBody(t *testing.T) {
}
})

t.Run("nested", func(t *testing.T) {
var src = `
package main
import (
"fmt"
)
func main() {
a := 1
b := 1
defer func() {
c := 1
fmt.Println("hello world")
}()
}
`
var expected = `
package main
import (
"fmt"
)
func main() {
a := 1
b := 1
defer func() {
fmt.Println("hello world")
}()
}
`

df, _ := ParseSrcFileFromBytes([]byte(src))
stmt := &dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("c")},
Tok: token.DEFINE,
Rhs: []dst.Expr{&dst.BasicLit{
Kind: token.INT,
Value: "1",
}},
Decs: dst.AssignStmtDecorations{},
}
assert.Equal(t, true, DeleteStmtFromFuncBody(df, "main", stmt))
buf := printToBuf(df)
assertCodesEqual(t, expected, buf.String())
})

t.Run("multiple", func(t *testing.T) {
var src = `
package main
Expand Down

0 comments on commit db6d3e8

Please sign in to comment.