Skip to content

Commit

Permalink
Remove comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Micah Parks committed Apr 6, 2022
1 parent 81dfc4f commit 16db93f
Show file tree
Hide file tree
Showing 13 changed files with 7 additions and 218 deletions.
13 changes: 2 additions & 11 deletions checksum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ import (
"github.com/MicahParks/keyfunc"
)

// TestChecksum confirms that the JWKS will only perform a refresh if a new JWKS is read from the remote resource.
func TestChecksum(t *testing.T) {

// Create a temporary directory to serve the JWKS from.
tempDir, err := ioutil.TempDir("", "*")
if err != nil {
t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error())
Expand All @@ -31,21 +30,17 @@ func TestChecksum(t *testing.T) {
}
}()

// Create the JWKS file path.
jwksFile := filepath.Join(tempDir, jwksFilePath)

// Write the JWKS.
err = ioutil.WriteFile(jwksFile, []byte(jwksJSON), 0600)
if err != nil {
t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error())
t.FailNow()
}

// Create the HTTP test server.
server := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
defer server.Close()

// Create testing options.
testingRefreshErrorHandler := func(err error) {
panic(fmt.Sprintf("Unhandled JWKS error: %s", err.Error()))
}
Expand All @@ -54,7 +49,6 @@ func TestChecksum(t *testing.T) {
RefreshUnknownKID: true,
}

// Set the JWKS URL.
jwksURL := server.URL + jwksFilePath

jwks, err := keyfunc.Get(jwksURL, opts)
Expand All @@ -64,7 +58,6 @@ func TestChecksum(t *testing.T) {
}
defer jwks.EndBackground()

// Get a map of all interface pointers for the JWKS.
cryptoKeyPointers := make(map[string]interface{})
for kid, cryptoKey := range jwks.ReadOnlyKeys() {
cryptoKeyPointers[kid] = cryptoKey
Expand Down Expand Up @@ -95,14 +88,12 @@ func TestChecksum(t *testing.T) {
}
}

// Write a new JWKS to the test file.
// Write a different JWKS.
_, _, jwksBytes, _, err := keysAndJWKS()
if err != nil {
t.Errorf("Failed to create a test JWKS.\nError: %s", err.Error())
t.FailNow()
}

// Write a different JWKS.
err = ioutil.WriteFile(jwksFile, jwksBytes, 0600)
if err != nil {
t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error())
Expand Down
10 changes: 0 additions & 10 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
)

const (

// ktyEC is the key type (kty) in the JWT header for ECDSA.
ktyEC = "EC"

Expand All @@ -25,8 +24,6 @@ const (

// ECDSA parses a jsonWebKey and turns it into an ECDSA public key.
func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {

// Confirm everything needed is present.
if j.X == "" || j.Y == "" || j.Curve == "" {
return nil, fmt.Errorf("%w: %s", ErrMissingAssets, ktyEC)
}
Expand All @@ -39,17 +36,12 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {
if err != nil {
return nil, err
}

// Decode the Y coordinate from Base64.
yCoordinate, err := base64.RawURLEncoding.DecodeString(j.Y)
if err != nil {
return nil, err
}

// Create the ECDSA public key.
publicKey = &ecdsa.PublicKey{}

// Set the curve type.
switch j.Curve {
case p256:
publicKey.Curve = elliptic.P256()
Expand All @@ -64,8 +56,6 @@ func (j *jsonWebKey) ECDSA() (publicKey *ecdsa.PublicKey, err error) {
// According to RFC 7517, these numbers are in big-endian format.
// https://tools.ietf.org/html/rfc7517#appendix-A.1
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)

// Turn the Y coordinate into a *big.Int.
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)

return publicKey, nil
Expand Down
3 changes: 0 additions & 3 deletions eddsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ import (
)

const (

// ktyEC is the key type (kty) in the JWT header for EdDSA.
ktyOKP = "OKP"
)

// EdDSA parses a jsonWebKey and turns it into a EdDSA public key.
func (j *jsonWebKey) EdDSA() (publicKey ed25519.PublicKey, err error) {

// Confirm everything needed is present.
if j.X == "" {
return nil, fmt.Errorf("%w: %s", ErrMissingAssets, ktyOKP)
}
Expand Down
49 changes: 2 additions & 47 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,27 @@ var (

// Get loads the JWKS at the given URL.
func Get(jwksURL string, options Options) (jwks *JWKS, err error) {

// Create the JWKS.
jwks = &JWKS{
jwksURL: jwksURL,
}

// Apply the options to the JWKS.
applyOptions(jwks, options)

// Apply some defaults if options were not provided.
if jwks.client == nil {
jwks.client = http.DefaultClient
}
if jwks.refreshTimeout == 0 {
jwks.refreshTimeout = defaultRefreshTimeout
}

// Get the keys for the JWKS.
err = jwks.refresh()
if err != nil {
return nil, err
}

// Check to see if a background refresh of the JWKS should happen.
if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {

// Attach a context used to end the background goroutine.
jwks.ctx, jwks.cancel = context.WithCancel(context.Background())

// Create a channel that will accept requests to refresh the JWKS.
jwks.refreshRequests = make(chan context.CancelFunc, 1)

// Start the background goroutine for data refresh.
go jwks.backgroundRefresh()
}

Expand All @@ -60,8 +48,6 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
// backgroundRefresh is meant to be a separate goroutine that will update the keys in a JWKS over a given interval of
// time.
func (j *JWKS) backgroundRefresh() {

// Create some rate limiting assets.
var lastRefresh time.Time
var queueOnce sync.Once
var refreshMux sync.Mutex
Expand All @@ -74,16 +60,11 @@ func (j *JWKS) backgroundRefresh() {

// Enter an infinite loop that ends when the background ends.
for {

// If there is a refresh interval, create the channel for it.
if j.refreshInterval != 0 {
refreshInterval = time.After(j.refreshInterval)
}

// Wait for a refresh to occur or the background to end.
select {

// Send a refresh request the JWKS after the given interval.
case <-refreshInterval:
select {
case <-j.ctx.Done():
Expand All @@ -92,23 +73,16 @@ func (j *JWKS) backgroundRefresh() {
default: // If the j.refreshRequests channel is full, don't send another request.
}

// Accept refresh requests.
case cancel := <-j.refreshRequests:

// Rate limit, if needed.
refreshMux.Lock()
if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {

// Don't make the JWT parsing goroutine wait for the JWKS to refresh.
cancel()

// Only queue a refresh once.
// Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return.
queueOnce.Do(func() {

// Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return.
go func() {

// Wait for the next time to refresh.
refreshMux.Lock()
wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
refreshMux.Unlock()
Expand All @@ -118,30 +92,23 @@ func (j *JWKS) backgroundRefresh() {
case <-time.After(wait):
}

// Refresh the JWKS.
refreshMux.Lock()
defer refreshMux.Unlock()
err := j.refresh()
if err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}

// Reset the last time for the refresh to now.
lastRefresh = time.Now()

// Allow another queue.
queueOnce = sync.Once{}
}()
})
} else {

// Refresh the JWKS.
err := j.refresh()
if err != nil && j.refreshErrorHandler != nil {
j.refreshErrorHandler(err)
}

// Reset the last time for the refresh to now.
lastRefresh = time.Now()

// Allow the JWT parsing goroutine to continue with the refreshed JWKS.
Expand All @@ -158,8 +125,6 @@ func (j *JWKS) backgroundRefresh() {

// refresh does an HTTP GET on the JWKS URL to rebuild the JWKS.
func (j *JWKS) refresh() (err error) {

// Create a context for the request.
var ctx context.Context
var cancel context.CancelFunc
if j.ctx != nil {
Expand All @@ -169,21 +134,18 @@ func (j *JWKS) refresh() (err error) {
}
defer cancel()

// Create the HTTP request.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.jwksURL, bytes.NewReader(nil))
if err != nil {
return err
}

// Get the JWKS as JSON from the given URL.
resp, err := j.client.Do(req)
if err != nil {
return err
}
//goland:noinspection GoUnhandledErrorResult
defer resp.Body.Close()

// Read the raw JWKS from the body of the response.
jwksBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
Expand All @@ -195,32 +157,25 @@ func (j *JWKS) refresh() (err error) {
}
j.raw = jwksBytes

// Create an updated JWKS.
updated, err := NewJSON(jwksBytes)
if err != nil {
return err
}

// Lock the JWKS for async safe usage.
j.mux.Lock()
defer j.mux.Unlock()

// Update the keys.
j.keys = updated.keys

// If given keys were provided, add them back into the refreshed JWKS.
var ok bool
if j.givenKeys != nil {
for kid, key := range j.givenKeys {

// Only overwrite the key if configured to do so.
if !j.givenKIDOverride {
if _, ok = j.keys[kid]; ok {
if _, ok := j.keys[kid]; ok {
continue
}
}

// Write the given key to the JWKS.
j.keys[kid] = key.inter
}
}
Expand Down
4 changes: 0 additions & 4 deletions given.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@ type GivenKey struct {

// NewGiven creates a JWKS from a map of given keys.
func NewGiven(givenKeys map[string]GivenKey) (jwks *JWKS) {

// Initialize the map of kid to cryptographic keys.
keys := make(map[string]interface{})

// Copy the given keys to the map of cryptographic keys.
for kid, given := range givenKeys {
keys[kid] = given.inter
}

// Return a JWKS with the map of cryptographic keys.
return &JWKS{
keys: keys,
}
Expand Down
Loading

0 comments on commit 16db93f

Please sign in to comment.