diff --git a/sqlcon/connector.go b/sqlcon/connector.go index 92d2a1be..d48ebf13 100644 --- a/sqlcon/connector.go +++ b/sqlcon/connector.go @@ -23,7 +23,6 @@ package sqlcon import ( "database/sql" "fmt" - "github.com/satori/go.uuid" "github.com/go-sql-driver/mysql" "github.com/lib/pq" @@ -212,8 +211,8 @@ func (c *SQLConnection) registerDriver() (string, error) { driverName := c.URL.Scheme if c.UseTracedDriver { driverName = "instrumented-sql-driver" - if c.useRandomDriverName { - driverName = uuid.NewV4().String() + if len(c.options.forcedDriverName) > 0 { + driverName = c.options.forcedDriverName } tracingOpts := []instrumentedsql.Opt{instrumentedsql.WithTracer(opentracing.NewTracer(c.AllowRoot))} diff --git a/sqlcon/connector_test.go b/sqlcon/connector_test.go index 54af4676..136f1f0b 100644 --- a/sqlcon/connector_test.go +++ b/sqlcon/connector_test.go @@ -21,6 +21,7 @@ package sqlcon import ( + "context" "flag" "fmt" "log" @@ -30,6 +31,9 @@ import ( "testing" "time" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" + _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -69,33 +73,92 @@ func merge(u *url.URL, params map[string]string) *url.URL { return b } +func TestDistributedTracing(t *testing.T) { + for _, testCase := range []struct { + description string + sqlConnection *SQLConnection + }{ + { + description: "mysql: when tracing has been configured - spans should be created", + sqlConnection: mustSQL(t, mysqlUrl.String(), + WithDistributedTracing(), + withRandomDriverName(), // this test option is set to ensure a unique driver name is registered + WithAllowRoot()), + }, + { + description: "postgres: when tracing has been configured - spans should be created", + sqlConnection: mustSQL(t, postgresUrl.String(), + WithDistributedTracing(), + withRandomDriverName(), // this test option is set to ensure a unique driver name is registered + WithAllowRoot()), + }, + { + description: "mysql: no spans should be created if parent span is missing from context when `WithAllowRoot` has NOT been set", + sqlConnection: mustSQL(t, mysqlUrl.String(), + WithDistributedTracing(), // Notice that the WithAllowRoot() option has NOT been set + withRandomDriverName()), // this test option is set to ensure a unique driver name is registered + }, + { + description: "postgres: no spans should be created if parent span is missing from context when `WithAllowRoot` has NOT been set", + sqlConnection: mustSQL(t, postgresUrl.String(), + WithDistributedTracing(), // Notice that the WithAllowRoot() option has NOT been set + withRandomDriverName()), // this test option is set to ensure a unique driver name is registered + }, + { + description: "mysql: when tracing has NOT been configured - NO spans should be created", + sqlConnection: mustSQL(t, mysqlUrl.String()), // Notice that the WithDistributedTracing() option has NOT been set + }, + { + description: "postgres: when tracing has NOT been configured - NO spans should be created", + sqlConnection: mustSQL(t, postgresUrl.String()), // Notice that the WithDistributedTracing() option has NOT been set + }, + } { + t.Run(fmt.Sprintf("case=%s", testCase.description), func(t *testing.T) { + mockedTracer := mocktracer.New() + defer mockedTracer.Reset() + opentracing.SetGlobalTracer(mockedTracer) + + db := testCase.sqlConnection.GetDatabase() + // Notice how no parent span exists in the provided context! + db.QueryRowContext(context.TODO(), "SELECT NOW()") + + spans := mockedTracer.FinishedSpans() + if testCase.sqlConnection.UseTracedDriver && testCase.sqlConnection.AllowRoot { + assert.NotEmpty(t, spans) + } else { + assert.Empty(t, spans) + } + }) + } +} + func TestRegisterDriver(t *testing.T) { unsupportedDSN := "unsupported://unsupported:secret@localhost:1337/mydb" supportedDSN := "mysql://foo@bar:baz@qux/db" for _, testCase := range []struct { - description string - sqlConnection *SQLConnection + description string + sqlConnection *SQLConnection expectedDriverName string - shouldError bool - } { + shouldError bool + }{ { - description: "should return error if supplied DSN is unsupported for tracing", - sqlConnection: mustSQL(t, unsupportedDSN, WithDistributedTracing()), + description: "should return error if supplied DSN is unsupported for tracing", + sqlConnection: mustSQL(t, unsupportedDSN, WithDistributedTracing()), expectedDriverName: "", - shouldError: true, + shouldError: true, }, { - description: "should return registered driver name if supplied DSN is valid for tracing", - sqlConnection: mustSQL(t, supportedDSN, WithDistributedTracing()), + description: "should return registered driver name if supplied DSN is valid for tracing", + sqlConnection: mustSQL(t, supportedDSN, WithDistributedTracing()), expectedDriverName: "instrumented-sql-driver", - shouldError: false, + shouldError: false, }, { - description: "should return registered driver name if tracing is NOT configured", - sqlConnection: mustSQL(t, supportedDSN), + description: "should return registered driver name if tracing is NOT configured", + sqlConnection: mustSQL(t, supportedDSN), expectedDriverName: "mysql", - shouldError: false, + shouldError: false, }, } { t.Run(fmt.Sprintf("case=%s", testCase.description), func(t *testing.T) { diff --git a/sqlcon/options.go b/sqlcon/options.go index 38fb5d95..7c9986bf 100644 --- a/sqlcon/options.go +++ b/sqlcon/options.go @@ -1,10 +1,12 @@ package sqlcon +import "github.com/satori/go.uuid" + type options struct { - UseTracedDriver bool - OmitArgs bool - AllowRoot bool - useRandomDriverName bool + UseTracedDriver bool + OmitArgs bool + AllowRoot bool + forcedDriverName string } type Opt func(*options) @@ -35,6 +37,6 @@ func WithAllowRoot() Opt { // Reason for this option is because you can't register a driver with the same name more than once func withRandomDriverName() Opt { return func(o *options) { - o.useRandomDriverName = true + o.forcedDriverName = uuid.NewV4().String() } }