Skip to content

Commit

Permalink
Refresh JWKS in background when the cached JWKS expires rather than b…
Browse files Browse the repository at this point in the history
…lock for refresh

Previously, whenever the cached JWKS expired there would be a block until the request has completed
and the cache has been updated. Now, when the JWKS cache expires the refresh is triggered in the
background and the cached JWKS will continue to be returned until the background refresh has been
completed. If that background refresh succeeds then the cache will be updated, if it errors then
the cached JWKS will be deleted and on the next call of KeyFunc the JWKS refresh will be attempted
again, which (if the error persists) would raise the error.

This now means that the only time KeyFunc will block is in the instance where there is no JWKS for
an issuer, when a cached JWKS expires we will no longer block
  • Loading branch information
ewanharris committed Dec 1, 2023
1 parent 34d8362 commit 3cff0b1
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 10 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.19
require (
github.com/google/go-cmp v0.6.0
github.com/stretchr/testify v1.8.4
golang.org/x/sync v0.5.0
gopkg.in/go-jose/go-jose.v2 v2.6.1
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/go-jose/go-jose.v2 v2.6.1 h1:qEzJlIDmG9q5VO0M/o8tGS65QMHMS1w01TQJB1VPJ4U=
Expand Down
25 changes: 22 additions & 3 deletions jwks/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"golang.org/x/sync/semaphore"
"gopkg.in/go-jose/go-jose.v2"

"github.com/auth0/go-jwt-middleware/v2/internal/oidc"
Expand Down Expand Up @@ -97,11 +98,16 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
// CachingProvider handles getting JWKS from the specified IssuerURL
// and caching them for CacheTTL time. It exposes KeyFunc which adheres
// to the keyFunc signature that the Validator requires.
// When the CacheTTL value has been reached, a JWKS refresh will be triggered
// in the background and the existing cached JWKS will be returned until the
// JWKS cache is updated, or if the request errors then it will be evicted from
// the cache.
type CachingProvider struct {
*Provider
CacheTTL time.Duration
mu sync.RWMutex
cache map[string]cachedJWKS
sem semaphore.Weighted
}

type cachedJWKS struct {
Expand All @@ -120,6 +126,7 @@ func NewCachingProvider(issuerURL *url.URL, cacheTTL time.Duration, opts ...Prov
Provider: NewProvider(issuerURL, opts...),
CacheTTL: cacheTTL,
cache: map[string]cachedJWKS{},
sem: *semaphore.NewWeighted(1),
}
}

Expand All @@ -132,10 +139,22 @@ func (c *CachingProvider) KeyFunc(ctx context.Context) (interface{}, error) {
issuer := c.IssuerURL.Hostname()

if cached, ok := c.cache[issuer]; ok {
if !time.Now().After(cached.expiresAt) {
c.mu.RUnlock()
return cached.jwks, nil
if time.Now().After(cached.expiresAt) && c.sem.TryAcquire(1) {
go func() {
defer c.sem.Release(1)
refreshCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.refreshKey(refreshCtx, issuer)

if err != nil {
c.mu.Lock()
delete(c.cache, issuer)
c.mu.Unlock()
}
}()
}
c.mu.RUnlock()
return cached.jwks, nil
}

c.mu.RUnlock()
Expand Down
100 changes: 93 additions & 7 deletions jwks/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/go-jose/go-jose.v2"

Expand Down Expand Up @@ -84,7 +85,8 @@ func Test_JWKSProvider(t *testing.T) {
}
})

t.Run("It re-caches the JWKS if they have expired when using CachingProvider", func(t *testing.T) {
t.Run("It eventually re-caches the JWKS if they have expired when using CachingProvider", func(t *testing.T) {
requestCount = 0
expiredCachedJWKS, err := generateJWKS()
require.NoError(t, err)

Expand All @@ -94,16 +96,20 @@ func Test_JWKSProvider(t *testing.T) {
expiresAt: time.Now().Add(-10 * time.Minute),
}

actualJWKS, err := provider.KeyFunc(context.Background())
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)

if !cmp.Equal(expectedJWKS, actualJWKS) {
t.Fatalf("jwks did not match: %s", cmp.Diff(expectedJWKS, actualJWKS))
if !cmp.Equal(expiredCachedJWKS, returnedJWKS) {
t.Fatalf("jwks did not match: %s", cmp.Diff(expiredCachedJWKS, returnedJWKS))
}

if !cmp.Equal(expectedJWKS, provider.cache[testServerURL.Hostname()].jwks) {
t.Fatalf("cached jwks did not match: %s", cmp.Diff(expectedJWKS, provider.cache[testServerURL.Hostname()].jwks))
}
require.EventuallyWithT(t, func(c *assert.CollectT) {
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)

assert.True(c, cmp.Equal(expectedJWKS, returnedJWKS))
assert.Equal(c, int32(2), requestCount)
}, 1*time.Second, 250*time.Millisecond, "JWKS did not update")

cacheExpiresAt := provider.cache[testServerURL.Hostname()].expiresAt
if !time.Now().Before(cacheExpiresAt) {
Expand Down Expand Up @@ -154,6 +160,86 @@ func Test_JWKSProvider(t *testing.T) {
}
},
)

t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with expired cache", func(t *testing.T) {
initialJWKS, err := generateJWKS()
require.NoError(t, err)
requestCount = 0

provider := NewCachingProvider(testServerURL, 5*time.Minute)
provider.cache[testServerURL.Hostname()] = cachedJWKS{
jwks: initialJWKS,
expiresAt: time.Now(),
}

var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
_, _ = provider.KeyFunc(context.Background())
wg.Done()
}()
}
wg.Wait()

require.EventuallyWithT(t, func(c *assert.CollectT) {
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)

assert.True(c, cmp.Equal(expectedJWKS, returnedJWKS))
assert.Equal(c, int32(2), requestCount)
}, 1*time.Second, 250*time.Millisecond, "JWKS did not update")
})

t.Run("It only calls the API once when multiple requests come in when using the CachingProvider with no cache", func(t *testing.T) {
provider := NewCachingProvider(testServerURL, 5*time.Minute)
requestCount = 0

var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
_, _ = provider.KeyFunc(context.Background())
wg.Done()
}()
}
wg.Wait()

if requestCount != 2 {
t.Fatalf("only wanted 2 requests (well known and jwks) , but we got %d requests", requestCount)
}
})

t.Run("Should delete cache entry if the refresh request fails", func(t *testing.T) {
malformedURL, err := url.Parse(testServer.URL + "/malformed")
require.NoError(t, err)

expiredCachedJWKS, err := generateJWKS()
require.NoError(t, err)

provider := NewCachingProvider(malformedURL, 5*time.Minute)
provider.cache[malformedURL.Hostname()] = cachedJWKS{
jwks: expiredCachedJWKS,
expiresAt: time.Now().Add(-10 * time.Minute),
}

// Trigger the refresh of the JWKS, which should return the cached JWKS
returnedJWKS, err := provider.KeyFunc(context.Background())
require.NoError(t, err)
assert.Equal(t, expiredCachedJWKS, returnedJWKS)

// Eventually it should return a nil JWKS
require.EventuallyWithT(t, func(c *assert.CollectT) {
returnedJWKS, err := provider.KeyFunc(context.Background())
require.Error(t, err)

assert.Nil(c, returnedJWKS)

cachedJWKS := provider.cache[malformedURL.Hostname()].jwks

assert.Nil(t, cachedJWKS)
}, 1*time.Second, 250*time.Millisecond, "JWKS did not get uncached")
})
}

func generateJWKS() (*jose.JSONWebKeySet, error) {
Expand Down

0 comments on commit 3cff0b1

Please sign in to comment.