Skip to content

Commit

Permalink
refactor(openid/client): remove indirection layer for login callback
Browse files Browse the repository at this point in the history
Co-authored-by: sindrerh2 <sindre.rodseth.hansen@nav.no>
  • Loading branch information
tronghn and sindrerh2 committed Jan 21, 2025
1 parent 75f98de commit 64e9167
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 205 deletions.
13 changes: 1 addition & 12 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,17 @@ func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
return
}

loginCallback, err := s.Client.LoginCallback(r, loginCookie)
tokens, err := s.Client.LoginCallback(r, loginCookie)
if err != nil {
if errors.Is(err, openidclient.ErrCallbackInvalidState) || errors.Is(err, openidclient.ErrCallbackInvalidIssuer) {
s.Unauthorized(w, r, err)
return
}

if errors.Is(err, openidclient.ErrCallbackIdentityProvider) {
s.InternalError(w, r, err)
return
}

s.InternalError(w, r, err)
return
}

tokens, err := loginCallback.RedeemTokens(r.Context())
if err != nil {
s.InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
return
}

sessionLifetime := s.Config.Session.MaxLifetime

sess, err := s.SessionManager.Create(r, tokens, sessionLifetime)
Expand Down
9 changes: 0 additions & 9 deletions pkg/openid/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,6 @@ func (c *Client) Login(r *http.Request) (*Login, error) {
return login, nil
}

func (c *Client) LoginCallback(r *http.Request, cookie *openid.LoginCookie) (*LoginCallback, error) {
loginCallback, err := NewLoginCallback(c, r, cookie)
if err != nil {
return nil, fmt.Errorf("callback: %w", err)
}

return loginCallback, nil
}

func (c *Client) Logout(r *http.Request) (*Logout, error) {
logout, err := NewLogout(c, r)
if err != nil {
Expand Down
88 changes: 41 additions & 47 deletions pkg/openid/client/login_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,28 @@ import (
"errors"
"fmt"
"net/http"
"net/url"

"golang.org/x/oauth2"

"github.com/nais/wonderwall/pkg/openid"
urlpkg "github.com/nais/wonderwall/pkg/url"
)

var (
ErrCallbackIdentityProvider = errors.New("identity provider error")
ErrCallbackInvalidState = errors.New("invalid state")
ErrCallbackInvalidIssuer = errors.New("invalid issuer")
ErrCallbackIdentityProvider = errors.New("callback: identity provider error")
ErrCallbackInvalidCookie = errors.New("callback: invalid cookie")
ErrCallbackInvalidState = errors.New("callback: invalid state")
ErrCallbackInvalidIssuer = errors.New("callback: invalid issuer")
ErrCallbackRedeemTokens = errors.New("callback: redeeming tokens")
)

type LoginCallback struct {
*Client
cookie *openid.LoginCookie
query url.Values
}

func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*LoginCallback, error) {
func (c *Client) LoginCallback(r *http.Request, cookie *openid.LoginCookie) (*openid.Tokens, error) {
if cookie == nil {
return nil, fmt.Errorf("cookie is nil")
}

// redirect_uri not set in cookie (e.g. login initiated at instance running older version, callback handled at newer version)
if len(cookie.RedirectURI) == 0 {
callbackURL, err := urlpkg.LoginCallback(r)
if err != nil {
return nil, fmt.Errorf("generating callback url: %w", err)
}

cookie.RedirectURI = callbackURL
return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidCookie, "cookie is nil")
}

query := r.URL.Query()
if query.Get("error") != "" {
oauthError := query.Get("error")

if oauthError := query.Get("error"); len(oauthError) > 0 {
oauthErrorDescription := query.Get("error_description")
return nil, fmt.Errorf("%w: %s: %s", ErrCallbackIdentityProvider, oauthError, oauthErrorDescription)
}
Expand All @@ -51,49 +35,59 @@ func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*
return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidState, err)
}

if c.cfg.Provider().AuthorizationResponseIssParameterSupported() {
iss := query.Get("iss")
expectedIss := c.cfg.Provider().Issuer()
if err := c.authorizationServerIssuerIdentification(query.Get("iss")); err != nil {
return nil, fmt.Errorf("%w: %s", ErrCallbackInvalidIssuer, err)
}

tokens, err := c.redeemTokens(r.Context(), query.Get("code"), cookie)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrCallbackRedeemTokens, err)
}

return tokens, nil
}

// Verify iss parameter if provider supports RFC 9207 - OAuth 2.0 Authorization Server Issuer Identification
func (c *Client) authorizationServerIssuerIdentification(iss string) error {
if !c.cfg.Provider().AuthorizationResponseIssParameterSupported() {
return nil
}

if len(iss) == 0 {
return nil, fmt.Errorf("%w: missing issuer parameter", ErrCallbackInvalidIssuer)
}
if len(iss) == 0 {
return fmt.Errorf("missing issuer parameter")
}

if iss != expectedIss {
return nil, fmt.Errorf("%w: issuer mismatch: expected %s, got %s", ErrCallbackInvalidIssuer, expectedIss, iss)
}
expectedIss := c.cfg.Provider().Issuer()
if iss != expectedIss {
return fmt.Errorf("issuer mismatch: expected %q, got %q", expectedIss, iss)
}

return &LoginCallback{
Client: c,
cookie: cookie,
query: query,
}, nil
return nil
}

func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {
params, err := in.AuthParams()
func (c *Client) redeemTokens(ctx context.Context, code string, cookie *openid.LoginCookie) (*openid.Tokens, error) {
params, err := c.AuthParams()
if err != nil {
return nil, err
}

rawTokens, err := in.AuthCodeGrant(ctx, in.query.Get("code"), params.AuthCodeOptions([]oauth2.AuthCodeOption{
openid.RedirectURIOption(in.cookie.RedirectURI),
oauth2.VerifierOption(in.cookie.CodeVerifier),
rawTokens, err := c.AuthCodeGrant(ctx, code, params.AuthCodeOptions([]oauth2.AuthCodeOption{
openid.RedirectURIOption(cookie.RedirectURI),
oauth2.VerifierOption(cookie.CodeVerifier),
}))
if err != nil {
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
}

jwkSet, err := in.jwksProvider.GetPublicJwkSet(ctx)
jwkSet, err := c.jwksProvider.GetPublicJwkSet(ctx)
if err != nil {
return nil, fmt.Errorf("getting jwks: %w", err)
}

tokens, err := openid.NewTokens(rawTokens, jwkSet, in.cfg, in.cookie)
tokens, err := openid.NewTokens(rawTokens, jwkSet, c.cfg, cookie)
if err != nil {
// JWKS might not be up to date, so we'll want to force a refresh for the next attempt
_, _ = in.jwksProvider.RefreshPublicJwkSet(ctx)
_, _ = c.jwksProvider.RefreshPublicJwkSet(ctx)
return nil, fmt.Errorf("parsing tokens: %w", err)
}

Expand Down
Loading

0 comments on commit 64e9167

Please sign in to comment.