diff --git a/circuitbreaker.go b/circuitbreaker.go index dd9bc55..9b9cdd1 100644 --- a/circuitbreaker.go +++ b/circuitbreaker.go @@ -302,7 +302,7 @@ func (cb *Breaker) Success() { cb.backoffLock.Unlock() state := cb.state() - if state == halfopen { + if state != closed { cb.Reset() } atomic.StoreInt64(&cb.consecFailures, 0) diff --git a/circuitbreaker_test.go b/circuitbreaker_test.go index 79aa143..2247b14 100644 --- a/circuitbreaker_test.go +++ b/circuitbreaker_test.go @@ -430,6 +430,50 @@ func TestRateBreakerResets(t *testing.T) { } } +func TestRateBreakerResetsOnSuccess(t *testing.T) { + serviceError := fmt.Errorf("service error") + + called := 0 + success := false + circuit := func() error { + if called < 4 { + called++ + return serviceError + } + success = true + return nil + } + + c := clock.NewMock() + cb := NewRateBreaker(0.5, 4) + cb.Clock = c + var err error + for i := 0; i < 4; i++ { + err = cb.Call(circuit, 0) + if err == nil { + t.Fatal("Expected cb to return an error (closed breaker, service failure)") + } else if err != serviceError { + t.Fatalf("Expected cb to return error from service; got %v", err) + } + } + + err = cb.Call(circuit, 0) + if err == nil { + t.Fatal("Expected cb to return an error (open breaker)") + } else if err != ErrBreakerOpen { + t.Fatalf("Expected cb to return open open breaker error; got %v", err) + } + + cb.Success() + err = cb.Call(circuit, 0) + if err != nil { + t.Fatalf("Expected cb to be successful after Success() call; got %v", err) + } + if !success { + t.Fatal("Expected cb to have been reset after Success() call") + } +} + func TestNeverRetryAfterBackoffStops(t *testing.T) { cb := NewBreakerWithOptions(&Options{ BackOff: &backoff.StopBackOff{},