Skip to content

Commit

Permalink
Merge pull request #91 from MicahParks/tolerate_startup_fail
Browse files Browse the repository at this point in the history
Add TolerateInitialJWKHTTPError option
  • Loading branch information
MicahParks authored May 31, 2023
2 parents a65b424 + 62e682a commit 94c4af8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
9 changes: 8 additions & 1 deletion get.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) {

err = jwks.refresh()
if err != nil {
return nil, err
if options.TolerateInitialJWKHTTPError {
if jwks.refreshErrorHandler != nil {
jwks.refreshErrorHandler(err)
}
jwks.keys = make(map[string]parsedJWK)
} else {
return nil, err
}
}

if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
Expand Down
8 changes: 8 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ type Options struct {
// ResponseExtractor consumes a *http.Response and produces the raw JSON for the JWKS. By default, the
// ResponseExtractorStatusOK function is used. The default behavior changed in v1.4.0.
ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)

// TolerateInitialJWKHTTPError will tolerate any error from the initial HTTP JWKS request. If an error occurs,
// the RefreshErrorHandler will be given the error. The program will continue to run as if the error did not occur
// and a valid JWK Set with no keys was received in the response. This allows for the background goroutine to
// request the JWKS at a later time.
//
// It does not make sense to mark this field as true unless the background refresh goroutine is active.
TolerateInitialJWKHTTPError bool
}

// MultipleOptions is used to configure the behavior when multiple JWKS are used by MultipleJWKS.
Expand Down
50 changes: 50 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keyfunc_test

import (
"errors"
"github.com/golang-jwt/jwt/v5"
"net/http"
"net/http/httptest"
"sync"
Expand Down Expand Up @@ -77,3 +78,52 @@ func TestResponseExtractorStatusAny(t *testing.T) {
t.Fatalf("Expected error no error for 500 status code.\nError: %s", err)
}
}

func TestTolerateStartupFailure(t *testing.T) {
var mux sync.Mutex
shouldError := true

server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
mux.Lock()
defer mux.Unlock()
if shouldError {
writer.WriteHeader(http.StatusInternalServerError)
} else {
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(jwksJSON))
}
}))
defer server.Close()

options := keyfunc.Options{
TolerateInitialJWKHTTPError: true,
RefreshUnknownKID: true,
}
jwks, err := keyfunc.Get(server.URL, options)
if err != nil {
t.Fatalf("TolerateInitialJWKHTTPError should not return error on bad HTTP startup.\nError: %s", err)
}

if len(jwks.ReadOnlyKeys()) != 0 {
t.Fatalf("Expected JWK Set to have no keys.")
}

const token = "eyJhbGciOiJFUzI1NiIsInR5cCIgOiAiSldUIiwia2lkIiA6ICJDR3QwWldTNExjNWZhaUtTZGkwdFUwZmpDQWR2R1JPUVJHVTlpUjd0VjBBIn0.eyJleHAiOjE2MTU0MDY4NjEsImlhdCI6MTYxNTQwNjgwMSwianRpIjoiYWVmOWQ5YjItN2EyYy00ZmQ4LTk4MzktODRiMzQ0Y2VmYzZhIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwL2F1dGgvcmVhbG1zL21hc3RlciIsImF1ZCI6ImFjY291bnQiLCJzdWIiOiJhZDEyOGRmMS0xMTQwLTRlNGMtYjA5Ny1hY2RjZTcwNWJkOWIiLCJ0eXAiOiJCZWFyZXIiLCJhenAiOiJ0b2tlbmRlbG1lIiwiYWNyIjoiMSIsInJlYWxtX2FjY2VzcyI6eyJyb2xlcyI6WyJvZmZsaW5lX2FjY2VzcyIsInVtYV9hdXRob3JpemF0aW9uIl19LCJyZXNvdXJjZV9hY2Nlc3MiOnsiYWNjb3VudCI6eyJyb2xlcyI6WyJtYW5hZ2UtYWNjb3VudCIsIm1hbmFnZS1hY2NvdW50LWxpbmtzIiwidmlldy1wcm9maWxlIl19fSwic2NvcGUiOiJlbWFpbCBwcm9maWxlIiwiY2xpZW50SG9zdCI6IjE3Mi4yMC4wLjEiLCJjbGllbnRJZCI6InRva2VuZGVsbWUiLCJlbWFpbF92ZXJpZmllZCI6ZmFsc2UsInByZWZlcnJlZF91c2VybmFtZSI6InNlcnZpY2UtYWNjb3VudC10b2tlbmRlbG1lIiwiY2xpZW50QWRkcmVzcyI6IjE3Mi4yMC4wLjEifQ.iQ77QGoPDNjR2oWLu3zT851mswP8J-h_nrGhs3fpa_tFB3FT1deKPGkjef9JOTYFI-CIVxdCFtW3KODOaw9Nrw"
_, err = jwt.Parse(token, jwks.Keyfunc)
if !errors.Is(err, keyfunc.ErrKIDNotFound) {
t.Fatalf("Expected error to be ErrKIDNotFound.\nError: %s", err)
}

mux.Lock()
shouldError = false
mux.Unlock()

_, err = jwt.Parse(token, jwks.Keyfunc)
if !errors.Is(err, jwt.ErrTokenExpired) {
t.Fatalf("Expected error to be jwt.ErrTokenExpired.\nError: %s", err)
}

if len(jwks.ReadOnlyKeys()) == 0 {
t.Fatalf("Expected JWK Set to have keys.")
}
}

0 comments on commit 94c4af8

Please sign in to comment.