From 16db93fb36ce987dbd7ffc712219fc8b6ee06267 Mon Sep 17 00:00:00 2001 From: Micah Parks Date: Wed, 6 Apr 2022 10:55:37 -0400 Subject: [PATCH] Remove comments --- checksum_test.go | 13 ++----------- ecdsa.go | 10 ---------- eddsa.go | 3 --- get.go | 49 ++---------------------------------------------- given.go | 4 ---- given_test.go | 43 ------------------------------------------ jwks.go | 15 --------------- jwks_test.go | 47 ---------------------------------------------- keyfunc.go | 3 --- oct.go | 3 --- options.go | 2 ++ override_test.go | 25 +----------------------- rsa.go | 8 -------- 13 files changed, 7 insertions(+), 218 deletions(-) diff --git a/checksum_test.go b/checksum_test.go index b1b4e40..6b7ca3a 100644 --- a/checksum_test.go +++ b/checksum_test.go @@ -15,9 +15,8 @@ import ( "github.com/MicahParks/keyfunc" ) +// TestChecksum confirms that the JWKS will only perform a refresh if a new JWKS is read from the remote resource. func TestChecksum(t *testing.T) { - - // Create a temporary directory to serve the JWKS from. tempDir, err := ioutil.TempDir("", "*") if err != nil { t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error()) @@ -31,21 +30,17 @@ func TestChecksum(t *testing.T) { } }() - // Create the JWKS file path. jwksFile := filepath.Join(tempDir, jwksFilePath) - // Write the JWKS. err = ioutil.WriteFile(jwksFile, []byte(jwksJSON), 0600) if err != nil { t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error()) t.FailNow() } - // Create the HTTP test server. server := httptest.NewServer(http.FileServer(http.Dir(tempDir))) defer server.Close() - // Create testing options. testingRefreshErrorHandler := func(err error) { panic(fmt.Sprintf("Unhandled JWKS error: %s", err.Error())) } @@ -54,7 +49,6 @@ func TestChecksum(t *testing.T) { RefreshUnknownKID: true, } - // Set the JWKS URL. jwksURL := server.URL + jwksFilePath jwks, err := keyfunc.Get(jwksURL, opts) @@ -64,7 +58,6 @@ func TestChecksum(t *testing.T) { } defer jwks.EndBackground() - // Get a map of all interface pointers for the JWKS. cryptoKeyPointers := make(map[string]interface{}) for kid, cryptoKey := range jwks.ReadOnlyKeys() { cryptoKeyPointers[kid] = cryptoKey @@ -95,14 +88,12 @@ func TestChecksum(t *testing.T) { } } - // Write a new JWKS to the test file. + // Write a different JWKS. _, _, jwksBytes, _, err := keysAndJWKS() if err != nil { t.Errorf("Failed to create a test JWKS.\nError: %s", err.Error()) t.FailNow() } - - // Write a different JWKS. err = ioutil.WriteFile(jwksFile, jwksBytes, 0600) if err != nil { t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error()) diff --git a/ecdsa.go b/ecdsa.go index 9b7f662..f874ac8 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -9,7 +9,6 @@ import ( ) const ( - // ktyEC is the key type (kty) in the JWT header for ECDSA. ktyEC = "EC" @@ -25,8 +24,6 @@ const ( // ECDSA parses a jsonWebKey and turns it into an ECDSA public key. func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { - - // Confirm everything needed is present. if j.X == "" || j.Y == "" || j.Curve == "" { return nil, fmt.Errorf("%w: %s", ErrMissingAssets, ktyEC) } @@ -39,17 +36,12 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { if err != nil { return nil, err } - - // Decode the Y coordinate from Base64. yCoordinate, err := base64.RawURLEncoding.DecodeString(j.Y) if err != nil { return nil, err } - // Create the ECDSA public key. publicKey = &ecdsa.PublicKey{} - - // Set the curve type. switch j.Curve { case p256: publicKey.Curve = elliptic.P256() @@ -64,8 +56,6 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) { // According to RFC 7517, these numbers are in big-endian format. // https://tools.ietf.org/html/rfc7517#appendix-A.1 publicKey.X = big.NewInt(0).SetBytes(xCoordinate) - - // Turn the Y coordinate into a *big.Int. publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) return publicKey, nil diff --git a/eddsa.go b/eddsa.go index 8b5856e..0e9ad68 100644 --- a/eddsa.go +++ b/eddsa.go @@ -7,15 +7,12 @@ import ( ) const ( - // ktyEC is the key type (kty) in the JWT header for EdDSA. ktyOKP = "OKP" ) // EdDSA parses a jsonWebKey and turns it into a EdDSA public key. func (j *jsonWebKey) EdDSA() (publicKey ed25519.PublicKey, err error) { - - // Confirm everything needed is present. if j.X == "" { return nil, fmt.Errorf("%w: %s", ErrMissingAssets, ktyOKP) } diff --git a/get.go b/get.go index a7718b8..728cd6b 100644 --- a/get.go +++ b/get.go @@ -18,16 +18,12 @@ var ( // Get loads the JWKS at the given URL. func Get(jwksURL string, options Options) (jwks *JWKS, err error) { - - // Create the JWKS. jwks = &JWKS{ jwksURL: jwksURL, } - // Apply the options to the JWKS. applyOptions(jwks, options) - // Apply some defaults if options were not provided. if jwks.client == nil { jwks.client = http.DefaultClient } @@ -35,22 +31,14 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) { jwks.refreshTimeout = defaultRefreshTimeout } - // Get the keys for the JWKS. err = jwks.refresh() if err != nil { return nil, err } - // Check to see if a background refresh of the JWKS should happen. if jwks.refreshInterval != 0 || jwks.refreshUnknownKID { - - // Attach a context used to end the background goroutine. jwks.ctx, jwks.cancel = context.WithCancel(context.Background()) - - // Create a channel that will accept requests to refresh the JWKS. jwks.refreshRequests = make(chan context.CancelFunc, 1) - - // Start the background goroutine for data refresh. go jwks.backgroundRefresh() } @@ -60,8 +48,6 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) { // backgroundRefresh is meant to be a separate goroutine that will update the keys in a JWKS over a given interval of // time. func (j *JWKS) backgroundRefresh() { - - // Create some rate limiting assets. var lastRefresh time.Time var queueOnce sync.Once var refreshMux sync.Mutex @@ -74,16 +60,11 @@ func (j *JWKS) backgroundRefresh() { // Enter an infinite loop that ends when the background ends. for { - - // If there is a refresh interval, create the channel for it. if j.refreshInterval != 0 { refreshInterval = time.After(j.refreshInterval) } - // Wait for a refresh to occur or the background to end. select { - - // Send a refresh request the JWKS after the given interval. case <-refreshInterval: select { case <-j.ctx.Done(): @@ -92,23 +73,16 @@ func (j *JWKS) backgroundRefresh() { default: // If the j.refreshRequests channel is full, don't send another request. } - // Accept refresh requests. case cancel := <-j.refreshRequests: - - // Rate limit, if needed. refreshMux.Lock() if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) { // Don't make the JWT parsing goroutine wait for the JWKS to refresh. cancel() - // Only queue a refresh once. + // Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return. queueOnce.Do(func() { - - // Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return. go func() { - - // Wait for the next time to refresh. refreshMux.Lock() wait := time.Until(lastRefresh.Add(j.refreshRateLimit)) refreshMux.Unlock() @@ -118,7 +92,6 @@ func (j *JWKS) backgroundRefresh() { case <-time.After(wait): } - // Refresh the JWKS. refreshMux.Lock() defer refreshMux.Unlock() err := j.refresh() @@ -126,22 +99,16 @@ func (j *JWKS) backgroundRefresh() { j.refreshErrorHandler(err) } - // Reset the last time for the refresh to now. lastRefresh = time.Now() - - // Allow another queue. queueOnce = sync.Once{} }() }) } else { - - // Refresh the JWKS. err := j.refresh() if err != nil && j.refreshErrorHandler != nil { j.refreshErrorHandler(err) } - // Reset the last time for the refresh to now. lastRefresh = time.Now() // Allow the JWT parsing goroutine to continue with the refreshed JWKS. @@ -158,8 +125,6 @@ func (j *JWKS) backgroundRefresh() { // refresh does an HTTP GET on the JWKS URL to rebuild the JWKS. func (j *JWKS) refresh() (err error) { - - // Create a context for the request. var ctx context.Context var cancel context.CancelFunc if j.ctx != nil { @@ -169,13 +134,11 @@ func (j *JWKS) refresh() (err error) { } defer cancel() - // Create the HTTP request. req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.jwksURL, bytes.NewReader(nil)) if err != nil { return err } - // Get the JWKS as JSON from the given URL. resp, err := j.client.Do(req) if err != nil { return err @@ -183,7 +146,6 @@ func (j *JWKS) refresh() (err error) { //goland:noinspection GoUnhandledErrorResult defer resp.Body.Close() - // Read the raw JWKS from the body of the response. jwksBytes, err := ioutil.ReadAll(resp.Body) if err != nil { return err @@ -195,32 +157,25 @@ func (j *JWKS) refresh() (err error) { } j.raw = jwksBytes - // Create an updated JWKS. updated, err := NewJSON(jwksBytes) if err != nil { return err } - // Lock the JWKS for async safe usage. j.mux.Lock() defer j.mux.Unlock() - - // Update the keys. j.keys = updated.keys - // If given keys were provided, add them back into the refreshed JWKS. - var ok bool if j.givenKeys != nil { for kid, key := range j.givenKeys { // Only overwrite the key if configured to do so. if !j.givenKIDOverride { - if _, ok = j.keys[kid]; ok { + if _, ok := j.keys[kid]; ok { continue } } - // Write the given key to the JWKS. j.keys[kid] = key.inter } } diff --git a/given.go b/given.go index 3321d10..f7bf6ad 100644 --- a/given.go +++ b/given.go @@ -13,16 +13,12 @@ type GivenKey struct { // NewGiven creates a JWKS from a map of given keys. func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) { - - // Initialize the map of kid to cryptographic keys. keys := make(map[string]interface{}) - // Copy the given keys to the map of cryptographic keys. for kid, given := range givenKeys { keys[kid] = given.inter } - // Return a JWKS with the map of cryptographic keys. return &JWKS{ keys: keys, } diff --git a/given_test.go b/given_test.go index 1086abe..9975634 100644 --- a/given_test.go +++ b/given_test.go @@ -17,7 +17,6 @@ import ( ) const ( - // algAttribute is the JSON attribute for the JWT encryption algorithm. algAttribute = "alg" @@ -30,32 +29,24 @@ const ( // TestNewGivenCustom tests that a custom jwt.SigningMethod can be used to create a JWKS and a proper jwt.Keyfunc. func TestNewGivenCustom(t *testing.T) { - - // Register the signing method. jwt.RegisterSigningMethod(method.CustomAlg, func() jwt.SigningMethod { return method.EmptyCustom{} }) - // Create the map of given keys. givenKeys := make(map[string]keyfunc.GivenKey) key := addCustom(givenKeys, testKID) - // Use the custom key to create a JWKS. jwks := keyfunc.NewGiven(givenKeys) - // Create the JWT with the appropriate key ID. token := jwt.New(method.EmptyCustom{}) token.Header[algAttribute] = method.CustomAlg token.Header[kidAttribute] = testKID - // Sign, parse, and validate the JWT. signParseValidate(t, token, key, jwks) } // TestNewGivenKeyECDSA tests that a generated ECDSA key can be added to the JWKS and create a proper jwt.Keyfunc. func TestNewGivenKeyECDSA(t *testing.T) { - - // Create the map of given keys. givenKeys := make(map[string]keyfunc.GivenKey) key, err := addECDSA(givenKeys, testKID) if err != nil { @@ -63,21 +54,16 @@ func TestNewGivenKeyECDSA(t *testing.T) { t.FailNow() } - // Use the RSA public key to create a JWKS. jwks := keyfunc.NewGiven(givenKeys) - // Create the JWT with the appropriate key ID. token := jwt.New(jwt.SigningMethodES256) token.Header[kidAttribute] = testKID - // Sign, parse, and validate the JWT. signParseValidate(t, token, key, jwks) } // TestNewGivenKeyEdDSA tests that a generated EdDSA key can be added to the JWKS and create a proper jwt.Keyfunc. func TestNewGivenKeyEdDSA(t *testing.T) { - - // Create the map of given keys. givenKeys := make(map[string]keyfunc.GivenKey) key, err := addEdDSA(givenKeys, testKID) if err != nil { @@ -85,21 +71,16 @@ func TestNewGivenKeyEdDSA(t *testing.T) { t.FailNow() } - // Use the RSA public key to create a JWKS. jwks := keyfunc.NewGiven(givenKeys) - // Create the JWT with the appropriate key ID. token := jwt.New(jwt.SigningMethodEdDSA) token.Header[kidAttribute] = testKID - // Sign, parse, and validate the JWT. signParseValidate(t, token, key, jwks) } // TestNewGivenKeyHMAC tests that a generated HMAC key can be added to a JWKS and create a proper jwt.Keyfunc. func TestNewGivenKeyHMAC(t *testing.T) { - - // Create the map of given keys. givenKeys := make(map[string]keyfunc.GivenKey) key, err := addHMAC(givenKeys, testKID) if err != nil { @@ -107,21 +88,16 @@ func TestNewGivenKeyHMAC(t *testing.T) { t.FailNow() } - // Use an HMAC secret to create a given JWKS. jwks := keyfunc.NewGiven(givenKeys) - // Create a JWT with the appropriate key ID. token := jwt.New(jwt.SigningMethodHS256) token.Header[kidAttribute] = testKID - // Sign, parse, and validate the JWT. signParseValidate(t, token, key, jwks) } // TestNewGivenKeyRSA tests that a generated RSA key can be added to the JWKS and create a proper jwt.Keyfunc. func TestNewGivenKeyRSA(t *testing.T) { - - // Create the map of given keys. givenKeys := make(map[string]keyfunc.GivenKey) key, err := addRSA(givenKeys, testKID) if err != nil { @@ -129,14 +105,11 @@ func TestNewGivenKeyRSA(t *testing.T) { t.FailNow() } - // Use the RSA public key to create a JWKS. jwks := keyfunc.NewGiven(givenKeys) - // Create the JWT with the appropriate key ID. token := jwt.New(jwt.SigningMethodRS256) token.Header[kidAttribute] = testKID - // Sign, parse, and validate the JWT. signParseValidate(t, token, key, jwks) } @@ -149,14 +122,11 @@ func addCustom(givenKeys map[string]keyfunc.GivenKey, kid string) (key string) { // addECDSA adds a new ECDSA key to the given keys map. func addECDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *ecdsa.PrivateKey, err error) { - - // Create the ECDSA key. key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, fmt.Errorf("failed to create ECDSA key: %w", err) } - // Add the new ECDSA public key to the keys map. givenKeys[kid] = keyfunc.NewGivenECDSA(&key.PublicKey) return key, nil @@ -164,14 +134,11 @@ func addECDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *ecdsa.Pri // addEdDSA adds a new EdDSA key to the given keys map. func addEdDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key ed25519.PrivateKey, err error) { - - // Create the ECDSA key. pub, key, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, fmt.Errorf("failed to create ECDSA key: %w", err) } - // Add the new ECDSA public key to the keys map. givenKeys[kid] = keyfunc.NewGivenEdDSA(pub) return key, nil @@ -179,15 +146,12 @@ func addEdDSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key ed25519.Pr // addHMAC creates a new HMAC secret stuff. func addHMAC(givenKeys map[string]keyfunc.GivenKey, kid string) (secret []byte, err error) { - - // Create the HMAC secret. secret = make([]byte, sha256.BlockSize) _, err = rand.Read(secret) if err != nil { return nil, fmt.Errorf("failed to create HMAC secret: %w", err) } - // Add the new HMAC key to the keys map. givenKeys[kid] = keyfunc.NewGivenHMAC(secret) return secret, nil @@ -195,14 +159,11 @@ func addHMAC(givenKeys map[string]keyfunc.GivenKey, kid string) (secret []byte, // addRSA adds a new RSA key to the given keys map. func addRSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *rsa.PrivateKey, err error) { - - // Create the RSA key. key, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, fmt.Errorf("failed to create RSA key: %w", err) } - // Add the new RSA public key to the keys map. givenKeys[kid] = keyfunc.NewGivenRSA(&key.PublicKey) return key, nil @@ -210,22 +171,18 @@ func addRSA(givenKeys map[string]keyfunc.GivenKey, kid string) (key *rsa.Private // signParseValidate signs the JWT, parses it using the given JWKS, then validates it. func signParseValidate(t *testing.T, token *jwt.Token, key interface{}, jwks *keyfunc.JWKS) { - - // Sign the token. jwtB64, err := token.SignedString(key) if err != nil { t.Errorf("Failed to sign the JWT.\nError: %s", err.Error()) t.FailNow() } - // Parse the JWT using the JWKS. parsed, err := jwt.Parse(jwtB64, jwks.Keyfunc) if err != nil { t.Errorf("Failed to parse the JWT.\nError: %s.", err.Error()) t.FailNow() } - // Confirm the JWT is valid. if !parsed.Valid { t.Errorf("The JWT was not valid.") t.FailNow() diff --git a/jwks.go b/jwks.go index 1b2dffc..39b2618 100644 --- a/jwks.go +++ b/jwks.go @@ -59,8 +59,6 @@ type rawJWKS struct { // NewJSON creates a new JWKS from a raw JSON message. func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) { - - // Turn the raw JWKS into the correct Go type. var rawKS rawJWKS err = json.Unmarshal(jwksBytes, &rawKS) if err != nil { @@ -72,8 +70,6 @@ func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) { keys: make(map[string]interface{}, len(rawKS.Keys)), } for _, key := range rawKS.Keys { - - // Determine the key's type and create the appropriate public key. var keyInter interface{} switch keyType := key.Type; keyType { case ktyEC: @@ -141,19 +137,12 @@ func (j *JWKS) ReadOnlyKeys() map[string]interface{} { // getKey gets the jsonWebKey from the given KID from the JWKS. It may refresh the JWKS if configured to. func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) { - - // Get the jsonWebKey from the JWKS. j.mux.RLock() jsonKey, ok := j.keys[kid] j.mux.RUnlock() - // Check if the key was present. if !ok { - - // Check to see if configured to refresh on unknown kid. if j.refreshUnknownKID { - - // Create a context for refreshing the JWKS. ctx, cancel := context.WithCancel(j.ctx) // Refresh the JWKS. @@ -162,7 +151,6 @@ func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) { return case j.refreshRequests <- cancel: default: - // If the j.refreshRequests channel is full, return the error early. return nil, ErrKIDNotFound } @@ -170,11 +158,8 @@ func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) { // Wait for the JWKS refresh to finish. <-ctx.Done() - // Lock the JWKS for async safe use. j.mux.RLock() defer j.mux.RUnlock() - - // Check if the JWKS refresh contained the requested key. if jsonKey, ok = j.keys[kid]; ok { return jsonKey, nil } diff --git a/jwks_test.go b/jwks_test.go index f29a622..5a4a435 100644 --- a/jwks_test.go +++ b/jwks_test.go @@ -20,7 +20,6 @@ import ( ) const ( - // emptyJWKSJSON is a hard-coded empty JWKS in JSON format. emptyJWKSJSON = `{"keys":[]}` @@ -34,8 +33,6 @@ const ( // TestInvalidServer performs initialization + refresh initialization with a server providing invalid data. // The test ensures that background refresh goroutine does not cause any trouble in case of init failure. func TestInvalidServer(t *testing.T) { - - // Create the HTTP test server. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write(nil) if err != nil { @@ -45,20 +42,17 @@ func TestInvalidServer(t *testing.T) { })) defer server.Close() - // Create testing options. testingRefreshErrorHandler := func(err error) { t.Errorf("Unhandled JWKS error: %s", err.Error()) t.FailNow() } - // Set the options to refresh KID when unknown. refreshInterval := time.Second options := keyfunc.Options{ RefreshInterval: refreshInterval, RefreshErrorHandler: testingRefreshErrorHandler, } - // Create the JWKS. _, err := keyfunc.Get(server.URL, options) if err == nil { t.Errorf("Creation of *keyfunc.JWKS with invalid server must fail.") @@ -68,8 +62,6 @@ func TestInvalidServer(t *testing.T) { // TestJWKS performs a table test on the JWKS code. func TestJWKS(t *testing.T) { - - // Create a temporary directory to serve the JWKS from. tempDir, err := ioutil.TempDir("", "*") if err != nil { t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error()) @@ -83,21 +75,17 @@ func TestJWKS(t *testing.T) { } }() - // Create the JWKS file path. jwksFile := filepath.Join(tempDir, jwksFilePath) - // Write the JWKS. err = ioutil.WriteFile(jwksFile, []byte(jwksJSON), 0600) if err != nil { t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error()) t.FailNow() } - // Create the HTTP test server. server := httptest.NewServer(http.FileServer(http.Dir(tempDir))) defer server.Close() - // Create testing options. testingRefreshInterval := time.Second testingRateLimit := time.Millisecond * 500 testingRefreshTimeout := time.Second @@ -105,10 +93,8 @@ func TestJWKS(t *testing.T) { panic(fmt.Sprintf("Unhandled JWKS error: %s", err.Error())) } - // Set the JWKS URL. jwksURL := server.URL + jwksFilePath - // Create a table of options to test. options := []keyfunc.Options{ {}, // Default options. { @@ -131,17 +117,13 @@ func TestJWKS(t *testing.T) { }, } - // Iterate through all options. for _, opts := range options { - - // Create the JWKS from the resource at the testing URL. jwks, err := keyfunc.Get(jwksURL, opts) if err != nil { t.Errorf("Failed to get JWKS from testing URL.\nError: %s", err.Error()) t.FailNow() } - // Create the test cases. testCases := []struct { token string }{ @@ -161,12 +143,10 @@ func TestJWKS(t *testing.T) { {"eyJhbGciOiJIUzI1NiIsImtpZCI6ImhtYWMiLCJrdHkiOiJvY3QiLCJ0eXAiOiJKV1QifQ.e30.vZ8H2-9j1pDXLNL2GFKbZOkC2qyA0dr7AiTJpNjgLcY"}, // HMAC } - // Wait for the interval to pass, if required. if opts.RefreshInterval != 0 { time.Sleep(opts.RefreshInterval) } - // Iterate through the test cases. for _, tc := range testCases { t.Run(fmt.Sprintf("token: %s", tc.token), func(t *testing.T) { @@ -184,22 +164,18 @@ func TestJWKS(t *testing.T) { }) } - // End the background goroutine. jwks.EndBackground() } } // TestKIDs confirms the JWKS.KIDs returns the key IDs (`kid`) stored in the JWKS. func TestJWKS_KIDs(t *testing.T) { - - // Create the JWKS from JSON. jwks, err := keyfunc.NewJSON([]byte(jwksJSON)) if err != nil { t.Errorf("Failed to create a JWKS from JSON.\nError: %s", err.Error()) t.FailNow() } - // The expected key IDs. expectedKIDs := []string{ "zXew0UJ1h6Q4CCcd_9wxMzvcp5cEBifH0KWrCz2Kyxc", "ebJxnm9B3QDBljB5XJWEu72qx6BawDaMAhwz4aKPkQ0", @@ -214,10 +190,8 @@ func TestJWKS_KIDs(t *testing.T) { "hmac", } - // Get all key IDs in the JWKS. actual := jwks.KIDs() - // Confirm the length is the same. actualLen := len(actual) expectedLen := len(expectedKIDs) if actualLen != expectedLen { @@ -225,7 +199,6 @@ func TestJWKS_KIDs(t *testing.T) { t.FailNow() } - // Confirm all expected keys are present. for _, expectedKID := range expectedKIDs { found := false for _, kid := range actual { @@ -242,8 +215,6 @@ func TestJWKS_KIDs(t *testing.T) { // TestRateLimit performs a test to confirm the rate limiter works as expected. func TestRateLimit(t *testing.T) { - - // Create a temporary directory to serve the JWKS from. tempDir, err := ioutil.TempDir("", "*") if err != nil { t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error()) @@ -257,19 +228,14 @@ func TestRateLimit(t *testing.T) { } }() - // Create an integer to keep track of how many times the JWKS has been refreshed. refreshes := uint(0) refreshMux := sync.Mutex{} - // Create the HTTP test server. server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - - // Increment the number of refreshes that have occurred. refreshMux.Lock() refreshes++ refreshMux.Unlock() - // Write the JWKS to the response, regardless of the request. writer.WriteHeader(200) if _, serveErr := writer.Write([]byte(jwksJSON)); serveErr != nil { t.Errorf("Failed to serve JWKS.\nError: %s", err.Error()) @@ -277,10 +243,8 @@ func TestRateLimit(t *testing.T) { })) defer server.Close() - // Set the JWKS URL. jwksURL := server.URL + jwksFilePath - // Create the testing options. refreshInterval := time.Second refreshRateLimit := time.Millisecond * 500 refreshTimeout := time.Second @@ -294,7 +258,6 @@ func TestRateLimit(t *testing.T) { RefreshUnknownKID: true, } - // Create the JWKS. jwks, err := keyfunc.Get(jwksURL, options) if err != nil { t.Errorf("Failed to create *keyfunc.JWKS.\nError: %s", err.Error()) @@ -381,8 +344,6 @@ func TestRateLimit(t *testing.T) { // TestUnknownKIDRefresh performs a test to confirm that an Unknown kid with refresh the JWKS. func TestUnknownKIDRefresh(t *testing.T) { - - // Create a temporary directory to serve the JWKS from. tempDir, err := ioutil.TempDir("", "*") if err != nil { t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error()) @@ -396,36 +357,29 @@ func TestUnknownKIDRefresh(t *testing.T) { } }() - // Create the JWKS file path. jwksFile := filepath.Join(tempDir, strings.TrimPrefix(jwksFilePath, "/")) - // Write the empty JWKS. err = ioutil.WriteFile(jwksFile, []byte(emptyJWKSJSON), 0600) if err != nil { t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error()) t.FailNow() } - // Create the HTTP test server. server := httptest.NewServer(http.FileServer(http.Dir(tempDir))) defer server.Close() - // Create testing options. testingRefreshErrorHandler := func(err error) { t.Errorf("Unhandled JWKS error: %s", err.Error()) t.FailNow() } - // Set the JWKS URL. jwksURL := server.URL + jwksFilePath - // Set the options to refresh KID when unknown. options := keyfunc.Options{ RefreshErrorHandler: testingRefreshErrorHandler, RefreshUnknownKID: true, } - // Create the JWKS. jwks, err := keyfunc.Get(jwksURL, options) if err != nil { t.Errorf("Failed to create *keyfunc.JWKS.\nError: %s", err.Error()) @@ -433,7 +387,6 @@ func TestUnknownKIDRefresh(t *testing.T) { } defer jwks.EndBackground() - // Write the populated JWKS. err = ioutil.WriteFile(jwksFile, []byte(jwksJSON), 0600) if err != nil { t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error()) diff --git a/keyfunc.go b/keyfunc.go index 05800c2..703c92f 100644 --- a/keyfunc.go +++ b/keyfunc.go @@ -16,8 +16,6 @@ var ( // Keyfunc is a compatibility function that matches the signature of github.com/golang-jwt/jwt/v4's jwt.Keyfunc // function. func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) { - - // Get the kid from the token header. kidInter, ok := token.Header["kid"] if !ok { return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKID) @@ -27,6 +25,5 @@ func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) { return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID) } - // Get the Go type for the correct cryptographic key. return j.getKey(kid) } diff --git a/oct.go b/oct.go index 0f4dcb8..26ab527 100644 --- a/oct.go +++ b/oct.go @@ -6,15 +6,12 @@ import ( ) const ( - // ktyOct is the key type (kty) in the JWT header for oct. ktyOct = "oct" ) // Oct parses a jsonWebKey and turns it into a raw byte slice (octet). This includes HMAC keys. func (j *jsonWebKey) Oct() (publicKey []byte, err error) { - - // Confirm everything needed is present. if j.K == "" { return nil, fmt.Errorf("%w: %s", ErrMissingAssets, ktyOct) } diff --git a/options.go b/options.go index 51961bc..a076190 100644 --- a/options.go +++ b/options.go @@ -62,12 +62,14 @@ func applyOptions(jwks *JWKS, options Options) { if options.Ctx != nil { jwks.ctx, jwks.cancel = context.WithCancel(options.Ctx) } + if options.GivenKeys != nil { jwks.givenKeys = make(map[string]GivenKey) for kid, key := range options.GivenKeys { jwks.givenKeys[kid] = key } } + jwks.client = options.Client jwks.givenKIDOverride = options.GivenKIDOverride jwks.refreshErrorHandler = options.RefreshErrorHandler diff --git a/override_test.go b/override_test.go index d7455fb..7e41eef 100644 --- a/override_test.go +++ b/override_test.go @@ -21,11 +21,7 @@ import ( ) const ( - - // givenKID is the key ID for the given key with a unique ID. - givenKID = "givenKID" - - // remoteKID is the key ID for the remote key and given key that has a conflicting key ID. + givenKID = "givenKID" remoteKID = "remoteKID" ) @@ -44,8 +40,6 @@ type pseudoJSONKey struct { // TestNewGiven tests that given keys will be added to a JWKS with a remote resource. func TestNewGiven(t *testing.T) { - - // Create a temporary directory to serve the JWKS from. tempDir, err := ioutil.TempDir("", "*") if err != nil { t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error()) @@ -59,42 +53,34 @@ func TestNewGiven(t *testing.T) { } }() - // Create the JWKS file path. jwksFile := filepath.Join(tempDir, jwksFilePath) - // Create the keys used for this test. givenKeys, givenPrivateKeys, jwksBytes, remotePrivateKeys, err := keysAndJWKS() if err != nil { t.Errorf("Failed to create cryptographic keys for the test.\nError: %s.", err.Error()) t.FailNow() } - // Write the empty JWKS. err = ioutil.WriteFile(jwksFile, jwksBytes, 0600) if err != nil { t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error()) t.FailNow() } - // Create the HTTP test server. server := httptest.NewServer(http.FileServer(http.Dir(tempDir))) defer server.Close() - // Create testing options. testingRefreshErrorHandler := func(err error) { panic(fmt.Sprintf("Unhandled JWKS error.\nError: %s", err.Error())) } - // Set the JWKS URL. jwksURL := server.URL + jwksFilePath - // Create the test options. options := keyfunc.Options{ GivenKeys: givenKeys, RefreshErrorHandler: testingRefreshErrorHandler, } - // Get the remote JWKS. jwks, err := keyfunc.Get(jwksURL, options) if err != nil { t.Errorf("Failed to get the JWKS the testing URL.\nError: %s", err.Error()) @@ -130,19 +116,15 @@ func TestNewGiven(t *testing.T) { // createSignParseValidate creates, signs, parses, and validates a JWT. func createSignParseValidate(t *testing.T, keys map[string]*rsa.PrivateKey, jwks *keyfunc.JWKS, kid string, shouldValidate bool) { - - // Create the JWT. unsignedToken := jwt.New(jwt.SigningMethodRS256) unsignedToken.Header[kidAttribute] = kid - // Sign the JWT. jwtB64, err := unsignedToken.SignedString(keys[kid]) if err != nil { t.Errorf("Failed to sign the JWT.\nError: %s.", err.Error()) t.FailNow() } - // Parse the JWT. token, err := jwt.Parse(jwtB64, jwks.Keyfunc) if err != nil { if !shouldValidate && !errors.Is(err, rsa.ErrVerification) { @@ -157,7 +139,6 @@ func createSignParseValidate(t *testing.T, keys map[string]*rsa.PrivateKey, jwks t.FailNow() } - // Validate the JWT. if !token.Valid { t.Errorf("The JWT is not valid.") t.FailNow() @@ -166,8 +147,6 @@ func createSignParseValidate(t *testing.T, keys map[string]*rsa.PrivateKey, jwks // keysAndJWKS creates a couple of cryptographic keys and the remote JWKS associated with them. func keysAndJWKS() (givenKeys map[string]keyfunc.GivenKey, givenPrivateKeys map[string]*rsa.PrivateKey, jwksBytes []byte, remotePrivateKeys map[string]*rsa.PrivateKey, err error) { - - // Initialize the function's assets. const rsaErrMessage = "failed to create RSA key: %w" givenKeys = make(map[string]keyfunc.GivenKey) givenPrivateKeys = make(map[string]*rsa.PrivateKey) @@ -194,7 +173,6 @@ func keysAndJWKS() (givenKeys map[string]keyfunc.GivenKey, givenPrivateKeys map[ } remotePrivateKeys[remoteKID] = key3 - // Create a pseudo-JWKS. jwks := pseudoJWKS{Keys: []pseudoJSONKey{{ KID: remoteKID, KTY: "RSA", @@ -202,7 +180,6 @@ func keysAndJWKS() (givenKeys map[string]keyfunc.GivenKey, givenPrivateKeys map[ N: base64.RawURLEncoding.EncodeToString(key3.PublicKey.N.Bytes()), }}} - // Marshal the JWKS to JSON. jwksBytes, err = json.Marshal(jwks) if err != nil { return nil, nil, nil, nil, fmt.Errorf("failed to marshal the JWKS to JSON: %w", err) diff --git a/rsa.go b/rsa.go index 79915e8..7184bf6 100644 --- a/rsa.go +++ b/rsa.go @@ -8,15 +8,12 @@ import ( ) const ( - // ktyRSA is the key type (kty) in the JWT header for RSA. ktyRSA = "RSA" ) // RSA parses a jsonWebKey and turns it into an RSA public key. func (j *jsonWebKey) RSA() (publicKey *rsa.PublicKey, err error) { - - // Confirm everything needed is present. if j.Exponent == "" || j.Modulus == "" { return nil, fmt.Errorf("%w: %s", ErrMissingAssets, ktyRSA) } @@ -29,14 +26,11 @@ func (j *jsonWebKey) RSA() (publicKey *rsa.PublicKey, err error) { if err != nil { return nil, err } - - // Decode the modulus from Base64. modulus, err := base64.RawURLEncoding.DecodeString(j.Modulus) if err != nil { return nil, err } - // Create the RSA public key. publicKey = &rsa.PublicKey{} // Turn the exponent into an integer. @@ -44,8 +38,6 @@ func (j *jsonWebKey) RSA() (publicKey *rsa.PublicKey, err error) { // According to RFC 7517, these numbers are in big-endian format. // https://tools.ietf.org/html/rfc7517#appendix-A.1 publicKey.E = int(big.NewInt(0).SetBytes(exponent).Uint64()) - - // Turn the modulus into a *big.Int. publicKey.N = big.NewInt(0).SetBytes(modulus) return publicKey, nil