diff --git a/README.md b/README.md index 0cb54ff..12d5ced 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ FprintFile(out io.Writer, df *dst.File) error HasStmtInsideFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (ret bool) DeleteStmtFromFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) AddStmtToFuncBody(df *dst.File, funcName string, stmt dst.Stmt, pos int) (modified bool) +AddStmtToFuncLitBody(df *dst.File, stmt dst.Stmt, pos int) (modified bool) AddStmtToFuncBodyStart(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) AddStmtToFuncBodyEnd(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) AddStmtToFuncBodyBefore(df *dst.File, funcName string, stmt, refStmt dst.Stmt) (modified bool) @@ -53,6 +54,7 @@ AddStmtToFuncBodyAfter(df *dst.File, funcName string, stmt, refStmt dst.Stmt) (m HasArgInCallExpr(df *dst.File, funcName string, arg dst.Expr) (ret bool) DeleteArgFromCallExpr(df *dst.File, funcName string, arg dst.Expr) (modified bool) AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (modified bool) +SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (modified bool) ``` ### function declaration utilities diff --git a/examples/rpc_adapter.go b/examples/rpc_adapter.go new file mode 100644 index 0000000..da89fa6 --- /dev/null +++ b/examples/rpc_adapter.go @@ -0,0 +1,143 @@ +package main + +import ( + "github.com/ZhengHe-MD/gorefactor" + "github.com/dave/dst" + "go/token" + "log" + "os" +) + +func main() { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + + func RegisterGuest(ctx context.Context, req *RegisterGuestReq) (r *RegisterGuestRes) { + tctx := trace.NewThriftUtilContextFromContext(ctx) + err := rpc(ctx, strconv.Itoa(rand.Int()), time.Millisecond*1000, + func(c *AccountServiceClient) (er error) { + r, er = c.RegisterGuest(req, tctx) + return er + }, + ) + if err != nil { + r = &RegisterGuestRes{ + Errinfo: tiperr.NewInternalErr(-1001, fmt.Sprintf("call serice:%s proc:%s m:RegisterGuest err:%s", service_key, proc_thrift, err)), + } + } + return + } + ` + + //var expected = ` + //package main + // + //func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + // return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(fctx context.Context, c interface{}) error { + // ct, ok := c.(*AccountServiceClient) + // if ok { + // return fn(fctx, ct) + // } else { + // return fmt.Errorf("reflect client thrift error") + // } + // }) + //} + // + //func RegisterGuest(ctx context.Context, req *RegisterGuestReq) (r *RegisterGuestRes) { + // err := rpc(ctx, strconv.Itoa(rand.Int()), time.Millisecond*1000, + // func(fctx context.Context, c *AccountServiceClient) (er error) { + // tctx := trace.NewThriftUtilContextFromContext(fctx) + // r, er = c.RegisterGuest(req, tctx) + // return er + // }, + // ) + // if err != nil { + // r = &RegisterGuestRes{ + // Errinfo: tiperr.NewInternalErr(-1001, fmt.Sprintf("call serice:%s proc:%s m:RegisterGuest err:%s", service_key, proc_thrift, err)), + // } + // } + // return + //} + //` + + df, err := gorefactor.ParseSrcFileFromBytes([]byte(src)) + if err != nil { + log.Println(err) + return + } + + gorefactor.SetMethodOnReceiver(df, "clientThrift", "RpcWithContext", "RpcWithContextV2") + gorefactor.AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0) + gorefactor.AddFieldToFuncLitParams(df, &dst.Field{ + Names: []*dst.Ident{dst.NewIdent("fctx")}, + Type: dst.NewIdent("context.Context"), + }, 0) + gorefactor.DeleteStmtFromFuncBody(df, "RegisterGuest", &dst.AssignStmt{ + Lhs: []dst.Expr{ + dst.NewIdent("tctx"), + }, + Tok: token.DEFINE, + Rhs: []dst.Expr{ + &dst.CallExpr{ + Fun: &dst.SelectorExpr{ + X: dst.NewIdent("trace"), + Sel: dst.NewIdent("NewThriftUtilContextFromContext"), + }, + Args: []dst.Expr{ + dst.NewIdent("ctx"), + }, + }, + }, + }) + gorefactor.AddStmtToFuncLitBody(df, &dst.AssignStmt{ + Lhs: []dst.Expr{ + dst.NewIdent("tctx"), + }, + Tok: token.DEFINE, + Rhs: []dst.Expr{ + &dst.CallExpr{ + Fun: &dst.SelectorExpr{ + X: dst.NewIdent("trace"), + Sel: dst.NewIdent("NewThriftUtilContextFromContext"), + }, + Args: []dst.Expr{ + dst.NewIdent("tctx"), + }, + }, + }, + }, 0) + gorefactor.DeleteStmtFromFuncBody(df, "rpc", &dst.AssignStmt{ + Lhs: []dst.Expr{ + dst.NewIdent("tctx"), + }, + Tok: token.DEFINE, + Rhs: []dst.Expr{ + &dst.CallExpr{ + Fun: &dst.SelectorExpr{ + X: dst.NewIdent("trace"), + Sel: dst.NewIdent("NewThriftUtilContextFromContext"), + }, + Args: []dst.Expr{ + dst.NewIdent("tctx"), + }, + }, + }, + }) + + err = gorefactor.FprintFile(os.Stdout, df) + if err != nil { + log.Println(err) + return + } +} \ No newline at end of file diff --git a/func_body.go b/func_body.go index d839c5f..1e4c854 100644 --- a/func_body.go +++ b/func_body.go @@ -149,6 +149,28 @@ func AddStmtToFuncBody(df *dst.File, funcName string, stmt dst.Stmt, pos int) (m return } +func AddStmtToFuncLitBody(df *dst.File, stmt dst.Stmt, pos int) (modified bool) { + pre := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.FuncLit: + nn := node.(*dst.FuncLit) + stmtList := nn.Body.List + pos = normalizePos(pos, len(stmtList)) + + nn.Body.List = append( + stmtList[:pos], + append([]dst.Stmt{dst.Clone(stmt).(dst.Stmt)}, stmtList[pos:]...)...) + modified = true + } + return true + } + + dstutil.Apply(df, pre, nil) + return +} + // AddStmtToFuncBodyStart adds given statement, to the start of function body func AddStmtToFuncBodyStart(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) { return AddStmtToFuncBody(df, funcName, stmt, 0) diff --git a/func_call.go b/func_call.go index f1018f6..cd6c9c7 100644 --- a/func_call.go +++ b/func_call.go @@ -111,3 +111,42 @@ func AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (mod dstutil.Apply(df, pre, nil) return } + +func SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (modified bool) { + pre := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.CallExpr: + nn := node.(*dst.CallExpr) + + switch nn.Fun.(type) { + case *dst.Ident: + si := nn.Fun.(*dst.Ident) + if si.Path == receiver && si.Name == oldMethod { + si.Name = newMethod + modified = true + } + case *dst.SelectorExpr: + se := nn.Fun.(*dst.SelectorExpr) + + // TODO: deal with other se.X type + switch se.X.(type) { + case *dst.Ident: + xi := se.X.(*dst.Ident) + if xi.Name != receiver { + return true + } + } + if se.Sel.Name == oldMethod { + se.Sel.Name = newMethod + modified = true + } + } + } + return true + } + + dstutil.Apply(df, pre, nil) + return +} \ No newline at end of file diff --git a/func_call_test.go b/func_call_test.go index 48d982a..319a56a 100644 --- a/func_call_test.go +++ b/func_call_test.go @@ -505,4 +505,42 @@ func TestAddArgToCallExpr(t *testing.T) { assertCodesEqual(t, expected, buf.String()) } }) + + t.Run("complex case 1", func(t *testing.T) { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + var expected = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(fctx, ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + var buf *bytes.Buffer + assert.True(t, true, AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0)) + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) } diff --git a/func_decl.go b/func_decl.go index 0d802bd..2504eed 100644 --- a/func_decl.go +++ b/func_decl.go @@ -65,7 +65,8 @@ func AddFieldToFuncDeclParams(df *dst.File, funcName string, field *dst.Field, p switch node.(type) { case *dst.FuncDecl: - if nn := node.(*dst.FuncDecl); nn.Name.Name == funcName { + nn := node.(*dst.FuncDecl) + if nn.Name.Name == funcName { funcType := nn.Type fieldList := funcType.Params.List pos = normalizePos(pos, len(fieldList)) @@ -82,3 +83,25 @@ func AddFieldToFuncDeclParams(df *dst.File, funcName string, field *dst.Field, p dstutil.Apply(df, pre, nil) return } + +func AddFieldToFuncLitParams(df *dst.File, field *dst.Field, pos int) (modified bool) { + pre := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.FuncLit: + nn := node.(*dst.FuncLit) + + fieldList := nn.Type.Params.List + pos = normalizePos(pos, len(fieldList)) + nn.Type.Params.List = append( + fieldList[:pos], + append([]*dst.Field{dst.Clone(field).(*dst.Field)}, fieldList[pos:]...)...) + modified = true + } + return true + } + + dstutil.Apply(df, pre, nil) + return +} diff --git a/func_decl_test.go b/func_decl_test.go index 0d4f7f9..19ee6ce 100644 --- a/func_decl_test.go +++ b/func_decl_test.go @@ -225,6 +225,65 @@ func TestDeleteFieldFromFuncDeclParams(t *testing.T) { assertCodesEqual(t, fmt.Sprintf(expectedTemplate, c.expectedArgs), buf.String()) } }) + + t.Run("decl inside decl2", func(t *testing.T) { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + var expected = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + fn := &dst.Field{ + Names: []*dst.Ident{dst.NewIdent("fn")}, + Type: &dst.FuncType{ + Params: &dst.FieldList{ + List: []*dst.Field{ + { + Type: &dst.StarExpr{ + X: dst.NewIdent("AccountServiceClient"), + }, + }, + }, + }, + Results: &dst.FieldList{ + List: []*dst.Field{ + { + Type: dst.NewIdent("error"), + }, + }, + }, + }, + } + modified := DeleteFieldFromFuncDeclParams(df, "rpc", fn) + assert.Equal(t, true, modified) + var buf *bytes.Buffer + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) } func TestAddFieldToFuncDeclParams(t *testing.T) { @@ -303,3 +362,142 @@ func TestAddFieldToFuncDeclParams(t *testing.T) { } }) } + +func TestSetMethodOnReceiver(t *testing.T) { + t.Run("case1: fmt", func(t *testing.T) { + var src = ` + package main + + import "fmt" + + func main() { + fmt.Println() + } + ` + + var expected = ` + package main + + import "fmt" + + func main() { + fmt.Panicln() + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + var buf *bytes.Buffer + assert.True(t, SetMethodOnReceiver(df, "fmt", "Println", "Panicln")) + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) + + t.Run("case2", func(t *testing.T) { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + var expected = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + var buf *bytes.Buffer + assert.True(t, SetMethodOnReceiver(df, "clientThrift", "RpcWithContext", "RpcWithContextV2")) + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) +} + +func TestAddFieldToFuncLitParams(t *testing.T) { + t.Run("case 1", func(t *testing.T) { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + + func UpdateUserStrInfo(ctx context.Context, req *UserStrInfoRequest) (r *UpdateUserInfoRes, err error) { + tctx := trace.NewThriftUtilContextFromContext(ctx) + if req == nil { + return nil, fmt.Errorf("req nil") + } + err = rpc(ctx, req.Item, time.Millisecond*3000, + func(c *AccountServiceClient) error { + r, err = c.UpdateUserStrInfo(req, tctx) + return err + }, + ) + return + } + ` + + var expected = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(fctx context.Context, c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + + func UpdateUserStrInfo(ctx context.Context, req *UserStrInfoRequest) (r *UpdateUserInfoRes, err error) { + tctx := trace.NewThriftUtilContextFromContext(ctx) + if req == nil { + return nil, fmt.Errorf("req nil") + } + err = rpc(ctx, req.Item, time.Millisecond*3000, + func(fctx context.Context, c *AccountServiceClient) error { + r, err = c.UpdateUserStrInfo(req, tctx) + return err + }, + ) + return + } + ` + + var buf *bytes.Buffer + df, _ := ParseSrcFileFromBytes([]byte(src)) + assert.True(t, AddFieldToFuncLitParams(df, &dst.Field{ + Names: []*dst.Ident{dst.NewIdent("fctx")}, + Type: dst.NewIdent("context.Context"), + }, 0)) + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) +} diff --git a/util.go b/util.go index 330fdff..aba6071 100644 --- a/util.go +++ b/util.go @@ -66,7 +66,19 @@ func nodesEqual(a, b dst.Node) (ret bool) { return false } + if len(na.Names) == 0 { + return ok && nodesEqual(na.Type, nb.Type) + } return ok && nodesEqual(na.Names[0], nb.Names[0]) && nodesEqual(na.Type, nb.Type) + + case *dst.FuncType: + na := a.(*dst.FuncType) + nb, ok := b.(*dst.FuncType) + + if (len(na.Params.List) != len(nb.Params.List)) || (len(na.Results.List) != len(nb.Results.List)) { + return false + } + return ok && fieldListsEqual(na.Params, nb.Params) && fieldListsEqual(na.Results, nb.Results) case *dst.StarExpr: na := a.(*dst.StarExpr) nb, ok := b.(*dst.StarExpr)