From 98436221d63b1311d3b69c868fc4d9ca5bb1c0a0 Mon Sep 17 00:00:00 2001 From: Brandon Duffany Date: Tue, 25 Jun 2024 04:44:11 -0400 Subject: [PATCH] Make panic behavior match x/sync package (#13) If the original caller panics, then have all callers panic, to match the x/sync package behavior. --- singleflight.go | 42 ++++++++++++++++++++++++++++++++++++++- singleflight_test.go | 47 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/singleflight.go b/singleflight.go index 820f0e1..d8c074d 100644 --- a/singleflight.go +++ b/singleflight.go @@ -10,9 +10,32 @@ package singleflight import ( "context" + "fmt" + "runtime/debug" "sync" ) +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func (p *panicError) Unwrap() error { + err, ok := p.value.(error) + if !ok { + return nil + } + + return err +} + // Group represents a class of work and forms a namespace in // which units of work can be executed with duplicate suppression. // K is the type of the key used for deduplication, and V is @@ -62,8 +85,14 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context g.mu.Unlock() go func() { + defer func() { + if v := recover(); v != nil { + c.panicErr = &panicError{value: v, stack: debug.Stack()} + } + close(c.done) + }() + c.val, c.err = fn(callCtx) - close(c.done) }() return g.wait(ctx, key, c) @@ -71,10 +100,12 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context // wait for function passed to Do to finish or context to be done. func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared bool, err error) { + var panicErr *panicError select { case <-c.done: v = c.val err = c.err + panicErr = c.panicErr case <-ctx.Done(): err = ctx.Err() } @@ -86,6 +117,11 @@ func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared } shared = c.shared g.mu.Unlock() + + if panicErr != nil { + panic(panicErr) + } + return v, shared, err } @@ -104,6 +140,10 @@ type call[V any] struct { val V err error + // panicError wraps the value passed to panic() if the function call panicked. + // val and err should be ignored if this is non-nil. + panicErr *panicError + // done channel signals that the function call is done. done chan struct{} diff --git a/singleflight_test.go b/singleflight_test.go index 475a6bd..40d752b 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -281,6 +281,53 @@ func TestDo_callDoAfterCancellation(t *testing.T) { } } +func TestDo_panic(t *testing.T) { + // Start a few goroutines all waiting on the same call. + // The call just waits for a short duration then panics. + // Each goroutine will recover from the panic, and send the recovered + // value on a channel. At the end, we make sure that every goroutine + // panicked, not just the first goroutine that triggered the call. + // This matches the behavior of x/sync/singleflight. + + const numGoroutines = 3 + const panicMessage = "test-panic-message" + + recoveries := make(chan any, numGoroutines) + ctx := context.Background() + var g singleflight.Group[string, string] + for i := 0; i < numGoroutines; i++ { + go func() { + defer func() { + recoveries <- recover() + }() + + g.Do(ctx, "key", func(_ context.Context) (string, error) { + time.Sleep(200 * time.Millisecond) + panic(panicMessage) + }) + t.Errorf("This line should not be reached - Do() should have panicked") + }() + } + + for i := 0; i < numGoroutines; i++ { + panicValue := <-recoveries + if err, ok := panicValue.(error); !ok || !strings.Contains(err.Error(), panicMessage) { + t.Errorf("got unexpected panic value %+#v", panicValue) + } + } + + // The work for "key" should be complete, and we should be able to + // start a new call for the same key without panicking. + + const want = "hello" + got, shared, err := g.Do(ctx, "key", func(_ context.Context) (string, error) { + return want, nil + }) + if got != want || shared || err != nil { + t.Errorf("unexpected result (value=%v, shared=%v, err=%v)", got, shared, err) + } +} + func TestForget(t *testing.T) { done := make(chan struct{}) defer close(done)