Skip to content

Commit

Permalink
Make panic behavior match x/sync package (#13)
Browse files Browse the repository at this point in the history
If the original caller panics, then have all callers panic,
to match the x/sync package behavior.
  • Loading branch information
bduffany authored Jun 25, 2024
1 parent bfebb53 commit 9843622
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
42 changes: 41 additions & 1 deletion singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,32 @@ package singleflight

import (
"context"
"fmt"
"runtime/debug"
"sync"
)

// A panicError is an arbitrary value recovered from a panic
// with the stack trace during the execution of given function.
type panicError struct {
value interface{}
stack []byte
}

// Error implements error interface.
func (p *panicError) Error() string {
return fmt.Sprintf("%v\n\n%s", p.value, p.stack)
}

func (p *panicError) Unwrap() error {
err, ok := p.value.(error)
if !ok {
return nil
}

return err
}

// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
// K is the type of the key used for deduplication, and V is
Expand Down Expand Up @@ -62,19 +85,27 @@ func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context
g.mu.Unlock()

go func() {
defer func() {
if v := recover(); v != nil {
c.panicErr = &panicError{value: v, stack: debug.Stack()}
}
close(c.done)
}()

c.val, c.err = fn(callCtx)
close(c.done)
}()

return g.wait(ctx, key, c)
}

// wait for function passed to Do to finish or context to be done.
func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared bool, err error) {
var panicErr *panicError
select {
case <-c.done:
v = c.val
err = c.err
panicErr = c.panicErr
case <-ctx.Done():
err = ctx.Err()
}
Expand All @@ -86,6 +117,11 @@ func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared
}
shared = c.shared
g.mu.Unlock()

if panicErr != nil {
panic(panicErr)
}

return v, shared, err
}

Expand All @@ -104,6 +140,10 @@ type call[V any] struct {
val V
err error

// panicError wraps the value passed to panic() if the function call panicked.
// val and err should be ignored if this is non-nil.
panicErr *panicError

// done channel signals that the function call is done.
done chan struct{}

Expand Down
47 changes: 47 additions & 0 deletions singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,53 @@ func TestDo_callDoAfterCancellation(t *testing.T) {
}
}

func TestDo_panic(t *testing.T) {
// Start a few goroutines all waiting on the same call.
// The call just waits for a short duration then panics.
// Each goroutine will recover from the panic, and send the recovered
// value on a channel. At the end, we make sure that every goroutine
// panicked, not just the first goroutine that triggered the call.
// This matches the behavior of x/sync/singleflight.

const numGoroutines = 3
const panicMessage = "test-panic-message"

recoveries := make(chan any, numGoroutines)
ctx := context.Background()
var g singleflight.Group[string, string]
for i := 0; i < numGoroutines; i++ {
go func() {
defer func() {
recoveries <- recover()
}()

g.Do(ctx, "key", func(_ context.Context) (string, error) {

Check failure on line 304 in singleflight_test.go

View workflow job for this annotation

GitHub Actions / Build (ubuntu-latest)

Error return value of `g.Do` is not checked (errcheck)

Check failure on line 304 in singleflight_test.go

View workflow job for this annotation

GitHub Actions / Build (macos-latest)

Error return value of `g.Do` is not checked (errcheck)
time.Sleep(200 * time.Millisecond)
panic(panicMessage)
})
t.Errorf("This line should not be reached - Do() should have panicked")
}()
}

for i := 0; i < numGoroutines; i++ {
panicValue := <-recoveries
if err, ok := panicValue.(error); !ok || !strings.Contains(err.Error(), panicMessage) {
t.Errorf("got unexpected panic value %+#v", panicValue)
}
}

// The work for "key" should be complete, and we should be able to
// start a new call for the same key without panicking.

const want = "hello"
got, shared, err := g.Do(ctx, "key", func(_ context.Context) (string, error) {
return want, nil
})
if got != want || shared || err != nil {
t.Errorf("unexpected result (value=%v, shared=%v, err=%v)", got, shared, err)
}
}

func TestForget(t *testing.T) {
done := make(chan struct{})
defer close(done)
Expand Down

0 comments on commit 9843622

Please sign in to comment.