diff --git a/controlplane/clickhouse/migrations/20230825095358_traces.sql b/controlplane/clickhouse/migrations/20230825095358_traces.sql index 04542d45ac..d4db87d0f3 100644 --- a/controlplane/clickhouse/migrations/20230825095358_traces.sql +++ b/controlplane/clickhouse/migrations/20230825095358_traces.sql @@ -15,6 +15,7 @@ CREATE TABLE IF NOT EXISTS traces ( StatusMessage String CODEC (ZSTD(3)), OperationHash String CODEC (ZSTD(3)), OperationContent String CODEC (ZSTD(3)), + OperationVariables String CODEC (ZSTD(3)), OperationPersistedID String CODEC (ZSTD(3)), HttpStatusCode String CODEC (ZSTD(3)), HttpHost String CODEC (ZSTD(3)), diff --git a/controlplane/src/core/repositories/analytics/TraceRepository.ts b/controlplane/src/core/repositories/analytics/TraceRepository.ts index 85d549843a..dda1f270d3 100644 --- a/controlplane/src/core/repositories/analytics/TraceRepository.ts +++ b/controlplane/src/core/repositories/analytics/TraceRepository.ts @@ -43,7 +43,8 @@ export class TraceRepository { SpanAttributes['http.target'] as attrHttpTarget, SpanAttributes['wg.subgraph.name'] as attrSubgraphName, SpanAttributes['wg.engine.plan_cache_hit'] as attrEnginePlanCacheHit, - SpanAttributes['wg.engine.request_tracing_enabled'] as attrEngineRequestTracingEnabled + SpanAttributes['wg.engine.request_tracing_enabled'] as attrEngineRequestTracingEnabled, + SpanAttributes['wg.operation.variables'] as attrOperationVariables FROM ${this.client.database}.otel_traces WHERE (TraceId = trace_id) AND (Timestamp >= start) AND (Timestamp <= end) AND SpanAttributes['wg.organization.id'] = '${organizationID}' ORDER BY Timestamp ASC @@ -81,6 +82,7 @@ export class TraceRepository { subgraphName: result.attrSubgraphName, enginePlanCacheHit: result.attrEnginePlanCacheHit, engineRequestTracingEnabled: result.attrEngineRequestTracingEnabled, + operationVariables: result.attrOperationVariables, }, })); } diff --git a/docker-compose.yml b/docker-compose.yml index 53f76fc78e..b776ce1696 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -175,11 +175,11 @@ services: image: redis:${DC_REDIS_VERSION:-7.2.4}-alpine ports: - '6380:6379' - command: redis-server --slaveof redis 6379 /usr/local/etc/redis/redis.conf + command: redis-server /usr/local/etc/redis/redis.conf depends_on: - redis volumes: - - ./docker/redis/redis.conf:/usr/local/etc/redis/redis.conf + - ./docker/redis/redis-slave.conf:/usr/local/etc/redis/redis.conf - redis-slave:/data networks: - primary diff --git a/docker/redis/redis-slave.conf b/docker/redis/redis-slave.conf new file mode 100644 index 0000000000..0d91a1b253 --- /dev/null +++ b/docker/redis/redis-slave.conf @@ -0,0 +1,2 @@ +slaveof redis 6379 +masterauth test \ No newline at end of file diff --git a/router/cmd/instance.go b/router/cmd/instance.go index de76f6d885..0bc5e3eec3 100644 --- a/router/cmd/instance.go +++ b/router/cmd/instance.go @@ -179,10 +179,13 @@ func traceConfig(cfg *config.Telemetry) *trace.Config { } return &trace.Config{ - Enabled: cfg.Tracing.Enabled, - Name: cfg.ServiceName, - Version: core.Version, - Sampler: cfg.Tracing.SamplingRate, + Enabled: cfg.Tracing.Enabled, + Name: cfg.ServiceName, + Version: core.Version, + Sampler: cfg.Tracing.SamplingRate, + ExportGraphQLVariables: trace.ExportGraphQLVariables{ + Enabled: cfg.Tracing.ExportGraphQLVariables, + }, Exporters: exporters, Propagators: propagators, } diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index 6e9b751ad4..2f4d251617 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -101,8 +101,8 @@ func (h *GraphQLHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { requestLogger := h.log.With(logging.WithRequestID(middleware.GetReqID(r.Context()))) operationCtx := getOperationContext(r.Context()) - executionContext, graphqlExecutionSpan := h.tracer.Start(r.Context(), "Operation - Execution", - trace.WithSpanKind(trace.SpanKindServer), + executionContext, graphqlExecutionSpan := h.tracer.Start(r.Context(), "Operation - Execute", + trace.WithSpanKind(trace.SpanKindInternal), ) defer graphqlExecutionSpan.End() diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index ea5ebdfed5..4428e0ab17 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -5,17 +5,19 @@ import ( "crypto/ecdsa" "errors" "fmt" + "github.com/golang-jwt/jwt/v5" "github.com/wundergraph/cosmo/router/pkg/logging" "github.com/wundergraph/cosmo/router/pkg/otel" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" + "go.opentelemetry.io/otel/attribute" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/trace" "net/http" + "strconv" "sync" "time" "github.com/go-chi/chi/middleware" - "github.com/golang-jwt/jwt/v5" "go.uber.org/zap" "github.com/wundergraph/cosmo/router/internal/cdn" @@ -29,7 +31,7 @@ type PreHandlerOptions struct { Logger *zap.Logger Executor *Executor Metrics RouterMetrics - Parser *OperationParser + OperationProcessor *OperationProcessor Planner *OperationPlanner AccessController *AccessController DevelopmentMode bool @@ -37,13 +39,14 @@ type PreHandlerOptions struct { EnableRequestTracing bool TracerProvider *sdktrace.TracerProvider FlushTelemetryAfterResponse bool + TraceExportVariables bool } type PreHandler struct { log *zap.Logger executor *Executor metrics RouterMetrics - parser *OperationParser + operationProcessor *OperationProcessor planner *OperationPlanner accessController *AccessController developmentMode bool @@ -52,6 +55,7 @@ type PreHandler struct { tracerProvider *sdktrace.TracerProvider flushTelemetryAfterResponse bool tracer trace.Tracer + traceExportVariables bool } func NewPreHandler(opts *PreHandlerOptions) *PreHandler { @@ -59,7 +63,7 @@ func NewPreHandler(opts *PreHandlerOptions) *PreHandler { log: opts.Logger, executor: opts.Executor, metrics: opts.Metrics, - parser: opts.Parser, + operationProcessor: opts.OperationProcessor, planner: opts.Planner, accessController: opts.AccessController, routerPublicKey: opts.RouterPublicKey, @@ -67,6 +71,8 @@ func NewPreHandler(opts *PreHandlerOptions) *PreHandler { enableRequestTracing: opts.EnableRequestTracing, flushTelemetryAfterResponse: opts.FlushTelemetryAfterResponse, tracerProvider: opts.TracerProvider, + traceExportVariables: opts.TraceExportVariables, + tracer: opts.TracerProvider.Tracer( "wundergraph/cosmo/router/pre_handler", trace.WithInstrumentationVersion("0.0.1"), @@ -98,9 +104,20 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { tracePlanStart int64 ) + routerSpan := trace.SpanFromContext(r.Context()) + clientInfo := NewClientInfoFromRequest(r) + baseAttributeValues := []attribute.KeyValue{ + otel.WgClientName.String(clientInfo.Name), + otel.WgClientVersion.String(clientInfo.Version), + otel.WgOperationProtocol.String(OperationProtocolHTTP.String()), + } + metrics := h.metrics.StartOperation(clientInfo, requestLogger, r.ContentLength) + routerSpan.SetAttributes(baseAttributeValues...) + metrics.AddAttributes(baseAttributeValues...) + if h.flushTelemetryAfterResponse { defer h.flushMetrics(r.Context(), requestLogger) } @@ -114,7 +131,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { buf := pool.GetBytesBuffer() defer pool.PutBytesBuffer(buf) - body, err := h.parser.ReadBody(r.Context(), buf, r.Body) + body, err := h.operationProcessor.ReadBody(buf, r.Body) if err != nil { finalErr = err requestLogger.Error(err.Error()) @@ -122,6 +139,108 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { return } + /** + * Parse the operation + */ + + engineParseCtx, engineParseSpan := h.tracer.Start(r.Context(), "Operation - Parse", + trace.WithSpanKind(trace.SpanKindInternal), + ) + + operationKit, err := h.operationProcessor.NewKit(body) + defer operationKit.Free() + + if err != nil { + finalErr = err + + rtrace.AttachErrToSpan(engineParseSpan, err) + engineParseSpan.End() + + h.writeOperationError(engineParseCtx, w, requestLogger, err) + return + } + + err = operationKit.Parse(r.Context(), clientInfo, requestLogger) + if err != nil { + finalErr = err + + rtrace.AttachErrToSpan(engineParseSpan, err) + engineParseSpan.End() + + h.writeOperationError(engineParseCtx, w, requestLogger, err) + return + } + + engineParseSpan.End() + + // Set the router span name after we have the operation name + routerSpan.SetName(GetSpanName(operationKit.parsedOperation.Name, operationKit.parsedOperation.Type)) + + baseAttributeValues = []attribute.KeyValue{ + otel.WgOperationName.String(operationKit.parsedOperation.Name), + otel.WgOperationType.String(operationKit.parsedOperation.Type), + } + if operationKit.parsedOperation.PersistedID != "" { + baseAttributeValues = append(baseAttributeValues, otel.WgOperationPersistedID.String(operationKit.parsedOperation.PersistedID)) + } + + routerSpan.SetAttributes(baseAttributeValues...) + metrics.AddAttributes(baseAttributeValues...) + + /** + * Normalize the operation + */ + + engineNormalizeCtx, engineNormalizeSpan := h.tracer.Start(r.Context(), "Operation - Normalize", + trace.WithSpanKind(trace.SpanKindInternal), + ) + + err = operationKit.Normalize() + if err != nil { + finalErr = err + + rtrace.AttachErrToSpan(engineNormalizeSpan, err) + engineNormalizeSpan.End() + + h.writeOperationError(engineNormalizeCtx, w, requestLogger, err) + return + } + + engineNormalizeSpan.End() + + if h.traceExportVariables { + // At this stage the variables are normalized + routerSpan.SetAttributes(otel.WgOperationVariables.String(string(operationKit.parsedOperation.Variables))) + } + + baseAttributeValues = []attribute.KeyValue{ + otel.WgOperationHash.String(strconv.FormatUint(operationKit.parsedOperation.ID, 10)), + } + + // Set the normalized operation as soon as we have it + routerSpan.SetAttributes(otel.WgOperationContent.String(operationKit.parsedOperation.NormalizedRepresentation)) + routerSpan.SetAttributes(baseAttributeValues...) + + metrics.AddAttributes(baseAttributeValues...) + + /** + * Validate the operation + */ + + engineValidateCtx, engineValidateSpan := h.tracer.Start(r.Context(), "Operation - Validate", + trace.WithSpanKind(trace.SpanKindInternal), + ) + err = operationKit.Validate() + if err != nil { + finalErr = err + + rtrace.AttachErrToSpan(engineValidateSpan, err) + engineValidateSpan.End() + + h.writeOperationError(engineValidateCtx, w, requestLogger, err) + return + } + if h.enableRequestTracing { if clientInfo.WGRequestToken != "" && h.routerPublicKey != nil { _, err = jwt.Parse(clientInfo.WGRequestToken, func(token *jwt.Token) (interface{}, error) { @@ -150,6 +269,34 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { r = r.WithContext(resolve.SetTraceStart(r.Context(), traceOptions.EnablePredictableDebugTimings)) } + engineValidateSpan.End() + + /** + * Plan the operation + */ + + enginePlanSpanCtx, enginePlanSpan := h.tracer.Start(r.Context(), "Operation - Plan", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes(otel.WgEngineRequestTracingEnabled.Bool(traceOptions.Enable)), + ) + + opContext, err := h.planner.Plan(operationKit.parsedOperation, clientInfo, OperationProtocolHTTP, traceOptions) + + if err != nil { + finalErr = err + + rtrace.AttachErrToSpan(enginePlanSpan, err) + enginePlanSpan.End() + + requestLogger.Error("failed to plan operation", zap.Error(err)) + h.writeOperationError(enginePlanSpanCtx, w, requestLogger, err) + return + } + + enginePlanSpan.SetAttributes(otel.WgEnginePlanCacheHit.Bool(opContext.planCacheHit)) + + enginePlanSpan.End() + // If we have authenticators, we try to authenticate the request if len(h.accessController.authenticators) > 0 { authenticateSpanCtx, authenticateSpan := h.tracer.Start(r.Context(), "Authenticate", @@ -173,23 +320,6 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { r = validatedReq } - engineParseValidateCtx, engineParseValidateSpan := h.tracer.Start(r.Context(), "Operation - Parse and Validate", - trace.WithSpanKind(trace.SpanKindServer), - ) - - operation, err := h.parser.Parse(r.Context(), clientInfo, body, requestLogger) - if err != nil { - finalErr = err - - rtrace.AttachErrToSpan(engineParseValidateSpan, err) - engineParseValidateSpan.End() - - h.writeOperationError(engineParseValidateCtx, w, requestLogger, err) - return - } - - engineParseValidateSpan.End() - // If the request has a query parameter wg_trace=true we skip the cache // and always plan the operation // this allows us to "write" to the plan @@ -197,34 +327,6 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { tracePlanStart = resolve.GetDurationNanoSinceTraceStart(r.Context()) } - enginePlanSpanCtx, enginePlanSpan := h.tracer.Start(r.Context(), "Operation - Planning", - trace.WithSpanKind(trace.SpanKindServer), - trace.WithAttributes(otel.WgEngineRequestTracingEnabled.Bool(traceOptions.Enable)), - ) - - opContext, err := h.planner.Plan(operation, clientInfo, OperationProtocolHTTP, traceOptions) - - commonAttributeValues := commonMetricAttributes(opContext) - metrics.AddAttributes(commonAttributeValues...) - - // The attributes have to be added on the root router span - initializeSpan(r.Context(), operation, commonAttributeValues) - - if err != nil { - finalErr = err - - rtrace.AttachErrToSpan(enginePlanSpan, err) - enginePlanSpan.End() - - requestLogger.Error("failed to plan operation", zap.Error(err)) - h.writeOperationError(enginePlanSpanCtx, w, requestLogger, err) - return - } - - enginePlanSpan.SetAttributes(otel.WgEnginePlanCacheHit.Bool(opContext.planCacheHit)) - - enginePlanSpan.End() - if !traceOptions.ExcludePlannerStats { planningTime := resolve.GetDurationNanoSinceTraceStart(r.Context()) - tracePlanStart resolve.SetPlannerStats(r.Context(), resolve.PlannerStats{ @@ -255,9 +357,9 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { finalErr = requestContext.error // Mark the root span of the router as failed, so we can easily identify failed requests - span := trace.SpanFromContext(newReq.Context()) + routerSpan = trace.SpanFromContext(newReq.Context()) if finalErr != nil { - rtrace.AttachErrToSpan(span, finalErr) + rtrace.AttachErrToSpan(routerSpan, finalErr) } }) } diff --git a/router/core/operation_metrics.go b/router/core/operation_metrics.go index 9d9d0b5598..e581a64273 100644 --- a/router/core/operation_metrics.go +++ b/router/core/operation_metrics.go @@ -11,7 +11,6 @@ import ( semconv "go.opentelemetry.io/otel/semconv/v1.21.0" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) type OperationProtocol string @@ -101,8 +100,8 @@ func startOperationMetrics(rMetrics RouterMetrics, logger *zap.Logger, requestCo } } -// commonMetricAttributes returns the attributes that are common to both metrics and traces. -func commonMetricAttributes(operationContext *operationContext) []attribute.KeyValue { +// setAttributesFromOperationContext returns the attributes that are common to both metrics and traces. +func setAttributesFromOperationContext(operationContext *operationContext) []attribute.KeyValue { if operationContext == nil { return nil } @@ -124,16 +123,3 @@ func commonMetricAttributes(operationContext *operationContext) []attribute.KeyV return baseMetricAttributeValues } - -// initializeSpan sets the correct span name and attributes for the operation on the current span. -func initializeSpan(ctx context.Context, operation *ParsedOperation, commonAttributeValues []attribute.KeyValue) { - if operation == nil { - return - } - - span := trace.SpanFromContext(ctx) - span.SetName(GetSpanName(operation.Name, operation.Type)) - span.SetAttributes(commonAttributeValues...) - // Only set the operation content on the span - span.SetAttributes(otel.WgOperationContent.String(operation.NormalizedRepresentation)) -} diff --git a/router/core/operation_parser_test.go b/router/core/operation_parser_test.go deleted file mode 100644 index 24eaa95bca..0000000000 --- a/router/core/operation_parser_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package core - -import ( - "context" - "errors" - "strings" - "testing" - - "github.com/buger/jsonparser" - "github.com/stretchr/testify/assert" - "github.com/wundergraph/cosmo/router/internal/pool" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "go.uber.org/zap" -) - -func TestOperationParserExtensions(t *testing.T) { - executor := &Executor{ - PlanConfig: plan.Configuration{}, - Definition: nil, - Resolver: nil, - RenameTypeNames: nil, - Pool: pool.New(), - } - parser := NewOperationParser(OperationParserOptions{ - Executor: executor, - MaxOperationSizeInBytes: 10 << 20, - }) - clientInfo := &ClientInfo{ - Name: "test", - Version: "1.0.0", - } - log := zap.NewNop() - testCases := []struct { - Input string - ValueType jsonparser.ValueType - Valid bool - }{ - { - Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":"this_is_not_valid"}`, - ValueType: jsonparser.String, - }, - { - Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":42}`, - ValueType: jsonparser.Number, - }, - { - Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":true}`, - ValueType: jsonparser.Boolean, - }, - { - Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":{}}`, - Valid: true, - }, - { - Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":null}`, - Valid: true, - }, - { - Input: `{"query":"subscription { initialPayload(repeat:3) }"}`, - Valid: true, - }, - } - var inputError InputError - for _, tc := range testCases { - tc := tc - t.Run(tc.Input, func(t *testing.T) { - _, err := parser.ParseReader(context.Background(), clientInfo, strings.NewReader(tc.Input), log) - isInputError := errors.As(err, &inputError) - if tc.Valid { - assert.False(t, isInputError, "expected invalid extensions to not return an input error, got %s", err) - } else { - assert.True(t, isInputError, "expected invalid extensions to return an input error, got %s", err) - assert.Contains(t, err.Error(), "extensions", "expected error to contain extensions") - assert.Contains(t, err.Error(), tc.ValueType.String(), "expected error to contain value type name") - } - }) - } -} diff --git a/router/core/operation_parser.go b/router/core/operation_processor.go similarity index 62% rename from router/core/operation_parser.go rename to router/core/operation_processor.go index 4e9549714e..3118cad103 100644 --- a/router/core/operation_parser.go +++ b/router/core/operation_processor.go @@ -25,6 +25,35 @@ import ( "go.uber.org/zap" ) +var ( + // staticOperationName is used to replace the operation name in the document when generating the operation ID + // this ensures that the operation ID is the same for the same operation regardless of the operation name + staticOperationName = []byte("O") + parseOperationKeys = [][]string{ + {"query"}, + {"variables"}, + {"operationName"}, + {"extensions"}, + } + + persistedQueryKeys = [][]string{ + {"version"}, + {"sha256Hash"}, + } +) + +const ( + parseOperationKeysQueryIndex = iota + parseOperationKeysVariablesIndex + parseOperationKeysOperationNameIndex + parseOperationKeysExtensionsIndex +) + +const ( + persistedQueryKeysVersionIndex = iota + persistedQueryKeysSha256HashIndex +) + type ParsedOperation struct { // ID represents a unique-ish ID for the operation calculated by hashing // its normalized representation and its variables @@ -38,7 +67,7 @@ type ParsedOperation struct { Variables []byte // NormalizedRepresentation is the normalized representation of the operation // as a string. This is provided for modules to be able to access the - // operation. + // operation. Only available after the operation has been normalized. NormalizedRepresentation string Extensions []byte PersistedID string @@ -62,13 +91,22 @@ var ( _ InputError = invalidExtensionsTypeError(0) ) -type OperationParser struct { +type OperationParserOptions struct { + Executor *Executor + MaxOperationSizeInBytes int64 + PersistentOpClient *cdn.PersistentOperationClient +} + +// OperationProcessor provides shared resources to the parseKit and OperationKit. +// It should be only instantiated once and shared across requests +type OperationProcessor struct { executor *Executor maxOperationSizeInBytes int64 cdn *cdn.PersistentOperationClient parseKitPool *sync.Pool } +// parseKit is a helper struct to parse, normalize and validate operations type parseKit struct { parser *astparser.Parser doc *ast.Document @@ -80,129 +118,42 @@ type parseKit struct { variablesValidator *variablesvalidation.VariablesValidator } -type OperationParserOptions struct { - Executor *Executor - MaxOperationSizeInBytes int64 - PersistentOpClient *cdn.PersistentOperationClient -} - -func NewOperationParser(opts OperationParserOptions) *OperationParser { - return &OperationParser{ - executor: opts.Executor, - maxOperationSizeInBytes: opts.MaxOperationSizeInBytes, - cdn: opts.PersistentOpClient, - parseKitPool: &sync.Pool{ - New: func() interface{} { - return &parseKit{ - parser: astparser.NewParser(), - doc: ast.NewSmallDocument(), - keyGen: xxhash.New(), - normalizer: astnormalization.NewWithOpts( - astnormalization.WithExtractVariables(), - astnormalization.WithInlineFragmentSpreads(), - astnormalization.WithRemoveFragmentDefinitions(), - astnormalization.WithRemoveNotMatchingOperationDefinitions(), - ), - printer: &astprinter.Printer{}, - normalizedOperation: &bytes.Buffer{}, - unescapedDocument: make([]byte, 1024), - variablesValidator: variablesvalidation.NewVariablesValidator(), - } - }, - }, - } -} - -func (p *OperationParser) getKit() *parseKit { - return p.parseKitPool.Get().(*parseKit) -} - -func (p *OperationParser) freeKit(kit *parseKit) { - kit.keyGen.Reset() - kit.doc.Reset() - kit.normalizedOperation.Reset() - kit.unescapedDocument = kit.unescapedDocument[:0] -} - -func (p *OperationParser) entityTooLarge() error { - return &inputError{ - message: "request body too large", - statusCode: http.StatusRequestEntityTooLarge, - } -} - -func (p *OperationParser) ReadBody(ctx context.Context, buf *bytes.Buffer, r io.Reader) ([]byte, error) { - // Use an extra byte for the max size. This way we can check if N became - // zero to detect if the request body was too large. - limitedReader := &io.LimitedReader{R: r, N: p.maxOperationSizeInBytes + 1} - if _, err := io.Copy(buf, limitedReader); err != nil { - return nil, fmt.Errorf("failed to read request body: %w", err) - } - - if limitedReader.N == 0 { - return nil, p.entityTooLarge() - } - - return buf.Bytes(), nil +// OperationKit provides methods to parse, normalize and validate operations. +// After each step, the operation is available as a ParsedOperation. +// It must be created for each request and freed after the request is done. +type OperationKit struct { + data []byte + operationDefinitionRef int + originalOperationNameRef ast.ByteSliceReference + operationParser *OperationProcessor + kit *parseKit + parsedOperation *ParsedOperation } -func (p *OperationParser) ParseReader(ctx context.Context, clientInfo *ClientInfo, r io.Reader, log *zap.Logger) (*ParsedOperation, error) { - buf := pool.GetBytesBuffer() - defer pool.PutBytesBuffer(buf) - data, err := p.ReadBody(ctx, buf, r) - if err != nil { - return nil, err +// NewOperationKit creates a new OperationKit. The kit is used to parse, normalize and validate operations. +// It allocates resources that need to be freed by calling OperationKit.Free() +func NewOperationKit(parser *OperationProcessor, data []byte) *OperationKit { + return &OperationKit{ + operationParser: parser, + kit: parser.getKit(), + operationDefinitionRef: -1, + data: data, } - return p.parse(ctx, clientInfo, data, log) } -func (p *OperationParser) Parse(ctx context.Context, clientInfo *ClientInfo, data []byte, log *zap.Logger) (*ParsedOperation, error) { - if len(data) > int(p.maxOperationSizeInBytes) { - return nil, p.entityTooLarge() - } - return p.parse(ctx, clientInfo, data, log) +// Free releases the resources used by the OperationKit +func (o *OperationKit) Free() { + o.operationParser.freeKit(o.kit) } -var ( - // staticOperationName is used to replace the operation name in the document when generating the operation ID - // this ensures that the operation ID is the same for the same operation regardless of the operation name - staticOperationName = []byte("O") - parseOperationKeys = [][]string{ - {"query"}, - {"variables"}, - {"operationName"}, - {"extensions"}, - } - - persistedQueryKeys = [][]string{ - {"version"}, - {"sha256Hash"}, - } -) - -const ( - parseOperationKeysQueryIndex = iota - parseOperationKeysVariablesIndex - parseOperationKeysOperationNameIndex - parseOperationKeysExtensionsIndex -) - -const ( - persistedQueryKeysVersionIndex = iota - persistedQueryKeysSha256HashIndex -) - -func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, body []byte, log *zap.Logger) (*ParsedOperation, error) { - +func (o *OperationKit) Parse(ctx context.Context, clientInfo *ClientInfo, log *zap.Logger) error { var ( requestOperationType string - operationDefinitionRef = -1 requestOperationNameBytes []byte requestExtensions []byte operationCount = 0 anonymousOperationCount = 0 anonymousOperationDefinitionRef = -1 - originalOperationNameRef ast.ByteSliceReference requestDocumentBytes []byte requestVariableBytes []byte persistedQueryVersion []byte @@ -211,10 +162,7 @@ func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, bod variablesValueType jsonparser.ValueType ) - kit := p.getKit() - defer p.freeKit(kit) - - jsonparser.EachKey(body, func(i int, value []byte, valueType jsonparser.ValueType, err error) { + jsonparser.EachKey(o.data, func(i int, value []byte, valueType jsonparser.ValueType, err error) { if parseErr != nil { // If we already have an error, don't overwrite it return @@ -225,7 +173,7 @@ func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, bod } switch i { case parseOperationKeysQueryIndex: - requestDocumentBytes, err = jsonparser.Unescape(value, kit.unescapedDocument) + requestDocumentBytes, err = jsonparser.Unescape(value, o.kit.unescapedDocument) if err != nil { parseErr = fmt.Errorf("error unescaping query: %w", err) return @@ -276,46 +224,46 @@ func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, bod case jsonparser.Null, jsonparser.Unknown, jsonparser.Object, jsonparser.NotExist: // valid, continue case jsonparser.Array: - return nil, &inputError{ + return &inputError{ message: "variables value must not be an array", statusCode: http.StatusBadRequest, } case jsonparser.String: - return nil, &inputError{ + return &inputError{ message: "variables value must not be a string", statusCode: http.StatusBadRequest, } case jsonparser.Number: - return nil, &inputError{ + return &inputError{ message: "variables value must not be a number", statusCode: http.StatusBadRequest, } case jsonparser.Boolean: - return nil, &inputError{ + return &inputError{ message: "variables value must not be a boolean", statusCode: http.StatusBadRequest, } default: - return nil, &inputError{ + return &inputError{ message: "variables value must be a JSON object", statusCode: http.StatusBadRequest, } } if parseErr != nil { - return nil, errors.WithStack(parseErr) + return errors.WithStack(parseErr) } if len(persistedQuerySha256Hash) > 0 { - if p.cdn == nil { - return nil, &inputError{ + if o.operationParser.cdn == nil { + return &inputError{ message: "could not resolve persisted query, feature is not configured", statusCode: http.StatusOK, } } - persistedOperationData, err := p.cdn.PersistedOperation(ctx, clientInfo.Name, persistedQuerySha256Hash) + persistedOperationData, err := o.operationParser.cdn.PersistedOperation(ctx, clientInfo.Name, persistedQuerySha256Hash) if err != nil { - return nil, errors.WithStack(err) + return errors.WithStack(err) } requestDocumentBytes = persistedOperationData } @@ -326,21 +274,21 @@ func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, bod } report := &operationreport.Report{} - kit.doc.Input.ResetInputBytes(requestDocumentBytes) - kit.parser.Parse(kit.doc, report) + o.kit.doc.Input.ResetInputBytes(requestDocumentBytes) + o.kit.parser.Parse(o.kit.doc, report) if report.HasErrors() { - return nil, &reportError{ + return &reportError{ report: report, } } - for i := range kit.doc.RootNodes { - if kit.doc.RootNodes[i].Kind != ast.NodeKindOperationDefinition { + for i := range o.kit.doc.RootNodes { + if o.kit.doc.RootNodes[i].Kind != ast.NodeKindOperationDefinition { continue } operationCount++ - ref := kit.doc.RootNodes[i].Ref - name := kit.doc.Input.ByteSlice(kit.doc.OperationDefinitions[ref].Name) + ref := o.kit.doc.RootNodes[i].Ref + name := o.kit.doc.Input.ByteSlice(o.kit.doc.OperationDefinitions[ref].Name) if len(name) == 0 { anonymousOperationCount++ if anonymousOperationDefinitionRef == -1 { @@ -349,54 +297,59 @@ func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, bod continue } if requestOperationNameBytes == nil { - operationDefinitionRef = ref - originalOperationNameRef = kit.doc.OperationDefinitions[ref].Name + o.operationDefinitionRef = ref + o.originalOperationNameRef = o.kit.doc.OperationDefinitions[ref].Name requestOperationNameBytes = name continue } - if bytes.Equal(name, requestOperationNameBytes) && operationDefinitionRef == -1 { - operationDefinitionRef = ref - originalOperationNameRef = kit.doc.OperationDefinitions[ref].Name + if bytes.Equal(name, requestOperationNameBytes) && o.operationDefinitionRef == -1 { + o.operationDefinitionRef = ref + o.originalOperationNameRef = o.kit.doc.OperationDefinitions[ref].Name } } if !requestHasOperationName && operationCount > 1 { - return nil, &inputError{ + return &inputError{ message: "operation name is required when multiple operations are defined", statusCode: http.StatusOK, } } - if requestHasOperationName && operationCount != 0 && operationDefinitionRef == -1 { - return nil, &inputError{ + if requestHasOperationName && operationCount != 0 && o.operationDefinitionRef == -1 { + return &inputError{ message: fmt.Sprintf("operation with name '%s' not found", string(requestOperationNameBytes)), statusCode: http.StatusOK, } } - if operationDefinitionRef == -1 { + if o.operationDefinitionRef == -1 { if anonymousOperationCount == 1 { - operationDefinitionRef = anonymousOperationDefinitionRef + o.operationDefinitionRef = anonymousOperationDefinitionRef } else if anonymousOperationCount > 1 { - return nil, &inputError{ + return &inputError{ message: "operation name is required when multiple operations are defined", statusCode: http.StatusOK, } } else { - return nil, &inputError{ + return &inputError{ message: fmt.Sprintf("operation with name '%s' not found", string(requestOperationNameBytes)), statusCode: http.StatusOK, } } } - switch kit.doc.OperationDefinitions[operationDefinitionRef].OperationType { + switch o.kit.doc.OperationDefinitions[o.operationDefinitionRef].OperationType { case ast.OperationTypeQuery: requestOperationType = "query" case ast.OperationTypeMutation: requestOperationType = "mutation" case ast.OperationTypeSubscription: requestOperationType = "subscription" + default: + return &inputError{ + message: "operation type not supported", + statusCode: http.StatusOK, + } } // set variables to empty object if they are null or not present @@ -404,51 +357,157 @@ func (p *OperationParser) parse(ctx context.Context, clientInfo *ClientInfo, bod requestVariableBytes = []byte("{}") } - // set variables on doc input before normalization + // Set variables on doc input before normalization // IMPORTANT: this is required for the normalization to work correctly! // Normalization reads/rewrites/adds variables - kit.doc.Input.Variables = requestVariableBytes + o.kit.doc.Input.Variables = requestVariableBytes + + // Replace the operation name with a static name to avoid different IDs for the same operation + replaceOperationName := o.kit.doc.Input.AppendInputBytes(staticOperationName) + o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = replaceOperationName + + // Here we create a copy of the original variables. After parse, the variables can be consumed or modified. + variablesCopy := make([]byte, len(o.kit.doc.Input.Variables)) + copy(variablesCopy, o.kit.doc.Input.Variables) + + o.parsedOperation = &ParsedOperation{ + ID: 0, // will be set after normalization + NormalizedRepresentation: "", // will be set after normalization + Name: string(requestOperationNameBytes), + Type: requestOperationType, + Extensions: requestExtensions, + PersistedID: string(persistedQuerySha256Hash), + Variables: variablesCopy, + } - // replace the operation name with a static name to avoid different IDs for the same operation - replaceOperationName := kit.doc.Input.AppendInputBytes(staticOperationName) - kit.doc.OperationDefinitions[operationDefinitionRef].Name = replaceOperationName - kit.normalizer.NormalizeNamedOperation(kit.doc, p.executor.Definition, staticOperationName, report) + return nil +} + +// Normalize normalizes the operation. After normalization the normalized representation of the operation +// and variables is available. Also, the final operation ID is generated. +func (o *OperationKit) Normalize() error { + report := &operationreport.Report{} + o.kit.normalizer.NormalizeNamedOperation(o.kit.doc, o.operationParser.executor.Definition, staticOperationName, report) if report.HasErrors() { - return nil, &reportError{ + return &reportError{ report: report, } } - // hash the normalized operation with the static operation name to avoid different IDs for the same operation - err := kit.printer.Print(kit.doc, p.executor.Definition, kit.keyGen) + + // Hash the normalized operation with the static operation name to avoid different IDs for the same operation + err := o.kit.printer.Print(o.kit.doc, o.operationParser.executor.Definition, o.kit.keyGen) if err != nil { - return nil, errors.WithStack(fmt.Errorf("failed to print normalized operation: %w", err)) + return errors.WithStack(fmt.Errorf("failed to print normalized operation: %w", err)) } - operationID := kit.keyGen.Sum64() // generate the operation ID - // print the operation with the original operation name - kit.doc.OperationDefinitions[operationDefinitionRef].Name = originalOperationNameRef - err = kit.printer.Print(kit.doc, p.executor.Definition, kit.normalizedOperation) + + // Generate the operation ID + o.parsedOperation.ID = o.kit.keyGen.Sum64() + + // Print the operation with the original operation name + o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = o.originalOperationNameRef + err = o.kit.printer.Print(o.kit.doc, o.operationParser.executor.Definition, o.kit.normalizedOperation) if err != nil { - return nil, errors.WithStack(fmt.Errorf("failed to print normalized operation: %w", err)) + return errors.WithStack(fmt.Errorf("failed to print normalized operation: %w", err)) } - variablesCopy := make([]byte, len(kit.doc.Input.Variables)) - copy(variablesCopy, kit.doc.Input.Variables) + // Set the normalized representation + o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() + + // Here we copy the normalized variables. After normalization, the variables can be consumed or modified. + variablesCopy := make([]byte, len(o.kit.doc.Input.Variables)) + copy(variablesCopy, o.kit.doc.Input.Variables) + + o.parsedOperation.Variables = variablesCopy - err = kit.variablesValidator.Validate(kit.doc, p.executor.Definition, variablesCopy) + return nil +} + +// Validate validates the operation variables. +func (o *OperationKit) Validate() error { + err := o.kit.variablesValidator.Validate(o.kit.doc, o.operationParser.executor.Definition, o.parsedOperation.Variables) if err != nil { - return nil, &inputError{ + return &inputError{ message: err.Error(), statusCode: http.StatusBadRequest, } } - return &ParsedOperation{ - ID: operationID, - Name: string(requestOperationNameBytes), - Type: requestOperationType, - Variables: variablesCopy, - NormalizedRepresentation: kit.normalizedOperation.String(), - Extensions: requestExtensions, - PersistedID: string(persistedQuerySha256Hash), - }, nil + return nil +} + +func NewOperationParser(opts OperationParserOptions) *OperationProcessor { + return &OperationProcessor{ + executor: opts.Executor, + maxOperationSizeInBytes: opts.MaxOperationSizeInBytes, + cdn: opts.PersistentOpClient, + parseKitPool: &sync.Pool{ + New: func() interface{} { + return &parseKit{ + parser: astparser.NewParser(), + doc: ast.NewSmallDocument(), + keyGen: xxhash.New(), + normalizer: astnormalization.NewWithOpts( + astnormalization.WithExtractVariables(), + astnormalization.WithInlineFragmentSpreads(), + astnormalization.WithRemoveFragmentDefinitions(), + astnormalization.WithRemoveNotMatchingOperationDefinitions(), + ), + printer: &astprinter.Printer{}, + normalizedOperation: &bytes.Buffer{}, + unescapedDocument: make([]byte, 1024), + variablesValidator: variablesvalidation.NewVariablesValidator(), + } + }, + }, + } +} + +func (p *OperationProcessor) getKit() *parseKit { + return p.parseKitPool.Get().(*parseKit) +} + +func (p *OperationProcessor) freeKit(kit *parseKit) { + kit.keyGen.Reset() + kit.doc.Reset() + kit.normalizedOperation.Reset() + kit.unescapedDocument = kit.unescapedDocument[:0] +} + +func (p *OperationProcessor) entityTooLarge() error { + return &inputError{ + message: "request body too large", + statusCode: http.StatusRequestEntityTooLarge, + } +} + +func (p *OperationProcessor) ReadBody(buf *bytes.Buffer, r io.Reader) ([]byte, error) { + // Use an extra byte for the max size. This way we can check if N became + // zero to detect if the request body was too large. + limitedReader := &io.LimitedReader{R: r, N: p.maxOperationSizeInBytes + 1} + if _, err := io.Copy(buf, limitedReader); err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + if limitedReader.N == 0 { + return nil, p.entityTooLarge() + } + + return buf.Bytes(), nil +} + +func (p *OperationProcessor) NewKitFromReader(r io.Reader) (*OperationKit, error) { + buf := pool.GetBytesBuffer() + defer pool.PutBytesBuffer(buf) + data, err := p.ReadBody(buf, r) + if err != nil { + return nil, err + } + return NewOperationKit(p, data), nil +} + +func (p *OperationProcessor) NewKit(data []byte) (*OperationKit, error) { + if len(data) > int(p.maxOperationSizeInBytes) { + return nil, p.entityTooLarge() + } + return NewOperationKit(p, data), nil } diff --git a/router/core/operation_processor_test.go b/router/core/operation_processor_test.go new file mode 100644 index 0000000000..dcb42dc129 --- /dev/null +++ b/router/core/operation_processor_test.go @@ -0,0 +1,229 @@ +package core + +import ( + "context" + "errors" + "github.com/stretchr/testify/require" + "strings" + "testing" + + "github.com/buger/jsonparser" + "github.com/stretchr/testify/assert" + "github.com/wundergraph/cosmo/router/internal/pool" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "go.uber.org/zap" +) + +func TestOperationParser(t *testing.T) { + executor := &Executor{ + PlanConfig: plan.Configuration{}, + Definition: nil, + Resolver: nil, + RenameTypeNames: nil, + Pool: pool.New(), + } + parser := NewOperationParser(OperationParserOptions{ + Executor: executor, + MaxOperationSizeInBytes: 10 << 20, + }) + clientInfo := &ClientInfo{ + Name: "test", + Version: "1.0.0", + } + log := zap.NewNop() + testCases := []struct { + ExpectedType string + ExpectedError error + Input string + Variables string + }{ + /** + * Test cases parse simple + */ + { + Input: `{"query":"query { employees { name } }"`, + ExpectedType: "query", + Variables: `{}`, + ExpectedError: nil, + }, + /** + * Test cases parse invalid graphql + */ + { + Input: `{"query":"invalid", "variables": {"foo": "bar"}}`, + Variables: `{"foo": "bar"}`, + ExpectedError: errors.New("unexpected literal - got: UNDEFINED want one of: [ENUM TYPE UNION QUERY INPUT EXTEND SCHEMA SCALAR FRAGMENT INTERFACE DIRECTIVE]"), + }, + /** + * Test cases parse operation types + */ + { + ExpectedType: "subscription", + Input: `{"query":"subscription { initialPayload(repeat:3) }", "variables": {"foo": "bar"}}`, + Variables: `{"foo": "bar"}`, + ExpectedError: nil, + }, + { + ExpectedType: "query", + Input: `{"query":"query { initialPayload(repeat:3) }", "variables": {"foo": "bar"}}`, + Variables: `{"foo": "bar"}`, + ExpectedError: nil, + }, + { + ExpectedType: "mutation", + Input: `{"query":"mutation { initialPayload(repeat:3) }", "variables": {"foo": "bar"}}`, + Variables: `{"foo": "bar"}`, + ExpectedError: nil, + }, + /** + * Test cases parse variables + */ + { + ExpectedType: "query", + Input: `{"query":"query { initialPayload(repeat:3) }", "variables": {"foo": ["bar"]}}`, + Variables: `{"foo": ["bar"]}`, + ExpectedError: nil, + }, + { + Input: `{"query":"query { initialPayload(repeat:3) }", "variables": null}`, + ExpectedType: "query", + Variables: "{}", + ExpectedError: nil, + }, + { + Input: `{"query":"query { initialPayload(repeat:3) }", "variables": {"foo": {"bar": "baz"}}`, + ExpectedType: "query", + Variables: `{"foo": {"bar": "baz"}}`, + ExpectedError: nil, + }, + { + Input: `{"query":"mutation", "variables": {"foo": "bar"}}`, + ExpectedError: errors.New("unexpected token - got: EOF want one of: [LBRACE]"), + ExpectedType: "", + Variables: "", + }, + /** + * Test cases parse operation name + */ + { + Input: `{"query":"subscription { initialPayload(repeat:3) }", "variables": {"foo": "bar"}, "operationName": "test"}`, + ExpectedError: errors.New("operation with name 'test' not found"), + ExpectedType: "", + Variables: "", + }, + { + ExpectedType: "subscription", + Input: `{"query":"subscription foo { initialPayload(repeat:3) }", "variables": {"foo": "bar"}, "operationName": "foo"}`, + Variables: `{"foo": "bar"}`, + ExpectedError: nil, + }, + /** + * Test cases parse multiple operations + */ + { + Input: `{"query":"query { initialPayload(repeat:3) } mutation { initialPayload(repeat:3) }", "variables": {"foo": "bar"}}`, + ExpectedError: errors.New("operation name is required when multiple operations are defined"), + ExpectedType: "", + Variables: "", + }, + { + ExpectedType: "query", + Input: `{"query":"query test { initialPayload(repeat:3) } mutation { initialPayload(repeat:3) }", "variables": {"foo": "bar"}, "operationName": "test"}`, + Variables: `{"foo": "bar"}`, + ExpectedError: nil, + }, + /** + * Test cases persist operation + */ + { + Input: `{"operationName": "test", "variables": {"foo": "bar"}, "extensions": {"persistedQuery": {"version": 1, "sha256Hash": "does-not-exist"}}}`, + Variables: `{"foo": "bar"}`, + ExpectedError: errors.New("could not resolve persisted query, feature is not configured"), + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.Input, func(t *testing.T) { + kit, err := parser.NewKitFromReader(strings.NewReader(tc.Input)) + assert.NoError(t, err) + + err = kit.Parse(context.Background(), clientInfo, log) + + if err != nil { + require.EqualError(t, tc.ExpectedError, err.Error()) + } else if kit.parsedOperation != nil { + require.Equal(t, tc.ExpectedType, kit.parsedOperation.Type) + require.JSONEq(t, tc.Variables, string(kit.parsedOperation.Variables)) + require.Equal(t, uint64(0), kit.parsedOperation.ID) + require.Equal(t, "", kit.parsedOperation.NormalizedRepresentation) + } + }) + } +} + +func TestOperationParserExtensions(t *testing.T) { + executor := &Executor{ + PlanConfig: plan.Configuration{}, + Definition: nil, + Resolver: nil, + RenameTypeNames: nil, + Pool: pool.New(), + } + parser := NewOperationParser(OperationParserOptions{ + Executor: executor, + MaxOperationSizeInBytes: 10 << 20, + }) + clientInfo := &ClientInfo{ + Name: "test", + Version: "1.0.0", + } + log := zap.NewNop() + testCases := []struct { + Input string + ValueType jsonparser.ValueType + Valid bool + }{ + { + Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":"this_is_not_valid"}`, + ValueType: jsonparser.String, + }, + { + Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":42}`, + ValueType: jsonparser.Number, + }, + { + Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":true}`, + ValueType: jsonparser.Boolean, + }, + { + Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":{}}`, + Valid: true, + }, + { + Input: `{"query":"subscription { initialPayload(repeat:3) }","extensions":null}`, + Valid: true, + }, + { + Input: `{"query":"subscription { initialPayload(repeat:3) }"}`, + Valid: true, + }, + } + var inputError InputError + for _, tc := range testCases { + tc := tc + t.Run(tc.Input, func(t *testing.T) { + kit, err := parser.NewKitFromReader(strings.NewReader(tc.Input)) + assert.NoError(t, err) + + err = kit.Parse(context.Background(), clientInfo, log) + isInputError := errors.As(err, &inputError) + if tc.Valid { + assert.False(t, isInputError, "expected invalid extensions to not return an input error, got %s", err) + } else { + assert.True(t, isInputError, "expected invalid extensions to return an input error, got %s", err) + assert.Contains(t, err.Error(), "extensions", "expected error to contain extensions") + assert.Contains(t, err.Error(), tc.ValueType.String(), "expected error to contain value type name") + } + }) + } +} diff --git a/router/core/router.go b/router/core/router.go index 05bef80a3c..fe78ff1484 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -310,7 +310,7 @@ func NewRouter(opts ...Option) (*Router, error) { } if r.traceConfig.Enabled { - defaultExporter := rtrace.GetDefaultExporter(r.traceConfig) + defaultExporter := rtrace.DefaultExporter(r.traceConfig) if defaultExporter != nil { disabledFeatures = append(disabledFeatures, "Cosmo Cloud Tracing") defaultExporter.Disabled = true @@ -547,7 +547,7 @@ func (r *Router) NewServer(ctx context.Context) (Server, error) { // bootstrap initializes the Router. It is called by Start() and NewServer(). // It should only be called once for a Router instance. func (r *Router) bootstrap(ctx context.Context) error { - cosmoCloudTracingEnabled := r.traceConfig.Enabled && rtrace.GetDefaultExporter(r.traceConfig) != nil + cosmoCloudTracingEnabled := r.traceConfig.Enabled && rtrace.DefaultExporter(r.traceConfig) != nil artInProductionEnabled := r.engineExecutionConfiguration.EnableRequestTracing && !r.developmentMode needsRegistration := cosmoCloudTracingEnabled || artInProductionEnabled @@ -711,6 +711,7 @@ func (r *Router) newServer(ctx context.Context, routerConfig *nodev1.RouterConfi baseAttributes := []attribute.KeyValue{ otel.WgRouterConfigVersion.String(routerConfig.GetVersion()), otel.WgRouterVersion.String(Version), + otel.WgRouterRootSpan.Bool(true), } if r.graphApiToken != "" { @@ -945,7 +946,7 @@ func (r *Router) newServer(ctx context.Context, routerConfig *nodev1.RouterConfi Logger: r.logger, Executor: executor, Metrics: routerMetrics, - Parser: operationParser, + OperationProcessor: operationParser, Planner: operationPlanner, AccessController: r.accessController, RouterPublicKey: publicKey, @@ -953,10 +954,11 @@ func (r *Router) newServer(ctx context.Context, routerConfig *nodev1.RouterConfi DevelopmentMode: r.developmentMode, TracerProvider: r.tracerProvider, FlushTelemetryAfterResponse: r.awsLambda, + TraceExportVariables: r.traceConfig.ExportGraphQLVariables.Enabled, }) wsMiddleware := NewWebsocketMiddleware(rootContext, WebsocketMiddlewareOptions{ - Parser: operationParser, + OperationProcessor: operationParser, Planner: operationPlanner, GraphQLHandler: graphqlHandler, Metrics: routerMetrics, diff --git a/router/core/transport.go b/router/core/transport.go index a5a1ede502..1507ebf4b7 100644 --- a/router/core/transport.go +++ b/router/core/transport.go @@ -64,7 +64,7 @@ func NewCustomTransport( func (ct *CustomTransport) measureSubgraphMetrics(req *http.Request) func(err error, resp *http.Response) { reqContext := getRequestContext(req.Context()) - baseFields := commonMetricAttributes(reqContext.operation) + baseFields := setAttributesFromOperationContext(reqContext.operation) activeSubgraph := reqContext.ActiveSubgraph(req) if activeSubgraph != nil { @@ -304,7 +304,7 @@ func (t TransportFactory) RoundTripper(enableSingleFlight bool, transport http.R reqContext := getRequestContext(r.Context()) operation := reqContext.operation - commonAttributeValues := commonMetricAttributes(operation) + commonAttributeValues := setAttributesFromOperationContext(operation) subgraph := reqContext.ActiveSubgraph(r) if subgraph != nil { diff --git a/router/core/websocket.go b/router/core/websocket.go index bb909037f4..1c95a5868c 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -35,14 +35,14 @@ var ( ) type WebsocketMiddlewareOptions struct { - Parser *OperationParser - Planner *OperationPlanner - GraphQLHandler *GraphQLHandler - Metrics RouterMetrics - AccessController *AccessController - Logger *zap.Logger - Stats WebSocketsStatistics - ReadTimeout time.Duration + OperationProcessor *OperationProcessor + Planner *OperationPlanner + GraphQLHandler *GraphQLHandler + Metrics RouterMetrics + AccessController *AccessController + Logger *zap.Logger + Stats WebSocketsStatistics + ReadTimeout time.Duration EnableWebSocketEpollKqueue bool EpollKqueuePollTimeout time.Duration @@ -52,16 +52,16 @@ type WebsocketMiddlewareOptions struct { func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { handler := &WebsocketHandler{ - ctx: ctx, - next: next, - parser: opts.Parser, - planner: opts.Planner, - graphqlHandler: opts.GraphQLHandler, - metrics: opts.Metrics, - accessController: opts.AccessController, - logger: opts.Logger, - stats: opts.Stats, - readTimeout: opts.ReadTimeout, + ctx: ctx, + next: next, + operationProcessor: opts.OperationProcessor, + planner: opts.Planner, + graphqlHandler: opts.GraphQLHandler, + metrics: opts.Metrics, + accessController: opts.AccessController, + logger: opts.Logger, + stats: opts.Stats, + readTimeout: opts.ReadTimeout, } handler.handlerPool = pond.New( 64, @@ -138,14 +138,14 @@ func (c *wsConnectionWrapper) Close() error { } type WebsocketHandler struct { - ctx context.Context - next http.Handler - parser *OperationParser - planner *OperationPlanner - graphqlHandler *GraphQLHandler - metrics RouterMetrics - accessController *AccessController - logger *zap.Logger + ctx context.Context + next http.Handler + operationProcessor *OperationProcessor + planner *OperationPlanner + graphqlHandler *GraphQLHandler + metrics RouterMetrics + accessController *AccessController + logger *zap.Logger epoll epoller.Poller connections map[int]*WebSocketConnectionHandler @@ -215,19 +215,19 @@ func (h *WebsocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } handler := NewWebsocketConnectionHandler(h.ctx, WebSocketConnectionHandlerOptions{ - Parser: h.parser, - Planner: h.planner, - GraphQLHandler: h.graphqlHandler, - Metrics: h.metrics, - ResponseWriter: w, - Request: r, - Connection: conn, - Protocol: protocol, - Logger: h.logger, - Stats: h.stats, - ConnectionID: h.connectionIDs.Inc(), - ClientInfo: clientInfo, - InitRequestID: requestID, + OperationProcessor: h.operationProcessor, + Planner: h.planner, + GraphQLHandler: h.graphqlHandler, + Metrics: h.metrics, + ResponseWriter: w, + Request: r, + Connection: conn, + Protocol: protocol, + Logger: h.logger, + Stats: h.stats, + ConnectionID: h.connectionIDs.Inc(), + ClientInfo: clientInfo, + InitRequestID: requestID, }) err = handler.Initialize() if err != nil { @@ -484,35 +484,35 @@ type graphqlError struct { } type WebSocketConnectionHandlerOptions struct { - Parser *OperationParser - Planner *OperationPlanner - GraphQLHandler *GraphQLHandler - Metrics RouterMetrics - ResponseWriter http.ResponseWriter - Request *http.Request - Connection *wsConnectionWrapper - Protocol wsproto.Proto - Logger *zap.Logger - Stats WebSocketsStatistics - ConnectionID int64 - RequestContext context.Context - ClientInfo *ClientInfo - InitRequestID string + OperationProcessor *OperationProcessor + Planner *OperationPlanner + GraphQLHandler *GraphQLHandler + Metrics RouterMetrics + ResponseWriter http.ResponseWriter + Request *http.Request + Connection *wsConnectionWrapper + Protocol wsproto.Proto + Logger *zap.Logger + Stats WebSocketsStatistics + ConnectionID int64 + RequestContext context.Context + ClientInfo *ClientInfo + InitRequestID string } type WebSocketConnectionHandler struct { - ctx context.Context - parser *OperationParser - planner *OperationPlanner - graphqlHandler *GraphQLHandler - metrics RouterMetrics - w http.ResponseWriter - r *http.Request - conn *wsConnectionWrapper - protocol wsproto.Proto - initialPayload json.RawMessage - clientInfo *ClientInfo - logger *zap.Logger + ctx context.Context + operationProcessor *OperationProcessor + planner *OperationPlanner + graphqlHandler *GraphQLHandler + metrics RouterMetrics + w http.ResponseWriter + r *http.Request + conn *wsConnectionWrapper + protocol wsproto.Proto + initialPayload json.RawMessage + clientInfo *ClientInfo + logger *zap.Logger initRequestID string connectionID int64 @@ -523,20 +523,20 @@ type WebSocketConnectionHandler struct { func NewWebsocketConnectionHandler(ctx context.Context, opts WebSocketConnectionHandlerOptions) *WebSocketConnectionHandler { return &WebSocketConnectionHandler{ - ctx: ctx, - parser: opts.Parser, - planner: opts.Planner, - graphqlHandler: opts.GraphQLHandler, - metrics: opts.Metrics, - w: opts.ResponseWriter, - r: opts.Request, - conn: opts.Connection, - protocol: opts.Protocol, - logger: opts.Logger, - connectionID: opts.ConnectionID, - stats: opts.Stats, - clientInfo: opts.ClientInfo, - initRequestID: opts.InitRequestID, + ctx: ctx, + operationProcessor: opts.OperationProcessor, + planner: opts.Planner, + graphqlHandler: opts.GraphQLHandler, + metrics: opts.Metrics, + w: opts.ResponseWriter, + r: opts.Request, + conn: opts.Connection, + protocol: opts.Protocol, + logger: opts.Logger, + connectionID: opts.ConnectionID, + stats: opts.Stats, + clientInfo: opts.ClientInfo, + initRequestID: opts.InitRequestID, } } @@ -561,16 +561,30 @@ func (h *WebSocketConnectionHandler) writeErrorMessage(operationID string, err e } func (h *WebSocketConnectionHandler) parseAndPlan(payload []byte) (*ParsedOperation, *operationContext, error) { - operation, err := h.parser.Parse(h.ctx, h.clientInfo, payload, h.logger) + operationKit, err := h.operationProcessor.NewKit(payload) + defer operationKit.Free() if err != nil { return nil, nil, err } - opContext, err := h.planner.Plan(operation, h.clientInfo, OperationProtocolWS, ParseRequestTraceOptions(h.r)) + + if err := operationKit.Parse(h.ctx, h.clientInfo, h.logger); err != nil { + return nil, nil, err + } + + if err := operationKit.Normalize(); err != nil { + return nil, nil, err + } + + if err := operationKit.Validate(); err != nil { + return nil, nil, err + } + + opContext, err := h.planner.Plan(operationKit.parsedOperation, h.clientInfo, OperationProtocolWS, ParseRequestTraceOptions(h.r)) if err != nil { - return operation, nil, err + return operationKit.parsedOperation, nil, err } opContext.initialPayload = h.initialPayload - return operation, opContext, nil + return operationKit.parsedOperation, opContext, nil } func (h *WebSocketConnectionHandler) executeSubscription(msg *wsproto.Message, id resolve.SubscriptionIdentifier) { @@ -579,9 +593,9 @@ func (h *WebSocketConnectionHandler) executeSubscription(msg *wsproto.Message, i _, operationCtx, err := h.parseAndPlan(msg.Payload) if err != nil { - werr := h.writeErrorMessage(msg.ID, err) - if werr != nil { - h.logger.Warn("writing error message", zap.Error(werr)) + wErr := h.writeErrorMessage(msg.ID, err) + if wErr != nil { + h.logger.Warn("writing error message", zap.Error(wErr)) } return } diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index b916edeb33..d79e68a303 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -29,6 +29,10 @@ type TracingExporterConfig struct { ExportTimeout time.Duration `yaml:"export_timeout" default:"30s" validate:"required,min=5s,max=120s"` } +type TracingGlobalFeatures struct { + ExportGraphQLVariables bool `yaml:"export_graphql_variables" default:"true" envconfig:"TRACING_EXPORT_GRAPHQL_VARIABLES"` +} + type TracingExporter struct { Disabled bool `yaml:"disabled"` Exporter otelconfig.Exporter `yaml:"exporter" validate:"oneof=http grpc"` @@ -39,10 +43,11 @@ type TracingExporter struct { } type Tracing struct { - Enabled bool `yaml:"enabled" default:"true" envconfig:"TRACING_ENABLED"` - SamplingRate float64 `yaml:"sampling_rate" default:"1" validate:"required,min=0,max=1" envconfig:"TRACING_SAMPLING_RATE"` - Exporters []TracingExporter `yaml:"exporters"` - Propagation PropagationConfig `yaml:"propagation"` + Enabled bool `yaml:"enabled" default:"true" envconfig:"TRACING_ENABLED"` + SamplingRate float64 `yaml:"sampling_rate" default:"1" validate:"required,min=0,max=1" envconfig:"TRACING_SAMPLING_RATE"` + Exporters []TracingExporter `yaml:"exporters"` + Propagation PropagationConfig `yaml:"propagation"` + TracingGlobalFeatures `yaml:",inline"` } type PropagationConfig struct { diff --git a/router/pkg/otel/attributes.go b/router/pkg/otel/attributes.go index fb2d5cc6da..21e277f38c 100644 --- a/router/pkg/otel/attributes.go +++ b/router/pkg/otel/attributes.go @@ -7,6 +7,7 @@ const ( WgOperationType = attribute.Key("wg.operation.type") WgOperationContent = attribute.Key("wg.operation.content") WgOperationHash = attribute.Key("wg.operation.hash") + WgOperationVariables = attribute.Key("wg.operation.variables") WgOperationProtocol = attribute.Key("wg.operation.protocol") WgComponentName = attribute.Key("wg.component.name") WgClientName = attribute.Key("wg.client.name") @@ -20,6 +21,7 @@ const ( WgOperationPersistedID = attribute.Key("wg.operation.persisted_id") WgEnginePlanCacheHit = attribute.Key("wg.engine.plan_cache_hit") WgEngineRequestTracingEnabled = attribute.Key("wg.engine.request_tracing_enabled") + WgRouterRootSpan = attribute.Key("wg.router.root_span") ) var ( diff --git a/router/pkg/trace/config.go b/router/pkg/trace/config.go index 9610558082..e6b3347bfc 100644 --- a/router/pkg/trace/config.go +++ b/router/pkg/trace/config.go @@ -38,6 +38,10 @@ type Exporter struct { HTTPPath string } +type ExportGraphQLVariables struct { + Enabled bool +} + // Config represents the configuration for the agent. type Config struct { Enabled bool @@ -46,12 +50,14 @@ type Config struct { // Version represents the service version for tracing. The default value is dev. Version string // Sampler represents the sampler for tracing. The default value is 1. - Sampler float64 - Exporters []*Exporter - Propagators []Propagator + Sampler float64 + // ExportGraphQLVariables defines if and how GraphQL variables should be exported as span attributes. + ExportGraphQLVariables ExportGraphQLVariables + Exporters []*Exporter + Propagators []Propagator } -func GetDefaultExporter(cfg *Config) *Exporter { +func DefaultExporter(cfg *Config) *Exporter { for _, exporter := range cfg.Exporters { if exporter.Disabled { continue @@ -78,6 +84,9 @@ func DefaultConfig(serviceVersion string) *Config { Name: ServerName, Version: serviceVersion, Sampler: 1, + ExportGraphQLVariables: ExportGraphQLVariables{ + Enabled: true, + }, Exporters: []*Exporter{ { Disabled: false, diff --git a/studio/src/components/analytics/trace.tsx b/studio/src/components/analytics/trace.tsx index ba4a07e540..b09d219a03 100644 --- a/studio/src/components/analytics/trace.tsx +++ b/studio/src/components/analytics/trace.tsx @@ -146,24 +146,6 @@ function Node({ }); }; - let additionalTooltipContent = ""; - - if (span.scopeName.startsWith("wundergraph/cosmo/router/")) { - if (span.spanName === "Authenticate") { - additionalTooltipContent = - "Authenticates the request against the configured authentication provider."; - } else if (span.spanName === "Operation - Parse and Validate") { - additionalTooltipContent = - "This is the first step in the query execution. It parses the variables, query and validates it against the schema."; - } else if (span.spanName === "Operation - Planning") { - additionalTooltipContent = - "Describes the process of building the optimized query plan for a given GraphQL query. This includes normalization, validating it against the schema."; - } else if (span.spanName === "Operation - Execution") { - additionalTooltipContent = - "Describes the process of executing the query plan for an operation. This includes fetching data from the subgraphs, aggregating the data and returning it to the client."; - } - } - return (