Skip to content

Commit

Permalink
support AddStmtToFuncLitBody and SetMethodOnReceiver
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghe3119 committed Nov 5, 2019
1 parent db6d3e8 commit 3f5164e
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ FprintFile(out io.Writer, df *dst.File) error
HasStmtInsideFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (ret bool)
DeleteStmtFromFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (modified bool)
AddStmtToFuncBody(df *dst.File, funcName string, stmt dst.Stmt, pos int) (modified bool)
AddStmtToFuncLitBody(df *dst.File, stmt dst.Stmt, pos int) (modified bool)
AddStmtToFuncBodyStart(df *dst.File, funcName string, stmt dst.Stmt) (modified bool)
AddStmtToFuncBodyEnd(df *dst.File, funcName string, stmt dst.Stmt) (modified bool)
AddStmtToFuncBodyBefore(df *dst.File, funcName string, stmt, refStmt dst.Stmt) (modified bool)
Expand All @@ -53,6 +54,7 @@ AddStmtToFuncBodyAfter(df *dst.File, funcName string, stmt, refStmt dst.Stmt) (m
HasArgInCallExpr(df *dst.File, funcName string, arg dst.Expr) (ret bool)
DeleteArgFromCallExpr(df *dst.File, funcName string, arg dst.Expr) (modified bool)
AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (modified bool)
SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (modified bool)
```

### function declaration utilities
Expand Down
143 changes: 143 additions & 0 deletions examples/rpc_adapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package main

import (
"github.com/ZhengHe-MD/gorefactor"
"github.com/dave/dst"
"go/token"
"log"
"os"
)

func main() {
var src = `
package main
func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error {
return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error {
ct, ok := c.(*AccountServiceClient)
if ok {
return fn(ct)
} else {
return fmt.Errorf("reflect client thrift error")
}
})
}
func RegisterGuest(ctx context.Context, req *RegisterGuestReq) (r *RegisterGuestRes) {
tctx := trace.NewThriftUtilContextFromContext(ctx)
err := rpc(ctx, strconv.Itoa(rand.Int()), time.Millisecond*1000,
func(c *AccountServiceClient) (er error) {
r, er = c.RegisterGuest(req, tctx)
return er
},
)
if err != nil {
r = &RegisterGuestRes{
Errinfo: tiperr.NewInternalErr(-1001, fmt.Sprintf("call serice:%s proc:%s m:RegisterGuest err:%s", service_key, proc_thrift, err)),
}
}
return
}
`

//var expected = `
//package main
//
//func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error {
// return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(fctx context.Context, c interface{}) error {
// ct, ok := c.(*AccountServiceClient)
// if ok {
// return fn(fctx, ct)
// } else {
// return fmt.Errorf("reflect client thrift error")
// }
// })
//}
//
//func RegisterGuest(ctx context.Context, req *RegisterGuestReq) (r *RegisterGuestRes) {
// err := rpc(ctx, strconv.Itoa(rand.Int()), time.Millisecond*1000,
// func(fctx context.Context, c *AccountServiceClient) (er error) {
// tctx := trace.NewThriftUtilContextFromContext(fctx)
// r, er = c.RegisterGuest(req, tctx)
// return er
// },
// )
// if err != nil {
// r = &RegisterGuestRes{
// Errinfo: tiperr.NewInternalErr(-1001, fmt.Sprintf("call serice:%s proc:%s m:RegisterGuest err:%s", service_key, proc_thrift, err)),
// }
// }
// return
//}
//`

df, err := gorefactor.ParseSrcFileFromBytes([]byte(src))
if err != nil {
log.Println(err)
return
}

gorefactor.SetMethodOnReceiver(df, "clientThrift", "RpcWithContext", "RpcWithContextV2")
gorefactor.AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0)
gorefactor.AddFieldToFuncLitParams(df, &dst.Field{
Names: []*dst.Ident{dst.NewIdent("fctx")},
Type: dst.NewIdent("context.Context"),
}, 0)
gorefactor.DeleteStmtFromFuncBody(df, "RegisterGuest", &dst.AssignStmt{
Lhs: []dst.Expr{
dst.NewIdent("tctx"),
},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.SelectorExpr{
X: dst.NewIdent("trace"),
Sel: dst.NewIdent("NewThriftUtilContextFromContext"),
},
Args: []dst.Expr{
dst.NewIdent("ctx"),
},
},
},
})
gorefactor.AddStmtToFuncLitBody(df, &dst.AssignStmt{
Lhs: []dst.Expr{
dst.NewIdent("tctx"),
},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.SelectorExpr{
X: dst.NewIdent("trace"),
Sel: dst.NewIdent("NewThriftUtilContextFromContext"),
},
Args: []dst.Expr{
dst.NewIdent("tctx"),
},
},
},
}, 0)
gorefactor.DeleteStmtFromFuncBody(df, "rpc", &dst.AssignStmt{
Lhs: []dst.Expr{
dst.NewIdent("tctx"),
},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.SelectorExpr{
X: dst.NewIdent("trace"),
Sel: dst.NewIdent("NewThriftUtilContextFromContext"),
},
Args: []dst.Expr{
dst.NewIdent("tctx"),
},
},
},
})

err = gorefactor.FprintFile(os.Stdout, df)
if err != nil {
log.Println(err)
return
}
}
22 changes: 22 additions & 0 deletions func_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ func AddStmtToFuncBody(df *dst.File, funcName string, stmt dst.Stmt, pos int) (m
return
}

func AddStmtToFuncLitBody(df *dst.File, stmt dst.Stmt, pos int) (modified bool) {
pre := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.FuncLit:
nn := node.(*dst.FuncLit)
stmtList := nn.Body.List
pos = normalizePos(pos, len(stmtList))

nn.Body.List = append(
stmtList[:pos],
append([]dst.Stmt{dst.Clone(stmt).(dst.Stmt)}, stmtList[pos:]...)...)
modified = true
}
return true
}

dstutil.Apply(df, pre, nil)
return
}

// AddStmtToFuncBodyStart adds given statement, to the start of function body
func AddStmtToFuncBodyStart(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) {
return AddStmtToFuncBody(df, funcName, stmt, 0)
Expand Down
39 changes: 39 additions & 0 deletions func_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,42 @@ func AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (mod
dstutil.Apply(df, pre, nil)
return
}

func SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (modified bool) {
pre := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.CallExpr:
nn := node.(*dst.CallExpr)

switch nn.Fun.(type) {
case *dst.Ident:
si := nn.Fun.(*dst.Ident)
if si.Path == receiver && si.Name == oldMethod {
si.Name = newMethod
modified = true
}
case *dst.SelectorExpr:
se := nn.Fun.(*dst.SelectorExpr)

// TODO: deal with other se.X type
switch se.X.(type) {
case *dst.Ident:
xi := se.X.(*dst.Ident)
if xi.Name != receiver {
return true
}
}
if se.Sel.Name == oldMethod {
se.Sel.Name = newMethod
modified = true
}
}
}
return true
}

dstutil.Apply(df, pre, nil)
return
}
38 changes: 38 additions & 0 deletions func_call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,4 +505,42 @@ func TestAddArgToCallExpr(t *testing.T) {
assertCodesEqual(t, expected, buf.String())
}
})

t.Run("complex case 1", func(t *testing.T) {
var src = `
package main
func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error {
return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error {
ct, ok := c.(*AccountServiceClient)
if ok {
return fn(ct)
} else {
return fmt.Errorf("reflect client thrift error")
}
})
}
`

var expected = `
package main
func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error {
return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error {
ct, ok := c.(*AccountServiceClient)
if ok {
return fn(fctx, ct)
} else {
return fmt.Errorf("reflect client thrift error")
}
})
}
`

df, _ := ParseSrcFileFromBytes([]byte(src))
var buf *bytes.Buffer
assert.True(t, true, AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0))
buf = printToBuf(df)
assertCodesEqual(t, expected, buf.String())
})
}
25 changes: 24 additions & 1 deletion func_decl.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func AddFieldToFuncDeclParams(df *dst.File, funcName string, field *dst.Field, p

switch node.(type) {
case *dst.FuncDecl:
if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName {
nn := node.(*dst.FuncDecl)
if nn.Name.Name == funcName {
funcType := nn.Type
fieldList := funcType.Params.List
pos = normalizePos(pos, len(fieldList))
Expand All @@ -82,3 +83,25 @@ func AddFieldToFuncDeclParams(df *dst.File, funcName string, field *dst.Field, p
dstutil.Apply(df, pre, nil)
return
}

func AddFieldToFuncLitParams(df *dst.File, field *dst.Field, pos int) (modified bool) {
pre := func(c *dstutil.Cursor) bool {
node := c.Node()

switch node.(type) {
case *dst.FuncLit:
nn := node.(*dst.FuncLit)

fieldList := nn.Type.Params.List
pos = normalizePos(pos, len(fieldList))
nn.Type.Params.List = append(
fieldList[:pos],
append([]*dst.Field{dst.Clone(field).(*dst.Field)}, fieldList[pos:]...)...)
modified = true
}
return true
}

dstutil.Apply(df, pre, nil)
return
}
Loading

0 comments on commit 3f5164e

Please sign in to comment.