From f856ab87fbe45f15f2b03d770b8f46324f2ae07e Mon Sep 17 00:00:00 2001 From: Janos Guljas Date: Thu, 27 May 2021 23:32:58 +0200 Subject: [PATCH] Pass a context to the calling function for cancellation --- singleflight.go | 37 +++++++--- singleflight_test.go | 157 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 174 insertions(+), 20 deletions(-) diff --git a/singleflight.go b/singleflight.go index 0e7d31c..45fe9a3 100644 --- a/singleflight.go +++ b/singleflight.go @@ -20,15 +20,18 @@ type Group struct { mu sync.Mutex // protects calls } -// Do executes and returns the results of the given function, making -// sure that only one execution is in-flight for a given key at a -// time. If a duplicate comes in, the duplicate caller waits for the -// original to complete and receives the same results. -// Passed context terminates the execution of Do function, not the passed -// function fn. If there are multiple callers, context passed to one caller -// does not effect the execution and returned values of others. +// Do executes and returns the results of the given function, making sure that +// only one execution is in-flight for a given key at a time. If a duplicate +// comes in, the duplicate caller waits for the original to complete and +// receives the same results. +// +// The context passed to the fn function is a new context which is canceled when +// contexts from all callers are canceled, so that no caller is expecting the +// result. If there are multiple callers, context passed to one caller does not +// effect the execution and returned values of others. +// // The return value shared indicates whether v was given to multiple callers. -func (g *Group) Do(ctx context.Context, key string, fn func() (interface{}, error)) (v interface{}, shared bool, err error) { +func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, shared bool, err error) { g.mu.Lock() if g.calls == nil { g.calls = make(map[string]*call) @@ -36,19 +39,24 @@ func (g *Group) Do(ctx context.Context, key string, fn func() (interface{}, erro if c, ok := g.calls[key]; ok { c.shared = true + c.counter++ g.mu.Unlock() return g.wait(ctx, key, c) } + callCtx, cancel := context.WithCancel(context.Background()) + c := &call{ - done: make(chan struct{}), + done: make(chan struct{}), + cancel: cancel, + counter: 1, } g.calls[key] = c g.mu.Unlock() go func() { - c.val, c.err = fn() + c.val, c.err = fn(callCtx) close(c.done) }() @@ -65,6 +73,10 @@ func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, s err = ctx.Err() } g.mu.Lock() + c.counter-- + if c.counter == 0 { + c.cancel() + } if !c.forgotten { delete(g.calls, key) } @@ -99,4 +111,9 @@ type call struct { // shared indicates if results val and err are passed to multiple callers. shared bool + + // Number of callers that are waiting for the result. + counter int + // Cancel function for the context passed to the executing function. + cancel context.CancelFunc } diff --git a/singleflight_test.go b/singleflight_test.go index 0a2bc74..dadc495 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -6,9 +6,13 @@ package singleflight_test import ( + "bytes" "context" "errors" + "fmt" + "runtime/pprof" "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -21,7 +25,7 @@ func TestDo(t *testing.T) { var g singleflight.Group want := "val" - got, shared, err := g.Do(context.Background(), "key", func() (interface{}, error) { + got, shared, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) { return want, nil }) if err != nil { @@ -38,7 +42,7 @@ func TestDo(t *testing.T) { func TestDo_error(t *testing.T) { var g singleflight.Group wantErr := errors.New("test error") - got, _, err := g.Do(context.Background(), "key", func() (interface{}, error) { + got, _, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) { return nil, wantErr }) if err != wantErr { @@ -64,7 +68,7 @@ func TestDo_multipleCalls(t *testing.T) { for i := 0; i < n; i++ { go func(i int) { defer wg.Done() - got[i], shared[i], err[i] = g.Do(context.Background(), "key", func() (interface{}, error) { + got[i], shared[i], err[i] = g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) { atomic.AddInt32(&counter, 1) time.Sleep(100 * time.Millisecond) return want, nil @@ -95,7 +99,7 @@ func TestDo_callRemoval(t *testing.T) { wantPrefix := "val" counter := 0 - fn := func() (interface{}, error) { + fn := func(_ context.Context) (interface{}, error) { counter++ return wantPrefix + strconv.Itoa(counter), nil } @@ -124,6 +128,9 @@ func TestDo_callRemoval(t *testing.T) { } func TestDo_cancelContext(t *testing.T) { + done := make(chan struct{}) + defer close(done) + var g singleflight.Group want := "val" @@ -133,8 +140,11 @@ func TestDo_cancelContext(t *testing.T) { cancel() }() start := time.Now() - got, shared, err := g.Do(ctx, "key", func() (interface{}, error) { - time.Sleep(time.Second) + got, shared, err := g.Do(ctx, "key", func(_ context.Context) (interface{}, error) { + select { + case <-time.After(time.Second): + case <-done: + } return want, nil }) if d := time.Since(start); d < 100*time.Microsecond || d > time.Second { @@ -152,11 +162,17 @@ func TestDo_cancelContext(t *testing.T) { } func TestDo_cancelContextSecond(t *testing.T) { + done := make(chan struct{}) + defer close(done) + var g singleflight.Group want := "val" - fn := func() (interface{}, error) { - time.Sleep(time.Second) + fn := func(_ context.Context) (interface{}, error) { + select { + case <-time.After(time.Second): + case <-done: + } return want, nil } go func() { @@ -186,16 +202,22 @@ func TestDo_cancelContextSecond(t *testing.T) { } func TestForget(t *testing.T) { + done := make(chan struct{}) + defer close(done) + var g singleflight.Group wantPrefix := "val" var counter uint64 firstCall := make(chan struct{}) - fn := func() (interface{}, error) { + fn := func(_ context.Context) (interface{}, error) { c := atomic.AddUint64(&counter, 1) if c == 1 { close(firstCall) - time.Sleep(time.Second) + select { + case <-time.After(time.Second): + case <-done: + } } return wantPrefix + strconv.FormatUint(c, 10), nil } @@ -220,3 +242,118 @@ func TestForget(t *testing.T) { t.Errorf("got value %v, want %v", got, want) } } + +func TestDo_multipleCallsCanceled(t *testing.T) { + const n = 5 + + for lastCall := 0; lastCall < n; lastCall++ { + lastCall := lastCall + t.Run(fmt.Sprintf("last call %v of %v", lastCall, n), func(t *testing.T) { + done := make(chan struct{}) + defer close(done) + + var g singleflight.Group + + var counter int32 + + fnCalled := make(chan struct{}) + fnErrChan := make(chan error) + var mu sync.Mutex + contexts := make([]context.Context, n) + cancelFuncs := make([]context.CancelFunc, n) + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func(i int) { + defer wg.Done() + ctx, cancel := context.WithCancel(context.Background()) + mu.Lock() + contexts[i] = ctx + cancelFuncs[i] = cancel + mu.Unlock() + _, _, _ = g.Do(ctx, "key", func(ctx context.Context) (interface{}, error) { + atomic.AddInt32(&counter, 1) + close(fnCalled) + var err error + select { + case <-ctx.Done(): + err = ctx.Err() + if err == nil { + err = errors.New("got unexpected error from context") + } + case <-time.After(10 * time.Second): + err = errors.New("unexpected timeout, context not canceled") + case <-done: + } + + fnErrChan <- err + + return nil, nil + }) + }(i) + } + select { + case <-fnCalled: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for function to be called") + } + + // Ensure that n goroutines are waiting at the select case in Group.wait. + // Update the line number on changes. + waitStacks(t, "resenje.org/singleflight/singleflight.go:68", n, 2*time.Second) + + // cancel all but one calls + for i := 0; i < n; i++ { + if i == lastCall { + continue + } + mu.Lock() + cancelFuncs[i]() + <-contexts[i].Done() + mu.Unlock() + } + + select { + case err := <-fnErrChan: + t.Fatalf("got unexpected error in function: %v", err) + default: + } + + // Ensure that only the last goroutine is waiting at the select case in Group.wait. + // Update the line number on changes. + waitStacks(t, "resenje.org/singleflight/singleflight.go:68", 1, 2*time.Second) + + mu.Lock() + cancelFuncs[lastCall]() + mu.Unlock() + + wg.Wait() + + select { + case err := <-fnErrChan: + if err != context.Canceled { + t.Fatalf("got unexpected error in function %v, want %v", err, context.Canceled) + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for the error") + } + }) + } +} + +func waitStacks(t *testing.T, loc string, count int, timeout time.Duration) { + t.Helper() + + for deadline := time.Now().Add(timeout); time.Now().Before(deadline); { + // Ensure that exact n goroutines are waiting at the desired stack trace. + var buf bytes.Buffer + if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil { + t.Fatal(err) + } + c := strings.Count(buf.String(), loc) + if c == count { + break + } + time.Sleep(10 * time.Millisecond) + } +}