From db682afdff015ded9e79e26e2ae05102cd7c82e4 Mon Sep 17 00:00:00 2001 From: Janos Guljas Date: Wed, 23 Aug 2023 12:33:38 +0200 Subject: [PATCH] Ensure no function calls after one context is cancelled --- .github/workflows/go.yml | 4 ++-- singleflight.go | 11 +-------- singleflight_test.go | 52 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index ede1b1b..967e6a0 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,7 +14,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.16 + go-version: '1.21' - name: Checkout uses: actions/checkout@v1 @@ -34,7 +34,7 @@ jobs: - name: Lint uses: golangci/golangci-lint-action@v2 with: - version: v1.40.1 + version: v1.54.2 args: --timeout 10m - name: Vet diff --git a/singleflight.go b/singleflight.go index 45fe9a3..3492b85 100644 --- a/singleflight.go +++ b/singleflight.go @@ -5,7 +5,7 @@ // Package singleflight provides a duplicate function call suppression // mechanism similar to golang.org/x/sync/singleflight with support -// for context cancelation. +// for context cancellation. package singleflight import ( @@ -76,8 +76,6 @@ func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, s c.counter-- if c.counter == 0 { c.cancel() - } - if !c.forgotten { delete(g.calls, key) } g.mu.Unlock() @@ -89,9 +87,6 @@ func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, s // an earlier call to complete. func (g *Group) Forget(key string) { g.mu.Lock() - if c, ok := g.calls[key]; ok { - c.forgotten = true - } delete(g.calls, key) g.mu.Unlock() } @@ -105,10 +100,6 @@ type call struct { // done channel signals that the function call is done. done chan struct{} - // forgotten indicates whether Forget was called with this call's key - // while the call was still in flight. - forgotten bool - // shared indicates if results val and err are passed to multiple callers. shared bool diff --git a/singleflight_test.go b/singleflight_test.go index dadc495..68bf8b0 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -201,6 +201,58 @@ func TestDo_cancelContextSecond(t *testing.T) { } } +func TestDo_callDoAfterCancellation(t *testing.T) { + done := make(chan struct{}) + defer close(done) + + var g singleflight.Group + + callCounter := new(atomic.Uint64) + fn := func(_ context.Context) (interface{}, error) { + callCounter.Add(1) + select { + case <-time.After(time.Second): + case <-done: + } + return "", nil + } + + go func() { + // keep the function call active for long period (1 second) + if _, _, err := g.Do(context.Background(), "key", fn); err != nil { + panic(err) + } + }() + + { // make another call that is canceled shortly (100 milliseconds) + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, _, err := g.Do(ctx, "key", fn) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + } + + want := uint64(1) + + if got := callCounter.Load(); got != want { + t.Errorf("got call counter %v, want %v", got, want) + } + + { // make another call after the previous call cancellation + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + _, _, err := g.Do(ctx, "key", fn) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + } + + if got := callCounter.Load(); got != want { + t.Errorf("got call counter %v, want %v", got, want) + } +} + func TestForget(t *testing.T) { done := make(chan struct{}) defer close(done)