diff --git a/tracing/middleware.go b/tracing/middleware.go index 9053357..6206c15 100644 --- a/tracing/middleware.go +++ b/tracing/middleware.go @@ -129,6 +129,10 @@ func ClientMiddleware(opts ...Option) client.Middleware { func ServerMiddleware(cfg *Config) app.HandlerFunc { return func(ctx context.Context, c *app.RequestContext) { + if cfg.shouldIgnore(ctx, c) { + c.Next(ctx) + return + } // get tracer carrier tc := internal.TraceCarrierFromContext(ctx) if tc == nil { diff --git a/tracing/options.go b/tracing/options.go index 5488357..98d56f6 100644 --- a/tracing/options.go +++ b/tracing/options.go @@ -40,6 +40,8 @@ func (fn option) apply(cfg *Config) { fn(cfg) } +type ConditionFunc func(ctx context.Context, c *app.RequestContext) bool + type Config struct { tracer trace.Tracer meter metric.Meter @@ -54,6 +56,7 @@ type Config struct { recordSourceOperation bool customResponseHandler app.HandlerFunc + shouldIgnore ConditionFunc } func newConfig(opts []Option) *Config { @@ -95,6 +98,9 @@ func defaultConfig() *Config { } return route }, + shouldIgnore: func(ctx context.Context, c *app.RequestContext) bool { + return false + }, } } @@ -132,3 +138,10 @@ func WithServerHttpRouteFormatter(serverHttpRouteFormatter func(c *app.RequestCo cfg.serverHttpRouteFormatter = serverHttpRouteFormatter }) } + +// WithShouldIgnore allows you to define the condition for enabling distributed tracing +func WithShouldIgnore(condition ConditionFunc) Option { + return option(func(cfg *Config) { + cfg.shouldIgnore = condition + }) +} diff --git a/tracing/tracer_server.go b/tracing/tracer_server.go index ca7bea1..b4ea744 100644 --- a/tracing/tracer_server.go +++ b/tracing/tracer_server.go @@ -73,7 +73,10 @@ func (s *serverTracer) createMeasures() { s.histogramRecorder[ServerLatency] = serverLatencyMeasure } -func (s *serverTracer) Start(ctx context.Context, _ *app.RequestContext) context.Context { +func (s *serverTracer) Start(ctx context.Context, c *app.RequestContext) context.Context { + if s.config.shouldIgnore(ctx, c) { + return ctx + } tc := &internal.TraceCarrier{} tc.SetTracer(s.config.tracer) @@ -81,6 +84,9 @@ func (s *serverTracer) Start(ctx context.Context, _ *app.RequestContext) context } func (s *serverTracer) Finish(ctx context.Context, c *app.RequestContext) { + if s.config.shouldIgnore(ctx, c) { + return + } // trace carrier from context tc := internal.TraceCarrierFromContext(ctx) if tc == nil {