diff --git a/semgroup.go b/semgroup.go index cd08311..7582d0a 100644 --- a/semgroup.go +++ b/semgroup.go @@ -10,6 +10,7 @@ package semgroup import ( "context" + "errors" "fmt" "strings" "sync" @@ -98,3 +99,21 @@ func (e multiError) ErrorOrNil() error { return e } + +func (e multiError) Is(target error) bool { + for _, err := range e { + if errors.Is(err, target) { + return true + } + } + return false +} + +func (e multiError) As(target interface{}) bool { + for _, err := range e { + if errors.As(err, target) { + return true + } + } + return false +} diff --git a/semgroup_test.go b/semgroup_test.go index dc15c27..255206d 100644 --- a/semgroup_test.go +++ b/semgroup_test.go @@ -3,6 +3,7 @@ package semgroup import ( "context" "errors" + "os" "sync" "testing" ) @@ -91,3 +92,68 @@ func TestGroup_deadlock(t *testing.T) { t.Errorf("error should be:\n%s\ngot:\n%s\n", wantErr, err.Error()) } } + +func TestGroup_multiple_tasks_errors_Is(t *testing.T) { + ctx := context.Background() + g := NewGroup(ctx, 1) + + var ( + fooErr = errors.New("foo") + barErr = errors.New("bar") + bazErr = errors.New("baz") + ) + + g.Go(func() error { return fooErr }) + g.Go(func() error { return nil }) + g.Go(func() error { return barErr }) + g.Go(func() error { return nil }) + + err := g.Wait() + if err == nil { + t.Fatalf("g.Wait() should return an error") + } + + if !errors.Is(err, fooErr) { + t.Errorf("error should be contained %v\n", fooErr) + } + + if !errors.Is(err, barErr) { + t.Errorf("error should be contained %v\n", barErr) + } + + if errors.Is(err, bazErr) { + t.Errorf("error should not be contained %v\n", bazErr) + } +} + +type foobarErr struct{ str string } + +func (e foobarErr) Error() string { + return "foobar" +} + +func TestGroup_multiple_tasks_errors_As(t *testing.T) { + ctx := context.Background() + g := NewGroup(ctx, 1) + + g.Go(func() error { return foobarErr{"baz"} }) + g.Go(func() error { return nil }) + + err := g.Wait() + if err == nil { + t.Fatalf("g.Wait() should return an error") + } + + var ( + fbe foobarErr + pe *os.PathError + ) + + if !errors.As(err, &fbe) { + t.Error("error should be matched foobarErr") + } + + if errors.As(err, &pe) { + t.Error("error should not be matched os.PathError") + } +}