From 24680eb3700dcb812c20352c8d85239722caeb23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20=C3=81lvarez?= <86166683+ralvarezdev@users.noreply.github.com> Date: Sun, 16 Feb 2025 13:21:59 -0400 Subject: [PATCH] refactor: improved jwt authenticator middleware to handle expired cookies --- http/middleware/auth/authenticator.go | 10 +- http/middleware/auth/middleware.go | 182 +++++++++++++++----------- 2 files changed, 109 insertions(+), 83 deletions(-) diff --git a/http/middleware/auth/authenticator.go b/http/middleware/auth/authenticator.go index 8297cca..e96c00b 100644 --- a/http/middleware/auth/authenticator.go +++ b/http/middleware/auth/authenticator.go @@ -15,21 +15,17 @@ type Authenticator interface { err error, errorCode *string, ), - refreshTokenFn func( - w http.ResponseWriter, - r *http.Request, - ) error, - authenticateFn func(next http.Handler) http.Handler, ) func(next http.Handler) http.Handler AuthenticateFromHeader( token gojwttoken.Token, ) func(next http.Handler) http.Handler AuthenticateFromCookie( token gojwttoken.Token, - cookieName string, + cookieRefreshTokenName, + cookieAccessTokenName string, refreshTokenFn func( w http.ResponseWriter, r *http.Request, - ) error, + ) (*map[gojwttoken.Token]string, error), ) func(next http.Handler) http.Handler } diff --git a/http/middleware/auth/middleware.go b/http/middleware/auth/middleware.go index 0f53c36..c2c2324 100644 --- a/http/middleware/auth/middleware.go +++ b/http/middleware/auth/middleware.go @@ -2,7 +2,6 @@ package auth import ( "errors" - "github.com/golang-jwt/jwt/v5" gojwt "github.com/ralvarezdev/go-jwt" gojwtnethttp "github.com/ralvarezdev/go-jwt/net/http" gojwtnethttpctx "github.com/ralvarezdev/go-jwt/net/http/context" @@ -49,11 +48,6 @@ func (m *Middleware) Authenticate( err error, errorCode *string, ), - refreshTokenFn func( - w http.ResponseWriter, - r *http.Request, - ) error, - authenticateFn func(next http.Handler) http.Handler, ) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc( @@ -64,33 +58,10 @@ func (m *Middleware) Authenticate( token, ) if err != nil { - // Check if the error is a token expired error - if token == gojwttoken.RefreshToken || !errors.Is( - err, - jwt.ErrTokenExpired, - ) || refreshTokenFn == nil { - failHandler( - w, - err, - ErrCodeInvalidTokenClaims, - ) - return - } - - // Refresh the token - if err = refreshTokenFn(w, r); err != nil { - failHandler( - w, - err, - ErrCodeFailedToRefreshToken, - ) - return - } - - // Authenticate again - authenticateFn(next).ServeHTTP( + failHandler( w, - r, + err, + ErrCodeInvalidTokenClaims, ) return } @@ -153,8 +124,6 @@ func (m *Middleware) AuthenticateFromHeader( token, rawToken, failHandler, - nil, - nil, )(next).ServeHTTP( w, r, @@ -167,64 +136,125 @@ func (m *Middleware) AuthenticateFromHeader( // AuthenticateFromCookie return the middleware function that authenticates the request from the cookie func (m *Middleware) AuthenticateFromCookie( token gojwttoken.Token, - cookieName string, + cookieRefreshTokenName, + cookieAccessTokenName string, refreshTokenFn func( w http.ResponseWriter, r *http.Request, - ) error, + ) (*map[gojwttoken.Token]string, error), ) func(next http.Handler) http.Handler { + var cookieName string + if token == gojwttoken.AccessToken { + cookieName = cookieAccessTokenName + } else if token == gojwttoken.RefreshToken { + cookieName = cookieRefreshTokenName + } + // Create the fail handler function failHandler := func( + cookieName string, + ) func( w http.ResponseWriter, err error, errorCode *string, ) { - m.handler.HandleError( - w, - gonethttpresponse.NewFailResponseError( - cookieName, - err.Error(), - errorCode, - http.StatusUnauthorized, - ), - ) + return func( + w http.ResponseWriter, + err error, + errorCode *string, + ) { + m.handler.HandleError( + w, + gonethttpresponse.NewFailResponseError( + cookieName, + err.Error(), + errorCode, + http.StatusUnauthorized, + ), + ) + } } // Create the authenticate function - var authenticateFn func(next http.Handler) http.Handler - authenticateFn = func(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - // Get the cookie - cookie, err := r.Cookie(cookieName) + var authenticateFn func(*map[gojwttoken.Token]string) func(next http.Handler) http.Handler + authenticateFn = func(rawTokens *map[gojwttoken.Token]string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + var rawToken string + var cookie *http.Cookie + var err error + var ok bool - // Return an error if the cookie is missing - if err != nil { - failHandler( - w, - gonethttp.ErrCookieNotFound, - gonethttp.ErrCodeCookieNotFound, - ) - return - } + // Get the cookie + if rawTokens != nil { + // Get the raw token from the map + rawToken, ok = (*rawTokens)[token] - // Get the raw token from the cookie - rawToken := cookie.Value + // Return an error if the token is missing + if !ok { + failHandler(cookieName)( + w, + gonethttp.ErrCookieNotFound, + gonethttp.ErrCodeCookieNotFound, + ) + return + } + } else { + // Get the cookie from the request + cookie, err = r.Cookie(cookieAccessTokenName) - // Call the Authenticate function - m.Authenticate( - token, - rawToken, - failHandler, - refreshTokenFn, - authenticateFn, - )(next).ServeHTTP( - w, - r, - ) - }, - ) + // Check if there was an error getting the cookie + if err == nil { + // Get the raw token from the cookie + rawToken = cookie.Value + } else if errors.Is(err, http.ErrNoCookie) { + // Check if the token can be refreshed + if token == gojwttoken.AccessToken && refreshTokenFn != nil { + // Refresh the token + rawTokens, err = refreshTokenFn(w, r) + if err != nil { + failHandler(cookieRefreshTokenName)( + w, + err, + ErrCodeFailedToRefreshToken, + ) + return + } + + // Authenticate again + authenticateFn(rawTokens)(next).ServeHTTP( + w, + r, + ) + return + } + } + } + + // Check if the raw token is empty + if rawToken == "" { + failHandler(cookieAccessTokenName)( + w, + gonethttp.ErrCookieNotFound, + gonethttp.ErrCodeCookieNotFound, + ) + return + } + + // Call the Authenticate function + m.Authenticate( + token, + rawToken, + failHandler, + )(next).ServeHTTP( + w, + r, + ) + }, + ) + } } - return authenticateFn + return authenticateFn(nil) }