From acdaf47598762aa20712abd4eb38250bb10cfd33 Mon Sep 17 00:00:00 2001 From: Jens Neuse Date: Mon, 2 Dec 2024 12:10:22 +0100 Subject: [PATCH] feat: execute subscription writes on main goroutine in synchronous resolve subscriptions --- v2/pkg/engine/resolve/resolve.go | 61 +++++++++++++++++++-------- v2/pkg/engine/resolve/resolve_test.go | 5 +++ 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 57207633b..7db2ac172 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -280,6 +280,10 @@ type sub struct { id SubscriptionIdentifier completed chan struct{} lastWrite time.Time + // executor is an optional argument that allows us to "schedule" the execution of an update on another thread + // e.g. if we're using SSE/Multipart Fetch, we can run the execution on the goroutine of the http request + // this ensures that ctx cancellation works properly when a client disconnects + executor chan func() } func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput []byte) { @@ -495,6 +499,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) id: add.id, completed: add.completed, lastWrite: time.Now(), + executor: add.executor, } if add.ctx.ExecutionOptions.SendHeartbeat { r.heartbeatSubscriptions[add.ctx] = s @@ -687,6 +692,9 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { trig.inFlight = wg for c, s := range trig.subscriptions { c, s := c, s + if err := c.ctx.Err(); err != nil { + continue // no need to schedule an event update when the client already disconnected + } skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf) if err != nil { r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer) @@ -695,12 +703,22 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { if skip { continue } - wg.Add(1) - go func() { - defer wg.Done() + fn := func() { r.executeSubscriptionUpdate(c, s, data) - }() + } + go func(fn func()) { + defer wg.Done() + if s.executor != nil { + select { + case <-r.ctx.Done(): + case <-c.ctx.Done(): + case s.executor <- fn: + } + } else { + fn() + } + }(fn) } } @@ -825,6 +843,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) } completed := make(chan struct{}) + executor := make(chan func()) select { case <-r.ctx.Done(): return r.ctx.Err() @@ -838,25 +857,32 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ writer: writer, id: id, completed: completed, + executor: executor, }, }: } - select { - case <-r.ctx.Done(): - // the resolver ctx was canceled - // this will trigger the shutdown of the trigger (on another goroutine) - // as such, we need to wait for the trigger to be shutdown - // otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer - // as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done +Loop: // execute fn on the main thread of the incoming request until ctx is done + for { select { - case <-completed: - return r.ctx.Err() - case <-time.After(time.Second * 5): - return r.ctx.Err() + case <-r.ctx.Done(): + // the resolver ctx was canceled + // this will trigger the shutdown of the trigger (on another goroutine) + // as such, we need to wait for the trigger to be shutdown + // otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer + // as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done + select { + case <-completed: + return r.ctx.Err() + case <-time.After(time.Second * 5): + return r.ctx.Err() + case <-ctx.Context().Done(): + return ctx.Context().Err() + } case <-ctx.Context().Done(): - return ctx.Context().Err() + break Loop + case fn := <-executor: + fn() } - case <-ctx.Context().Done(): } if r.options.Debug { fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID) @@ -1008,6 +1034,7 @@ type addSubscription struct { writer SubscriptionResponseWriter id SubscriptionIdentifier completed chan struct{} + executor chan func() } type subscriptionEventKind int diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 2b94fe811..a198e13a1 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -5201,6 +5201,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { resolver := newResolver(c) ctx := &Context{ + ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"id":1}`)), } @@ -5296,6 +5297,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { resolver := newResolver(c) ctx := &Context{ + ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"id":2}`)), } @@ -5389,6 +5391,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { resolver := newResolver(c) ctx := &Context{ + ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"ids":[1,2]}`)), } @@ -5487,6 +5490,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { resolver := newResolver(c) ctx := &Context{ + ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"ids":["2","3"]}`)), } @@ -5595,6 +5599,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) { resolver := newResolver(c) ctx := &Context{ + ctx: context.Background(), Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)), }