diff --git a/v2/pkg/astparser/parser.go b/v2/pkg/astparser/parser.go index 384c9db34..a1d9e6229 100644 --- a/v2/pkg/astparser/parser.go +++ b/v2/pkg/astparser/parser.go @@ -29,7 +29,7 @@ func ParseGraphqlDocumentString(input string) (ast.Document, operationreport.Rep // Instead create a parser as well as AST objects and re-use them. func ParseGraphqlDocumentBytes(input []byte) (ast.Document, operationreport.Report) { parser := NewParser() - doc := *ast.NewDocument() + doc := *ast.NewSmallDocument() doc.Input.ResetInputBytes(input) report := operationreport.Report{} parser.Parse(&doc, &report) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 4ab73e273..85f25d008 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -101,7 +101,7 @@ func (p *Planner) UpstreamSchema(dataSourceConfig plan.DataSourceConfiguration) panic(err) } - definition := ast.NewDocument() + definition := ast.NewSmallDocument() definitionParser := astparser.NewParser() report := &operationreport.Report{} @@ -738,7 +738,7 @@ func (p *Planner) EnterArgument(_ int) { func (p *Planner) EnterDocument(_, _ *ast.Document) { if p.upstreamOperation == nil { - p.upstreamOperation = ast.NewDocument() + p.upstreamOperation = ast.NewSmallDocument() } else { p.upstreamOperation.Reset() } @@ -763,7 +763,7 @@ func (p *Planner) EnterDocument(_, _ *ast.Document) { p.upstreamDefinition = nil if p.config.UpstreamSchema != "" { - p.upstreamDefinition = ast.NewDocument() + p.upstreamDefinition = ast.NewSmallDocument() p.upstreamDefinition.Input.ResetInputString(p.config.UpstreamSchema) parser := astparser.NewParser() var report operationreport.Report @@ -1354,8 +1354,8 @@ func (p *Planner) printOperation() []byte { rawQuery := buf.Bytes() // create empty operation and definition documents - operation := ast.NewDocument() - definition := ast.NewDocument() + operation := ast.NewSmallDocument() + definition := ast.NewSmallDocument() report := &operationreport.Report{} operationParser := astparser.NewParser() definitionParser := astparser.NewParser() diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index dceea9968..f9bf92b1c 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -3,27 +3,19 @@ package resolve import ( - "bytes" "context" - "fmt" "io" - "strconv" "sync" - "github.com/buger/jsonparser" "github.com/pkg/errors" - "github.com/tidwall/gjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "golang.org/x/sync/singleflight" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/fastbuffer" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) type Resolver struct { ctx context.Context enableSingleFlightLoader bool - sf *singleflight.Group + sf *Group toolPool sync.Pool } @@ -37,7 +29,7 @@ func New(ctx context.Context, enableSingleFlightLoader bool) *Resolver { return &Resolver{ ctx: ctx, enableSingleFlightLoader: enableSingleFlightLoader, - sf: &singleflight.Group{}, + sf: &Group{}, toolPool: sync.Pool{ New: func() interface{} { return &tools{ @@ -62,40 +54,6 @@ func (r *Resolver) putTools(t *tools) { r.toolPool.Put(t) } -func (r *Resolver) resolveNode(ctx *Context, node Node, data []byte, bufPair *BufPair) (err error) { - switch n := node.(type) { - case *Object: - return r.resolveObject(ctx, n, data, bufPair) - case *Array: - return r.resolveArray(ctx, n, data, bufPair) - case *Null: - r.resolveNull(bufPair.Data) - return - case *String: - return r.resolveString(ctx, n, data, bufPair) - case *Boolean: - return r.resolveBoolean(ctx, n, data, bufPair) - case *Integer: - return r.resolveInteger(ctx, n, data, bufPair) - case *Float: - return r.resolveFloat(ctx, n, data, bufPair) - case *BigInt: - return r.resolveBigInt(ctx, n, data, bufPair) - case *Scalar: - return r.resolveScalar(ctx, n, data, bufPair) - case *EmptyObject: - r.resolveEmptyObject(bufPair.Data) - return - case *EmptyArray: - r.resolveEmptyArray(bufPair.Data) - return - case *CustomNode: - return r.resolveCustom(ctx, n, data, bufPair) - default: - return - } -} - func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, data []byte, writer io.Writer) (err error) { if response.Info == nil { @@ -179,492 +137,3 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } } } - -func (r *Resolver) resolveEmptyArray(b *fastbuffer.FastBuffer) { - b.WriteBytes(lBrack) - b.WriteBytes(rBrack) -} - -func (r *Resolver) resolveEmptyObject(b *fastbuffer.FastBuffer) { - b.WriteBytes(lBrace) - b.WriteBytes(rBrace) -} - -func (r *Resolver) resolveArray(ctx *Context, array *Array, data []byte, arrayBuf *BufPair) (parentErr error) { - if len(array.Path) != 0 { - data, _, _, _ = jsonparser.Get(data, array.Path...) - } - - if bytes.Equal(data, emptyArray) { - r.resolveEmptyArray(arrayBuf.Data) - return - } - - itemBuf := r.getBufPair() - defer r.freeBufPair(itemBuf) - - reset := arrayBuf.Data.Len() - i := 0 - hasData := false - resolveArrayAsNull := false - - _, err := jsonparser.ArrayEach(data, func(value []byte, dataType jsonparser.ValueType, offset int, err error) { - if parentErr != nil { - return - } - if err == nil && dataType == jsonparser.String { - value = data[offset-2 : offset+len(value)] // add quotes to string values - } - itemBuf.Reset() - ctx.addIntegerPathElement(i) - err = r.resolveNode(ctx, array.Item, value, itemBuf) - ctx.removeLastPathElement() - if err != nil { - if errors.Is(err, errNonNullableFieldValueIsNull) && array.Nullable { - resolveArrayAsNull = true - return - } - parentErr = err - return - } - if !hasData { - arrayBuf.Data.WriteBytes(lBrack) - } - r.MergeBufPairs(itemBuf, arrayBuf, hasData) - hasData = true - i++ - }) - if err != nil { - if array.Nullable { - arrayBuf.Data.Reslice(0, reset) - r.resolveNull(arrayBuf.Data) - return nil - } - return errNonNullableFieldValueIsNull - } - if resolveArrayAsNull { - arrayBuf.Data.Reslice(0, reset) - r.resolveNull(arrayBuf.Data) - return nil - } - if hasData { - arrayBuf.Data.WriteBytes(rBrack) - } - return -} - -func (r *Resolver) resolveArraySynchronous(ctx *Context, array *Array, arrayItems *[][]byte, arrayBuf *BufPair) (err error) { - arrayBuf.Data.WriteBytes(lBrack) - start := arrayBuf.Data.Len() - - itemBuf := r.getBufPair() - defer r.freeBufPair(itemBuf) - - for i := range *arrayItems { - ctx.addIntegerPathElement(i) - if arrayBuf.Data.Len() > start { - arrayBuf.Data.WriteBytes(comma) - } - err = r.resolveNode(ctx, array.Item, (*arrayItems)[i], arrayBuf) - ctx.removeLastPathElement() - if err != nil { - if errors.Is(err, errNonNullableFieldValueIsNull) && array.Nullable { - arrayBuf.Data.Reset() - r.resolveNull(arrayBuf.Data) - return nil - } - if errors.Is(err, errTypeNameSkipped) { - err = nil - continue - } - return - } - } - - arrayBuf.Data.WriteBytes(rBrack) - return -} - -func (r *Resolver) exportField(ctx *Context, export *FieldExport, value []byte) { - if export == nil { - return - } - if export.AsString { - value = append(quote, append(value, quote...)...) - } - ctx.Variables, _ = jsonparser.Set(ctx.Variables, value, export.Path...) -} - -func (r *Resolver) resolveInteger(ctx *Context, integer *Integer, data []byte, integerBuf *BufPair) error { - value, dataType, _, err := jsonparser.Get(data, integer.Path...) - if err != nil || dataType != jsonparser.Number { - if !integer.Nullable { - return errNonNullableFieldValueIsNull - } - r.resolveNull(integerBuf.Data) - return nil - } - integerBuf.Data.WriteBytes(value) - r.exportField(ctx, integer.Export, value) - return nil -} - -func (r *Resolver) resolveFloat(ctx *Context, floatValue *Float, data []byte, floatBuf *BufPair) error { - value, dataType, _, err := jsonparser.Get(data, floatValue.Path...) - if err != nil || dataType != jsonparser.Number { - if !floatValue.Nullable { - return errNonNullableFieldValueIsNull - } - r.resolveNull(floatBuf.Data) - return nil - } - floatBuf.Data.WriteBytes(value) - r.exportField(ctx, floatValue.Export, value) - return nil -} - -func (r *Resolver) resolveBigInt(ctx *Context, bigIntValue *BigInt, data []byte, bigIntBuf *BufPair) error { - value, valueType, _, err := jsonparser.Get(data, bigIntValue.Path...) - switch { - case err != nil, valueType == jsonparser.Null: - if !bigIntValue.Nullable { - return errNonNullableFieldValueIsNull - } - r.resolveNull(bigIntBuf.Data) - return nil - case valueType == jsonparser.Number: - bigIntBuf.Data.WriteBytes(value) - case valueType == jsonparser.String: - bigIntBuf.Data.WriteBytes(quote) - bigIntBuf.Data.WriteBytes(value) - bigIntBuf.Data.WriteBytes(quote) - default: - return fmt.Errorf("invalid value type '%s' for path %s, expecting number or string, got: %v", valueType, string(ctx.path()), string(value)) - - } - r.exportField(ctx, bigIntValue.Export, value) - return nil -} - -func (r *Resolver) resolveScalar(ctx *Context, scalarValue *Scalar, data []byte, scalarBuf *BufPair) error { - value, valueType, _, err := jsonparser.Get(data, scalarValue.Path...) - switch { - case err != nil, valueType == jsonparser.Null: - if !scalarValue.Nullable { - return errNonNullableFieldValueIsNull - } - r.resolveNull(scalarBuf.Data) - return nil - case valueType == jsonparser.String: - scalarBuf.Data.WriteBytes(quote) - scalarBuf.Data.WriteBytes(value) - scalarBuf.Data.WriteBytes(quote) - default: - scalarBuf.Data.WriteBytes(value) - } - r.exportField(ctx, scalarValue.Export, value) - return nil -} - -func (r *Resolver) resolveCustom(ctx *Context, customValue *CustomNode, data []byte, customBuf *BufPair) error { - value, dataType, _, _ := jsonparser.Get(data, customValue.Path...) - if dataType == jsonparser.Null && !customValue.Nullable { - return errNonNullableFieldValueIsNull - } - resolvedValue, err := customValue.Resolve(value) - if err != nil { - return fmt.Errorf("failed to resolve value type %s for path %s via custom resolver", dataType, string(ctx.path())) - } - customBuf.Data.WriteBytes(resolvedValue) - return nil -} - -func (r *Resolver) resolveBoolean(ctx *Context, boolean *Boolean, data []byte, booleanBuf *BufPair) error { - value, valueType, _, err := jsonparser.Get(data, boolean.Path...) - if err != nil || valueType != jsonparser.Boolean { - if !boolean.Nullable { - return errNonNullableFieldValueIsNull - } - r.resolveNull(booleanBuf.Data) - return nil - } - booleanBuf.Data.WriteBytes(value) - r.exportField(ctx, boolean.Export, value) - return nil -} - -func (r *Resolver) resolveString(ctx *Context, str *String, data []byte, stringBuf *BufPair) error { - var ( - value []byte - valueType jsonparser.ValueType - err error - ) - - value, valueType, _, err = jsonparser.Get(data, str.Path...) - if err != nil || valueType != jsonparser.String { - if err == nil && str.UnescapeResponseJson { - switch valueType { - case jsonparser.Object, jsonparser.Array, jsonparser.Boolean, jsonparser.Number, jsonparser.Null: - stringBuf.Data.WriteBytes(value) - return nil - } - } - if value != nil && valueType != jsonparser.Null { - return fmt.Errorf("invalid value type '%s' for path %s, expecting string, got: %v. You can fix this by configuring this field as Int/Float/JSON Scalar", valueType, string(ctx.path()), string(value)) - } - if !str.Nullable { - return errNonNullableFieldValueIsNull - } - r.resolveNull(stringBuf.Data) - return nil - } - - if value == nil && !str.Nullable { - return errNonNullableFieldValueIsNull - } - - if str.UnescapeResponseJson { - value = bytes.ReplaceAll(value, []byte(`\"`), []byte(`"`)) - - // Do not modify values which was strings - // When the original value from upstream response was a plain string value `"hello"`, `"true"`, `"1"`, `"2.0"`, - // after getting it via jsonparser.Get we will get unquoted values `hello`, `true`, `1`, `2.0` - // which is not string anymore, so we need to quote it again - if !(bytes.ContainsAny(value, `{}[]`) && gjson.ValidBytes(value)) { - // wrap value in quotes to make it valid json - value = append(quote, append(value, quote...)...) - } - - stringBuf.Data.WriteBytes(value) - r.exportField(ctx, str.Export, value) - return nil - } - - value = r.renameTypeName(ctx, str, value) - - stringBuf.Data.WriteBytes(quote) - stringBuf.Data.WriteBytes(value) - stringBuf.Data.WriteBytes(quote) - r.exportField(ctx, str.Export, value) - return nil -} - -func (r *Resolver) renameTypeName(ctx *Context, str *String, typeName []byte) []byte { - if !str.IsTypeName { - return typeName - } - for i := range ctx.RenameTypeNames { - if bytes.Equal(ctx.RenameTypeNames[i].From, typeName) { - return ctx.RenameTypeNames[i].To - } - } - return typeName -} - -func (r *Resolver) resolveNull(b *fastbuffer.FastBuffer) { - b.WriteBytes(null) -} - -func (r *Resolver) addResolveError(ctx *Context, objectBuf *BufPair) { - locations, path := pool.BytesBuffer.Get(), pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(locations) - defer pool.BytesBuffer.Put(path) - - var pathBytes []byte - - locations.Write(lBrack) - locations.Write(lBrace) - locations.Write(quote) - locations.Write(literalLine) - locations.Write(quote) - locations.Write(colon) - locations.Write([]byte(strconv.Itoa(int(ctx.position.Line)))) - locations.Write(comma) - locations.Write(quote) - locations.Write(literalColumn) - locations.Write(quote) - locations.Write(colon) - locations.Write([]byte(strconv.Itoa(int(ctx.position.Column)))) - locations.Write(rBrace) - locations.Write(rBrack) - - if len(ctx.pathElements) > 0 { - path.Write(lBrack) - path.Write(quote) - path.Write(bytes.Join(ctx.pathElements, quotedComma)) - path.Write(quote) - path.Write(rBrack) - - pathBytes = path.Bytes() - } - - objectBuf.WriteErr(unableToResolveMsg, locations.Bytes(), pathBytes, nil) -} - -func (r *Resolver) resolveObject(ctx *Context, object *Object, data []byte, parentBuf *BufPair) (err error) { - if len(object.Path) != 0 { - data, _, _, _ = jsonparser.Get(data, object.Path...) - if len(data) == 0 || bytes.Equal(data, null) { - if object.Nullable { - r.resolveNull(parentBuf.Data) - return - } - - r.addResolveError(ctx, parentBuf) - return errNonNullableFieldValueIsNull - } - } - - if object.UnescapeResponseJson { - data = bytes.ReplaceAll(data, []byte(`\"`), []byte(`"`)) - } - - fieldBuf := r.getBufPair() - defer r.freeBufPair(fieldBuf) - - typeNameSkip := false - first := true - skipCount := 0 - for i := range object.Fields { - if object.Fields[i].SkipDirectiveDefined { - skip, err := jsonparser.GetBoolean(ctx.Variables, object.Fields[i].SkipVariableName) - if err == nil && skip { - skipCount++ - continue - } - } - - if object.Fields[i].IncludeDirectiveDefined { - include, err := jsonparser.GetBoolean(ctx.Variables, object.Fields[i].IncludeVariableName) - if err != nil || !include { - skipCount++ - continue - } - } - - if object.Fields[i].OnTypeNames != nil { - typeName, _, _, _ := jsonparser.Get(data, "__typename") - hasMatch := false - for _, onTypeName := range object.Fields[i].OnTypeNames { - if bytes.Equal(typeName, onTypeName) { - hasMatch = true - break - } - } - if !hasMatch { - typeNameSkip = true - continue - } - } - - if first { - fieldBuf.Data.WriteBytes(lBrace) - first = false - } else { - fieldBuf.Data.WriteBytes(comma) - } - fieldBuf.Data.WriteBytes(quote) - fieldBuf.Data.WriteBytes(object.Fields[i].Name) - fieldBuf.Data.WriteBytes(quote) - fieldBuf.Data.WriteBytes(colon) - ctx.addPathElement(object.Fields[i].Name) - ctx.setPosition(object.Fields[i].Position) - err = r.resolveNode(ctx, object.Fields[i].Value, data, fieldBuf) - ctx.removeLastPathElement() - if err != nil { - if errors.Is(err, errTypeNameSkipped) { - fieldBuf.Data.Reset() - r.resolveEmptyObject(parentBuf.Data) - return nil - } - if errors.Is(err, errNonNullableFieldValueIsNull) { - fieldBuf.Data.Reset() - r.MergeBufPairErrors(fieldBuf, parentBuf) - - if object.Nullable { - r.resolveNull(parentBuf.Data) - return nil - } - - // if field is of object type than we should not add resolve error here - if _, ok := object.Fields[i].Value.(*Object); !ok { - r.addResolveError(ctx, parentBuf) - } - } - - return - } - r.MergeBufPairs(fieldBuf, parentBuf, false) - } - allSkipped := len(object.Fields) != 0 && len(object.Fields) == skipCount - if allSkipped { - // return empty object if all fields have been skipped - r.resolveEmptyObject(parentBuf.Data) - return - } - if first { - if typeNameSkip { - r.resolveEmptyObject(parentBuf.Data) - return - } - if !object.Nullable { - r.addResolveError(ctx, parentBuf) - return errNonNullableFieldValueIsNull - } - r.resolveNull(parentBuf.Data) - return - } - parentBuf.Data.WriteBytes(rBrace) - return -} - -func (r *Resolver) MergeBufPairs(from, to *BufPair, prefixDataWithComma bool) { - r.MergeBufPairData(from, to, prefixDataWithComma) - r.MergeBufPairErrors(from, to) -} - -func (r *Resolver) MergeBufPairData(from, to *BufPair, prefixDataWithComma bool) { - if !from.HasData() { - return - } - if prefixDataWithComma { - to.Data.WriteBytes(comma) - } - to.Data.WriteBytes(from.Data.Bytes()) - from.Data.Reset() -} - -func (r *Resolver) MergeBufPairErrors(from, to *BufPair) { - if !from.HasErrors() { - return - } - if to.HasErrors() { - to.Errors.WriteBytes(comma) - } - to.Errors.WriteBytes(from.Errors.Bytes()) - from.Errors.Reset() -} - -func (r *Resolver) getBufPair() *BufPair { - return nil -} - -func (r *Resolver) freeBufPair(pair *BufPair) {} - -func (r *Resolver) getBufPairSlice() *[]*BufPair { - return nil -} - -func (r *Resolver) freeBufPairSlice(slice *[]*BufPair) {} - -func (r *Resolver) getErrChan() chan error { - return nil -} - -func (r *Resolver) freeErrChan(ch chan error) {} - -func (r *Resolver) getWaitGroup() *sync.WaitGroup { - return nil -} - -func (r *Resolver) freeWaitGroup(wg *sync.WaitGroup) {} diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 74736841f..8c2a8a46c 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -1792,6 +1792,14 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, OnTypeNames: [][]byte{[]byte("ConcreteOne")}, }, + { + Name: []byte("__typename"), + Value: &String{ + Nullable: true, + Path: []string{"__typename"}, + }, + OnTypeNames: [][]byte{[]byte("ConcreteOne")}, + }, }, }, }, @@ -1806,7 +1814,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("interface response with matching type", testFn(false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return obj(`{"thing":{"id":"1","abstractThing":{"__typename":"ConcreteOne","name":"foo"}}}`), Context{ctx: context.Background()}, - `{"data":{"thing":{"id":"1","abstractThing":{"name":"foo"}}}}` + `{"data":{"thing":{"id":"1","abstractThing":{"name":"foo","__typename":"ConcreteOne"}}}}` })) t.Run("interface response with not matching type", testFn(false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { diff --git a/v2/pkg/engine/resolve/v2load.go b/v2/pkg/engine/resolve/v2load.go index 666588edb..c312e8f9a 100644 --- a/v2/pkg/engine/resolve/v2load.go +++ b/v2/pkg/engine/resolve/v2load.go @@ -5,14 +5,14 @@ import ( "context" "fmt" "io" - "unsafe" + "runtime" + "runtime/debug" + "sync" "github.com/pkg/errors" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/singleflight" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" + "golang.org/x/sync/errgroup" ) type V2Loader struct { @@ -20,7 +20,7 @@ type V2Loader struct { dataRoot int errorsRoot int ctx *Context - sf *singleflight.Group + sf *Group enableSingleFlight bool path []string } @@ -616,8 +616,14 @@ func (l *V2Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight b if !l.enableSingleFlight || disallowSingleFlight { return source.Load(ctx, input, out) } - key := *(*string)(unsafe.Pointer(&input)) - maybeSharedBuf, err, _ := l.sf.Do(key, func() (interface{}, error) { + keyGen := pool.Hash64.Get() + defer pool.Hash64.Put(keyGen) + _, err := keyGen.Write(input) + if err != nil { + return errors.WithStack(err) + } + key := keyGen.Sum64() + data, err, _ := l.sf.Do(key, func() ([]byte, error) { singleBuffer := pool.BytesBuffer.Get() defer pool.BytesBuffer.Put(singleBuffer) err := source.Load(ctx, input, singleBuffer) @@ -632,7 +638,193 @@ func (l *V2Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight b if err != nil { return errors.WithStack(err) } - sharedBuf := maybeSharedBuf.([]byte) - _, err = out.Write(sharedBuf) + _, err = out.Write(data) return errors.WithStack(err) } + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val []byte + err error + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result +} + +type Group struct { + mu sync.Mutex // protects m + m map[uint64]*call // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val []byte + Err error + Shared bool +} + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value []byte + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v []byte) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group) Do(key uint64, fn func() ([]byte, error)) (v []byte, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[uint64]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group) DoChan(key uint64, fn func() ([]byte, error)) <-chan Result { + ch := make(chan Result, 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[uint64]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call{chans: []chan<- Result{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group) doCall(c *call, key uint64, fn func() ([]byte, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r.([]byte)) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group) Forget(key uint64) { + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() +} diff --git a/v2/pkg/federation/schema.go b/v2/pkg/federation/schema.go index a7669e690..e6f540e8a 100644 --- a/v2/pkg/federation/schema.go +++ b/v2/pkg/federation/schema.go @@ -44,7 +44,7 @@ func (s *schemaBuilder) buildFederationSchema(baseSchema, serviceSDL string) (st } func (s *schemaBuilder) extendQueryTypeWithFederationFields(schema string, hasEntities bool) string { - doc := ast.NewDocument() + doc := ast.NewSmallDocument() doc.Input.ResetInputString(schema) parser := astparser.NewParser() report := &operationreport.Report{} @@ -120,7 +120,7 @@ func (s *schemaBuilder) extendQueryType(doc *ast.Document, ref int, hasEntities // _service: _Service! func (s *schemaBuilder) entityUnionTypes(serviceSDL string) []string { - doc := ast.NewDocument() + doc := ast.NewSmallDocument() doc.Input.ResetInputString(serviceSDL) parser := astparser.NewParser() report := &operationreport.Report{}