Skip to content

Commit

Permalink
refactor: improved jwt authenticator middleware to handle expired coo…
Browse files Browse the repository at this point in the history
…kies
  • Loading branch information
ralvarezdev committed Feb 16, 2025
1 parent db022d2 commit 24680eb
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 83 deletions.
10 changes: 3 additions & 7 deletions http/middleware/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@ type Authenticator interface {
err error,
errorCode *string,
),
refreshTokenFn func(
w http.ResponseWriter,
r *http.Request,
) error,
authenticateFn func(next http.Handler) http.Handler,
) func(next http.Handler) http.Handler
AuthenticateFromHeader(
token gojwttoken.Token,
) func(next http.Handler) http.Handler
AuthenticateFromCookie(
token gojwttoken.Token,
cookieName string,
cookieRefreshTokenName,
cookieAccessTokenName string,
refreshTokenFn func(
w http.ResponseWriter,
r *http.Request,
) error,
) (*map[gojwttoken.Token]string, error),
) func(next http.Handler) http.Handler
}
182 changes: 106 additions & 76 deletions http/middleware/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package auth

import (
"errors"
"github.com/golang-jwt/jwt/v5"
gojwt "github.com/ralvarezdev/go-jwt"
gojwtnethttp "github.com/ralvarezdev/go-jwt/net/http"
gojwtnethttpctx "github.com/ralvarezdev/go-jwt/net/http/context"
Expand Down Expand Up @@ -49,11 +48,6 @@ func (m *Middleware) Authenticate(
err error,
errorCode *string,
),
refreshTokenFn func(
w http.ResponseWriter,
r *http.Request,
) error,
authenticateFn func(next http.Handler) http.Handler,
) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(
Expand All @@ -64,33 +58,10 @@ func (m *Middleware) Authenticate(
token,
)
if err != nil {
// Check if the error is a token expired error
if token == gojwttoken.RefreshToken || !errors.Is(
err,
jwt.ErrTokenExpired,
) || refreshTokenFn == nil {
failHandler(
w,
err,
ErrCodeInvalidTokenClaims,
)
return
}

// Refresh the token
if err = refreshTokenFn(w, r); err != nil {
failHandler(
w,
err,
ErrCodeFailedToRefreshToken,
)
return
}

// Authenticate again
authenticateFn(next).ServeHTTP(
failHandler(
w,
r,
err,
ErrCodeInvalidTokenClaims,
)
return
}
Expand Down Expand Up @@ -153,8 +124,6 @@ func (m *Middleware) AuthenticateFromHeader(
token,
rawToken,
failHandler,
nil,
nil,
)(next).ServeHTTP(
w,
r,
Expand All @@ -167,64 +136,125 @@ func (m *Middleware) AuthenticateFromHeader(
// AuthenticateFromCookie return the middleware function that authenticates the request from the cookie
func (m *Middleware) AuthenticateFromCookie(
token gojwttoken.Token,
cookieName string,
cookieRefreshTokenName,
cookieAccessTokenName string,
refreshTokenFn func(
w http.ResponseWriter,
r *http.Request,
) error,
) (*map[gojwttoken.Token]string, error),
) func(next http.Handler) http.Handler {
var cookieName string
if token == gojwttoken.AccessToken {
cookieName = cookieAccessTokenName
} else if token == gojwttoken.RefreshToken {
cookieName = cookieRefreshTokenName
}

// Create the fail handler function
failHandler := func(
cookieName string,
) func(
w http.ResponseWriter,
err error,
errorCode *string,
) {
m.handler.HandleError(
w,
gonethttpresponse.NewFailResponseError(
cookieName,
err.Error(),
errorCode,
http.StatusUnauthorized,
),
)
return func(
w http.ResponseWriter,
err error,
errorCode *string,
) {
m.handler.HandleError(
w,
gonethttpresponse.NewFailResponseError(
cookieName,
err.Error(),
errorCode,
http.StatusUnauthorized,
),
)
}
}

// Create the authenticate function
var authenticateFn func(next http.Handler) http.Handler
authenticateFn = func(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Get the cookie
cookie, err := r.Cookie(cookieName)
var authenticateFn func(*map[gojwttoken.Token]string) func(next http.Handler) http.Handler
authenticateFn = func(rawTokens *map[gojwttoken.Token]string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
var rawToken string
var cookie *http.Cookie
var err error
var ok bool

// Return an error if the cookie is missing
if err != nil {
failHandler(
w,
gonethttp.ErrCookieNotFound,
gonethttp.ErrCodeCookieNotFound,
)
return
}
// Get the cookie
if rawTokens != nil {
// Get the raw token from the map
rawToken, ok = (*rawTokens)[token]

// Get the raw token from the cookie
rawToken := cookie.Value
// Return an error if the token is missing
if !ok {
failHandler(cookieName)(
w,
gonethttp.ErrCookieNotFound,
gonethttp.ErrCodeCookieNotFound,
)
return
}
} else {
// Get the cookie from the request
cookie, err = r.Cookie(cookieAccessTokenName)

// Call the Authenticate function
m.Authenticate(
token,
rawToken,
failHandler,
refreshTokenFn,
authenticateFn,
)(next).ServeHTTP(
w,
r,
)
},
)
// Check if there was an error getting the cookie
if err == nil {
// Get the raw token from the cookie
rawToken = cookie.Value
} else if errors.Is(err, http.ErrNoCookie) {
// Check if the token can be refreshed
if token == gojwttoken.AccessToken && refreshTokenFn != nil {
// Refresh the token
rawTokens, err = refreshTokenFn(w, r)
if err != nil {
failHandler(cookieRefreshTokenName)(
w,
err,
ErrCodeFailedToRefreshToken,
)
return
}

// Authenticate again
authenticateFn(rawTokens)(next).ServeHTTP(
w,
r,
)
return
}
}
}

// Check if the raw token is empty
if rawToken == "" {
failHandler(cookieAccessTokenName)(
w,
gonethttp.ErrCookieNotFound,
gonethttp.ErrCodeCookieNotFound,
)
return
}

// Call the Authenticate function
m.Authenticate(
token,
rawToken,
failHandler,
)(next).ServeHTTP(
w,
r,
)
},
)
}
}

return authenticateFn
return authenticateFn(nil)
}

0 comments on commit 24680eb

Please sign in to comment.