diff --git a/README.md b/README.md index 8d1e4f5..96d317d 100644 --- a/README.md +++ b/README.md @@ -41,20 +41,26 @@ FprintFile(out io.Writer, df *dst.File) error HasStmtInsideFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (ret bool) DeleteStmtFromFuncBody(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) AddStmtToFuncBody(df *dst.File, funcName string, stmt dst.Stmt, pos int) (modified bool) -AddStmtToFuncLitBody(df *dst.File, stmt dst.Stmt, pos int) (modified bool) AddStmtToFuncBodyStart(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) AddStmtToFuncBodyEnd(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) AddStmtToFuncBodyBefore(df *dst.File, funcName string, stmt, refStmt dst.Stmt) (modified bool) AddStmtToFuncBodyAfter(df *dst.File, funcName string, stmt, refStmt dst.Stmt) (modified bool) ``` +### function lit utilities + +``` +AddStmtToFuncLitBody(df *dst.File, scope Scope, stmt dst.Stmt, pos int) (modified bool) +AddFieldToFuncLitParams(df *dst.File, scope Scope, field *dst.Field, pos int) (modified bool) +``` + ### function call utilities ``` -HasArgInCallExpr(df *dst.File, funcName string, arg dst.Expr) (ret bool) -DeleteArgFromCallExpr(df *dst.File, funcName string, arg dst.Expr) (modified bool) -AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (modified bool) -SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (modified bool) +HasArgInCallExpr(df *dst.File, scope Scope, funcName string, arg dst.Expr) (ret bool) +DeleteArgFromCallExpr(df *dst.File, scope Scope, funcName string, arg dst.Expr) (modified bool) +AddArgToCallExpr(df *dst.File, scope Scope, funcName string, arg dst.Expr, pos int) (modified bool) +SetMethodOnReceiver(df *dst.File, scope Scope, receiver, oldMethod, newMethod string) (modified bool) ``` ### function declaration utilities diff --git a/examples/rpc_adapter.go b/examples/rpc_adapter.go index da89fa6..f19cae5 100644 --- a/examples/rpc_adapter.go +++ b/examples/rpc_adapter.go @@ -3,7 +3,6 @@ package main import ( "github.com/ZhengHe-MD/gorefactor" "github.com/dave/dst" - "go/token" "log" "os" ) @@ -12,9 +11,9 @@ func main() { var src = ` package main - func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { - return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { - ct, ok := c.(*AccountServiceClient) + func rpc(haskkey string, timeout time.Duration, fn func(*AppServiceClient) error) error { + return clientThrift.Rpc(haskkey, timeout, func(c interface{}) error { + ct, ok := c.(*AppServiceClient) if ok { return fn(ct) } else { @@ -23,17 +22,18 @@ func main() { }) } - func RegisterGuest(ctx context.Context, req *RegisterGuestReq) (r *RegisterGuestRes) { + func VersionCheck(ctx context.Context, req *VersionCheckReq) (r *VersionCheckRes) { tctx := trace.NewThriftUtilContextFromContext(ctx) - err := rpc(ctx, strconv.Itoa(rand.Int()), time.Millisecond*1000, - func(c *AccountServiceClient) (er error) { - r, er = c.RegisterGuest(req, tctx) + err := rpc(strconv.FormatInt(req.Uid, 10), time.Millisecond*3000, + func(c *AppServiceClient) (er error) { + r, er = c.VersionCheck(req, tctx) return er }, ) + if err != nil { - r = &RegisterGuestRes{ - Errinfo: tiperr.NewInternalErr(-1001, fmt.Sprintf("call serice:%s proc:%s m:RegisterGuest err:%s", service_key, proc_thrift, err)), + r = &VersionCheckRes{ + Errinfo: tiperr.NewInternalErr(-1001, fmt.Sprintf("call serice:%s proc:%s m:VersionCheck err:%s", service_key, proc_thrift, err)), } } return @@ -77,67 +77,73 @@ func main() { return } - gorefactor.SetMethodOnReceiver(df, "clientThrift", "RpcWithContext", "RpcWithContextV2") - gorefactor.AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0) - gorefactor.AddFieldToFuncLitParams(df, &dst.Field{ - Names: []*dst.Ident{dst.NewIdent("fctx")}, - Type: dst.NewIdent("context.Context"), - }, 0) - gorefactor.DeleteStmtFromFuncBody(df, "RegisterGuest", &dst.AssignStmt{ - Lhs: []dst.Expr{ - dst.NewIdent("tctx"), - }, - Tok: token.DEFINE, - Rhs: []dst.Expr{ - &dst.CallExpr{ - Fun: &dst.SelectorExpr{ - X: dst.NewIdent("trace"), - Sel: dst.NewIdent("NewThriftUtilContextFromContext"), - }, - Args: []dst.Expr{ - dst.NewIdent("ctx"), - }, - }, - }, - }) - gorefactor.AddStmtToFuncLitBody(df, &dst.AssignStmt{ - Lhs: []dst.Expr{ - dst.NewIdent("tctx"), - }, - Tok: token.DEFINE, - Rhs: []dst.Expr{ - &dst.CallExpr{ - Fun: &dst.SelectorExpr{ - X: dst.NewIdent("trace"), - Sel: dst.NewIdent("NewThriftUtilContextFromContext"), - }, - Args: []dst.Expr{ - dst.NewIdent("tctx"), - }, - }, - }, - }, 0) - gorefactor.DeleteStmtFromFuncBody(df, "rpc", &dst.AssignStmt{ - Lhs: []dst.Expr{ - dst.NewIdent("tctx"), - }, - Tok: token.DEFINE, - Rhs: []dst.Expr{ - &dst.CallExpr{ - Fun: &dst.SelectorExpr{ - X: dst.NewIdent("trace"), - Sel: dst.NewIdent("NewThriftUtilContextFromContext"), - }, - Args: []dst.Expr{ - dst.NewIdent("tctx"), - }, - }, - }, - }) + //gorefactor.AddFieldToFuncDeclParams(df, "rpc", &dst.Field{ + // Names: []*dst.Ident{dst.NewIdent("ctx")}, + // Type: dst.NewIdent("context.Contex"), + //}, 0) + gorefactor.AddFieldToFuncDeclParams(df, "fn", &dst.Field{Type: dst.NewIdent("context.Context")}, 0) + //gorefactor.SetMethodOnReceiver(df, "clientThrift", "Rpc", "RpcWithContextV2") + //gorefactor.AddArgToCallExpr(df, "RpcWithContextV2", dst.NewIdent("ctx"), 0) + //gorefactor.AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0) + //gorefactor.AddFieldToFuncLitParams(df, &dst.Field{ + // Names: []*dst.Ident{dst.NewIdent("fctx")}, + // Type: dst.NewIdent("context.Context"), + //}, 0) + //gorefactor.DeleteStmtFromFuncBody(df, "VersionCheck", &dst.AssignStmt{ + // Lhs: []dst.Expr{ + // dst.NewIdent("tctx"), + // }, + // Tok: token.DEFINE, + // Rhs: []dst.Expr{ + // &dst.CallExpr{ + // Fun: &dst.SelectorExpr{ + // X: dst.NewIdent("trace"), + // Sel: dst.NewIdent("NewThriftUtilContextFromContext"), + // }, + // Args: []dst.Expr{ + // dst.NewIdent("ctx"), + // }, + // }, + // }, + //}) + //gorefactor.AddStmtToFuncLitBody(df, &dst.AssignStmt{ + // Lhs: []dst.Expr{ + // dst.NewIdent("tctx"), + // }, + // Tok: token.DEFINE, + // Rhs: []dst.Expr{ + // &dst.CallExpr{ + // Fun: &dst.SelectorExpr{ + // X: dst.NewIdent("trace"), + // Sel: dst.NewIdent("NewThriftUtilContextFromContext"), + // }, + // Args: []dst.Expr{ + // dst.NewIdent("fctx"), + // }, + // }, + // }, + //}, 0) + //gorefactor.DeleteStmtFromFuncBody(df, "rpc", &dst.AssignStmt{ + // Lhs: []dst.Expr{ + // dst.NewIdent("tctx"), + // }, + // Tok: token.DEFINE, + // Rhs: []dst.Expr{ + // &dst.CallExpr{ + // Fun: &dst.SelectorExpr{ + // X: dst.NewIdent("trace"), + // Sel: dst.NewIdent("NewThriftUtilContextFromContext"), + // }, + // Args: []dst.Expr{ + // dst.NewIdent("tctx"), + // }, + // }, + // }, + //}) err = gorefactor.FprintFile(os.Stdout, df) if err != nil { log.Println(err) return } -} \ No newline at end of file +} diff --git a/func_body.go b/func_body.go index 1e4c854..d839c5f 100644 --- a/func_body.go +++ b/func_body.go @@ -149,28 +149,6 @@ func AddStmtToFuncBody(df *dst.File, funcName string, stmt dst.Stmt, pos int) (m return } -func AddStmtToFuncLitBody(df *dst.File, stmt dst.Stmt, pos int) (modified bool) { - pre := func(c *dstutil.Cursor) bool { - node := c.Node() - - switch node.(type) { - case *dst.FuncLit: - nn := node.(*dst.FuncLit) - stmtList := nn.Body.List - pos = normalizePos(pos, len(stmtList)) - - nn.Body.List = append( - stmtList[:pos], - append([]dst.Stmt{dst.Clone(stmt).(dst.Stmt)}, stmtList[pos:]...)...) - modified = true - } - return true - } - - dstutil.Apply(df, pre, nil) - return -} - // AddStmtToFuncBodyStart adds given statement, to the start of function body func AddStmtToFuncBodyStart(df *dst.File, funcName string, stmt dst.Stmt) (modified bool) { return AddStmtToFuncBody(df, funcName, stmt, 0) diff --git a/func_call.go b/func_call.go index bf8396d..8ae4ea8 100644 --- a/func_call.go +++ b/func_call.go @@ -51,12 +51,16 @@ func HasArgInCallExpr(df *dst.File, scope Scope, funcName string, arg dst.Expr) // DeleteArgFromCallExpr deletes any arg, in the function call's argument list, // that is semantically equal to the given arg. -func DeleteArgFromCallExpr(df *dst.File, funcName string, arg dst.Expr) (modified bool) { +func DeleteArgFromCallExpr(df *dst.File, scope Scope, funcName string, arg dst.Expr) (modified bool) { pre := func(c *dstutil.Cursor) bool { node := c.Node() switch node.(type) { case *dst.CallExpr: + if !scope.IsInScope() { + return true + } + var found bool nn := node.(*dst.CallExpr) if ie, ok := nn.Fun.(*dst.Ident); ok && ie.Name == funcName { @@ -79,23 +83,33 @@ func DeleteArgFromCallExpr(df *dst.File, funcName string, arg dst.Expr) (modifie nn.Args = newArgs return false } + default: + scope.TryEnterScope(node) } return true } - dstutil.Apply(df, pre, nil) + post := func(c *dstutil.Cursor) bool { + scope.TryLeaveScope(c.Node()) + return true + } + + dstutil.Apply(df, pre, post) return } // AddArgToCallExpr adds given arg, to the function call's argument list, in the given position -func AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (modified bool) { +func AddArgToCallExpr(df *dst.File, scope Scope, funcName string, arg dst.Expr, pos int) (modified bool) { pre := func(c *dstutil.Cursor) bool { node := c.Node() switch node.(type) { case *dst.CallExpr: - nn := node.(*dst.CallExpr) + if !scope.IsInScope() { + return true + } + nn := node.(*dst.CallExpr) var ce *dst.CallExpr if ie, ok := nn.Fun.(*dst.Ident); ok && ie.Name == funcName { @@ -115,20 +129,30 @@ func AddArgToCallExpr(df *dst.File, funcName string, arg dst.Expr, pos int) (mod modified = true return false } + default: + scope.TryEnterScope(node) } return true } - dstutil.Apply(df, pre, nil) + post := func(c *dstutil.Cursor) bool { + scope.TryLeaveScope(c.Node()) + return true + } + + dstutil.Apply(df, pre, post) return } -func SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (modified bool) { +func SetMethodOnReceiver(df *dst.File, scope Scope, receiver, oldMethod, newMethod string) (modified bool) { pre := func(c *dstutil.Cursor) bool { node := c.Node() switch node.(type) { case *dst.CallExpr: + if !scope.IsInScope() { + return true + } nn := node.(*dst.CallExpr) switch nn.Fun.(type) { @@ -154,10 +178,17 @@ func SetMethodOnReceiver(df *dst.File, receiver, oldMethod, newMethod string) (m modified = true } } + default: + scope.TryEnterScope(node) } return true } - dstutil.Apply(df, pre, nil) + post := func(c *dstutil.Cursor) bool { + scope.TryLeaveScope(c.Node()) + return true + } + + dstutil.Apply(df, pre, post) return } diff --git a/func_call_test.go b/func_call_test.go index e4e9c0c..606a305 100644 --- a/func_call_test.go +++ b/func_call_test.go @@ -223,6 +223,48 @@ func TestHasArgInCallExpr(t *testing.T) { } func TestDeleteArgFromCallExpr(t *testing.T) { + t.Run("delete within scope", func(t *testing.T) { + var src = ` + package main + + func A() { + fmt.Println(1) + } + + func B() { + fmt.Println(1) + } + ` + + var expectedTemplate = ` + package main + + func A() { + fmt.Println(%s) + } + + func B() { + fmt.Println(%s) + } + ` + + cases := []struct{ + funcName string + scope Scope + expectedFuncAArg string + expectedFuncBArg string + } { + {"Println", Scope{FuncName: "A"}, "", "1"}, + {"Println", Scope{FuncName: "B"}, "1", ""}, + } + + for _, c := range cases { + df, _ := ParseSrcFileFromBytes([]byte(src)) + assert.True(t, DeleteArgFromCallExpr(df, c.scope, c.funcName, &dst.BasicLit{Kind: token.INT, Value: "1"})) + buf := printToBuf(df) + assertCodesEqual(t, fmt.Sprintf(expectedTemplate, c.expectedFuncAArg, c.expectedFuncBArg), buf.String()) + } + }) t.Run("delete multiple args", func(t *testing.T) { var src = ` package main @@ -276,16 +318,16 @@ func TestDeleteArgFromCallExpr(t *testing.T) { var buf *bytes.Buffer - assert.True(t, DeleteArgFromCallExpr(df, "Println", &dst.BasicLit{Kind: token.INT, Value: "3"})) + assert.True(t, DeleteArgFromCallExpr(df, EmptyScope, "Println", &dst.BasicLit{Kind: token.INT, Value: "3"})) buf = printToBuf(df) assertCodesEqual(t, expected1, buf.String()) - assert.False(t, DeleteArgFromCallExpr(df, "Println", &dst.BasicLit{Kind: token.INT, Value: "3"})) + assert.False(t, DeleteArgFromCallExpr(df, EmptyScope, "Println", &dst.BasicLit{Kind: token.INT, Value: "3"})) buf = printToBuf(df) assertCodesEqual(t, expected1, buf.String()) - assert.True(t, DeleteArgFromCallExpr(df, "Println", &dst.BasicLit{Kind: token.INT, Value: "2"})) + assert.True(t, DeleteArgFromCallExpr(df, EmptyScope, "Println", &dst.BasicLit{Kind: token.INT, Value: "2"})) buf = printToBuf(df) assertCodesEqual(t, expected2, buf.String()) - assert.True(t, DeleteArgFromCallExpr(df, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"})) + assert.True(t, DeleteArgFromCallExpr(df, EmptyScope, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"})) buf = printToBuf(df) assertCodesEqual(t, expected3, buf.String()) }) @@ -318,7 +360,7 @@ func TestDeleteArgFromCallExpr(t *testing.T) { df, _ := ParseSrcFileFromBytes([]byte(src)) var buf *bytes.Buffer - assert.True(t, DeleteArgFromCallExpr(df, "hello", &dst.BasicLit{Kind: token.INT, Value: "1"})) + assert.True(t, DeleteArgFromCallExpr(df, EmptyScope, "hello", &dst.BasicLit{Kind: token.INT, Value: "1"})) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) }) @@ -355,7 +397,7 @@ func TestDeleteArgFromCallExpr(t *testing.T) { df, _ := ParseSrcFileFromBytes([]byte(src)) var buf *bytes.Buffer - assert.True(t, DeleteArgFromCallExpr(df, "hello", &dst.BasicLit{Kind: token.INT, Value: "1"})) + assert.True(t, DeleteArgFromCallExpr(df, EmptyScope, "hello", &dst.BasicLit{Kind: token.INT, Value: "1"})) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) }) @@ -398,13 +440,55 @@ func TestDeleteArgFromCallExpr(t *testing.T) { df, _ := ParseSrcFileFromBytes([]byte(src)) var buf *bytes.Buffer - assert.True(t, DeleteArgFromCallExpr(df, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"})) + assert.True(t, DeleteArgFromCallExpr(df, EmptyScope, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"})) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) }) } func TestAddArgToCallExpr(t *testing.T) { + t.Run("add within scope", func(t *testing.T) { + var src = ` + package main + + func A() { + fmt.Println() + } + + func B() { + fmt.Println() + } + ` + + var expectedTemplate = ` + package main + + func A() { + fmt.Println(%s) + } + + func B() { + fmt.Println(%s) + } + ` + + cases := []struct{ + scope Scope + expectedFuncAArg string + expectedFuncBArg string + } { + {Scope{FuncName: "A"}, "1", ""}, + {Scope{FuncName: "B"}, "", "1"}, + } + + for _, c := range cases { + df, _ := ParseSrcFileFromBytes([]byte(src)) + assert.True(t, AddArgToCallExpr(df, c.scope, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"}, 0)) + buf := printToBuf(df) + assertCodesEqual(t, fmt.Sprintf(expectedTemplate, c.expectedFuncAArg, c.expectedFuncBArg), buf.String()) + } + }) + t.Run("add param to empty call", func(t *testing.T) { var src = ` package main @@ -444,7 +528,7 @@ func TestAddArgToCallExpr(t *testing.T) { df, _ := ParseSrcFileFromBytes([]byte(src)) var buf *bytes.Buffer - assert.True(t, AddArgToCallExpr(df, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"}, c.pos)) + assert.True(t, AddArgToCallExpr(df, EmptyScope, "Println", &dst.BasicLit{Kind: token.INT, Value: "1"}, c.pos)) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) } @@ -494,7 +578,7 @@ func TestAddArgToCallExpr(t *testing.T) { expected := fmt.Sprintf(expectedTemplate, c.expectedArgs) df, _ := ParseSrcFileFromBytes([]byte(src)) - assert.True(t, AddArgToCallExpr(df, "Println", c.arg, c.pos)) + assert.True(t, AddArgToCallExpr(df, EmptyScope, "Println", c.arg, c.pos)) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) } @@ -545,7 +629,7 @@ func TestAddArgToCallExpr(t *testing.T) { expected := fmt.Sprintf(expectedTemplate, c.expectedArgs) df, _ := ParseSrcFileFromBytes([]byte(src)) - assert.True(t, AddArgToCallExpr(df, "Println", c.arg, c.pos)) + assert.True(t, AddArgToCallExpr(df, EmptyScope, "Println", c.arg, c.pos)) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) } @@ -584,8 +668,77 @@ func TestAddArgToCallExpr(t *testing.T) { df, _ := ParseSrcFileFromBytes([]byte(src)) var buf *bytes.Buffer - assert.True(t, true, AddArgToCallExpr(df, "fn", dst.NewIdent("fctx"), 0)) + assert.True(t, true, AddArgToCallExpr(df, EmptyScope, "fn", dst.NewIdent("fctx"), 0)) buf = printToBuf(df) assertCodesEqual(t, expected, buf.String()) }) } + + +func TestSetMethodOnReceiver(t *testing.T) { + t.Run("case1: fmt", func(t *testing.T) { + var src = ` + package main + + import "fmt" + + func main() { + fmt.Println() + } + ` + + var expected = ` + package main + + import "fmt" + + func main() { + fmt.Panicln() + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + var buf *bytes.Buffer + assert.True(t, SetMethodOnReceiver(df, EmptyScope, "fmt", "Println", "Panicln")) + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) + + t.Run("case2", func(t *testing.T) { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + var expected = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + ` + + df, _ := ParseSrcFileFromBytes([]byte(src)) + var buf *bytes.Buffer + assert.True(t, SetMethodOnReceiver(df, EmptyScope, "clientThrift", "RpcWithContext", "RpcWithContextV2")) + buf = printToBuf(df) + assertCodesEqual(t, expected, buf.String()) + }) +} \ No newline at end of file diff --git a/func_decl.go b/func_decl.go index 2504eed..c4b4f54 100644 --- a/func_decl.go +++ b/func_decl.go @@ -84,24 +84,4 @@ func AddFieldToFuncDeclParams(df *dst.File, funcName string, field *dst.Field, p return } -func AddFieldToFuncLitParams(df *dst.File, field *dst.Field, pos int) (modified bool) { - pre := func(c *dstutil.Cursor) bool { - node := c.Node() - - switch node.(type) { - case *dst.FuncLit: - nn := node.(*dst.FuncLit) - fieldList := nn.Type.Params.List - pos = normalizePos(pos, len(fieldList)) - nn.Type.Params.List = append( - fieldList[:pos], - append([]*dst.Field{dst.Clone(field).(*dst.Field)}, fieldList[pos:]...)...) - modified = true - } - return true - } - - dstutil.Apply(df, pre, nil) - return -} diff --git a/func_decl_test.go b/func_decl_test.go index 19ee6ce..f731be8 100644 --- a/func_decl_test.go +++ b/func_decl_test.go @@ -363,141 +363,6 @@ func TestAddFieldToFuncDeclParams(t *testing.T) { }) } -func TestSetMethodOnReceiver(t *testing.T) { - t.Run("case1: fmt", func(t *testing.T) { - var src = ` - package main - - import "fmt" - - func main() { - fmt.Println() - } - ` - - var expected = ` - package main - - import "fmt" - - func main() { - fmt.Panicln() - } - ` - - df, _ := ParseSrcFileFromBytes([]byte(src)) - var buf *bytes.Buffer - assert.True(t, SetMethodOnReceiver(df, "fmt", "Println", "Panicln")) - buf = printToBuf(df) - assertCodesEqual(t, expected, buf.String()) - }) - - t.Run("case2", func(t *testing.T) { - var src = ` - package main - - func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { - return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { - ct, ok := c.(*AccountServiceClient) - if ok { - return fn(ct) - } else { - return fmt.Errorf("reflect client thrift error") - } - }) - } - ` - - var expected = ` - package main - - func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { - return clientThrift.RpcWithContextV2(ctx, hashKey, timeout, func(c interface{}) error { - ct, ok := c.(*AccountServiceClient) - if ok { - return fn(ct) - } else { - return fmt.Errorf("reflect client thrift error") - } - }) - } - ` - - df, _ := ParseSrcFileFromBytes([]byte(src)) - var buf *bytes.Buffer - assert.True(t, SetMethodOnReceiver(df, "clientThrift", "RpcWithContext", "RpcWithContextV2")) - buf = printToBuf(df) - assertCodesEqual(t, expected, buf.String()) - }) -} -func TestAddFieldToFuncLitParams(t *testing.T) { - t.Run("case 1", func(t *testing.T) { - var src = ` - package main - func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { - return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { - ct, ok := c.(*AccountServiceClient) - if ok { - return fn(ct) - } else { - return fmt.Errorf("reflect client thrift error") - } - }) - } - - func UpdateUserStrInfo(ctx context.Context, req *UserStrInfoRequest) (r *UpdateUserInfoRes, err error) { - tctx := trace.NewThriftUtilContextFromContext(ctx) - if req == nil { - return nil, fmt.Errorf("req nil") - } - err = rpc(ctx, req.Item, time.Millisecond*3000, - func(c *AccountServiceClient) error { - r, err = c.UpdateUserStrInfo(req, tctx) - return err - }, - ) - return - } - ` - var expected = ` - package main - - func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { - return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(fctx context.Context, c interface{}) error { - ct, ok := c.(*AccountServiceClient) - if ok { - return fn(ct) - } else { - return fmt.Errorf("reflect client thrift error") - } - }) - } - - func UpdateUserStrInfo(ctx context.Context, req *UserStrInfoRequest) (r *UpdateUserInfoRes, err error) { - tctx := trace.NewThriftUtilContextFromContext(ctx) - if req == nil { - return nil, fmt.Errorf("req nil") - } - err = rpc(ctx, req.Item, time.Millisecond*3000, - func(fctx context.Context, c *AccountServiceClient) error { - r, err = c.UpdateUserStrInfo(req, tctx) - return err - }, - ) - return - } - ` - - var buf *bytes.Buffer - df, _ := ParseSrcFileFromBytes([]byte(src)) - assert.True(t, AddFieldToFuncLitParams(df, &dst.Field{ - Names: []*dst.Ident{dst.NewIdent("fctx")}, - Type: dst.NewIdent("context.Context"), - }, 0)) - buf = printToBuf(df) - assertCodesEqual(t, expected, buf.String()) - }) -} diff --git a/func_lit.go b/func_lit.go new file mode 100644 index 0000000..99d5551 --- /dev/null +++ b/func_lit.go @@ -0,0 +1,72 @@ +package gorefactor + +import ( + "github.com/dave/dst" + "github.com/dave/dst/dstutil" +) + +// AddStmtToFuncLitBody add statement to anonymous function body +func AddStmtToFuncLitBody(df *dst.File, scope Scope, stmt dst.Stmt, pos int) (modified bool) { + pre := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.FuncLit: + if !scope.IsInScope() { + return true + } + nn := node.(*dst.FuncLit) + stmtList := nn.Body.List + pos = normalizePos(pos, len(stmtList)) + + nn.Body.List = append( + stmtList[:pos], + append([]dst.Stmt{dst.Clone(stmt).(dst.Stmt)}, stmtList[pos:]...)...) + modified = true + default: + scope.TryEnterScope(node) + } + return true + } + + post := func(c *dstutil.Cursor) bool { + scope.TryLeaveScope(c.Node()) + return true + } + + dstutil.Apply(df, pre, post) + return +} + +// AddFieldToFuncLitParams add statement to anonymous function params +func AddFieldToFuncLitParams(df *dst.File, scope Scope, field *dst.Field, pos int) (modified bool) { + pre := func(c *dstutil.Cursor) bool { + node := c.Node() + + switch node.(type) { + case *dst.FuncLit: + if !scope.IsInScope() { + return true + } + nn := node.(*dst.FuncLit) + + fieldList := nn.Type.Params.List + pos = normalizePos(pos, len(fieldList)) + nn.Type.Params.List = append( + fieldList[:pos], + append([]*dst.Field{dst.Clone(field).(*dst.Field)}, fieldList[pos:]...)...) + modified = true + default: + scope.TryEnterScope(node) + } + return true + } + + post := func(c *dstutil.Cursor) bool { + scope.TryLeaveScope(c.Node()) + return true + } + + dstutil.Apply(df, pre, post) + return +} \ No newline at end of file diff --git a/func_lit_test.go b/func_lit_test.go new file mode 100644 index 0000000..fdb431a --- /dev/null +++ b/func_lit_test.go @@ -0,0 +1,89 @@ +package gorefactor + +import ( + "github.com/dave/dst" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAddFieldToFuncLitParams(t *testing.T) { + t.Run("case 1", func(t *testing.T) { + var src = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + + func UpdateUserStrInfo(ctx context.Context, req *UserStrInfoRequest) (r *UpdateUserInfoRes, err error) { + tctx := trace.NewThriftUtilContextFromContext(ctx) + if req == nil { + return nil, fmt.Errorf("req nil") + } + err = rpc(ctx, req.Item, time.Millisecond*3000, + func(c *AccountServiceClient) error { + r, err = c.UpdateUserStrInfo(req, tctx) + return err + }, + ) + return + } + ` + + var expectedWithoutScope = ` + package main + + func rpc(ctx context.Context, hashKey string, timeout time.Duration, fn func(*AccountServiceClient) error) error { + return clientThrift.RpcWithContext(ctx, hashKey, timeout, func(fctx context.Context, c interface{}) error { + ct, ok := c.(*AccountServiceClient) + if ok { + return fn(ct) + } else { + return fmt.Errorf("reflect client thrift error") + } + }) + } + + func UpdateUserStrInfo(ctx context.Context, req *UserStrInfoRequest) (r *UpdateUserInfoRes, err error) { + tctx := trace.NewThriftUtilContextFromContext(ctx) + if req == nil { + return nil, fmt.Errorf("req nil") + } + err = rpc(ctx, req.Item, time.Millisecond*3000, + func(fctx context.Context, c *AccountServiceClient) error { + r, err = c.UpdateUserStrInfo(req, tctx) + return err + }, + ) + return + } + ` + + { + df, _ := ParseSrcFileFromBytes([]byte(src)) + assert.True(t, AddFieldToFuncLitParams(df, EmptyScope, &dst.Field{ + Names: []*dst.Ident{dst.NewIdent("fctx")}, + Type: dst.NewIdent("context.Context"), + }, 0)) + buf := printToBuf(df) + assertCodesEqual(t, expectedWithoutScope, buf.String()) + } + + { + df, _ := ParseSrcFileFromBytes([]byte(src)) + assert.False(t, AddFieldToFuncLitParams(df, Scope{FuncName: "NonExist"}, &dst.Field{ + Names: []*dst.Ident{dst.NewIdent("fctx")}, + Type: dst.NewIdent("context.Context"), + }, 0)) + buf := printToBuf(df) + assertCodesEqual(t, src, buf.String()) + } + }) +} \ No newline at end of file