diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 65e6819..4b89447 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -137,11 +137,11 @@ func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) { "redirect_after_login": canonicalRedirect, } - if acr := login.Acr; acr != "" { - fields["acr"] = acr + if acrValues := login.AcrValues; acrValues != "" { + fields["acr"] = acrValues } - if locale := login.Locale; locale != "" { + if locale := login.UILocales; locale != "" { fields["locale"] = locale } diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index 7b6f0db..e46224c 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -89,16 +89,16 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A } func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) { - params, err := c.AuthParams() + params, err := c.ClientAuthenticationParams() if err != nil { return nil, err } - payload := params.URLValues(map[string]string{ + payload := params.Merge(openid.AuthParams{ "grant_type": "refresh_token", "refresh_token": refreshToken, "client_id": c.cfg.Client().ClientID(), - }).Encode() + }).URLValues().Encode() endpoint := c.cfg.Provider().TokenEndpoint() body, err := c.oauthPostRequest(ctx, endpoint, payload) @@ -114,7 +114,7 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid return &tokenResponse, nil } -func (c *Client) AuthParams() (openid.AuthParams, error) { +func (c *Client) ClientAuthenticationParams() (openid.AuthParams, error) { switch c.cfg.Client().AuthMethod() { case openidconfig.AuthMethodPrivateKeyJWT: assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime) @@ -122,10 +122,10 @@ func (c *Client) AuthParams() (openid.AuthParams, error) { return nil, fmt.Errorf("creating client assertion: %w", err) } - return openid.AuthParamsJwtBearer(assertion), nil + return openid.ClientAuthParamsJwtBearer(assertion), nil case openidconfig.AuthMethodClientSecret: - return openid.AuthParamsClientSecret(c.cfg.Client().ClientSecret()), nil + return openid.ClientAuthParamsSecret(c.cfg.Client().ClientSecret()), nil } return nil, fmt.Errorf("unsupported client authentication method: %q", c.cfg.Client().AuthMethod()) diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 7a317b6..5bdbcd0 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -25,7 +25,6 @@ const ( LocaleURLParameter = "locale" SecurityLevelURLParameter = "level" PromptURLParameter = "prompt" - MaxAgeURLParameter = "max_age" ) var ( @@ -34,43 +33,17 @@ var ( ErrInvalidPrompt = errors.New("InvalidPrompt") ErrInvalidLoginParameter = errors.New("InvalidLoginParameter") - // LoginParameterMapping maps incoming login parameters to OpenID Connect parameters - LoginParameterMapping = map[string]string{ - LocaleURLParameter: "ui_locales", - SecurityLevelURLParameter: "acr_values", - } - PromptAllowedValues = []string{"login", "select_account"} ) type Login struct { - authorizationRequest + openid.AuthorizationCodeParams AuthCodeURL string Cookie openid.LoginCookie } -type authorizationRequest struct { - Acr string - CallbackURL string - CodeVerifier string - Locale string - Nonce string - Prompt string - State string -} - -func (a authorizationRequest) ToCookie() openid.LoginCookie { - return openid.LoginCookie{ - Acr: a.Acr, - CodeVerifier: a.CodeVerifier, - Nonce: a.Nonce, - State: a.State, - RedirectURI: a.CallbackURL, - } -} - func (c *Client) Login(r *http.Request) (*Login, error) { - request, err := c.newAuthorizationRequest(r) + request, err := c.newAuthorizationCodeParams(r) if err != nil { return nil, fmt.Errorf("login: %w", err) } @@ -81,14 +54,14 @@ func (c *Client) Login(r *http.Request) (*Login, error) { } return &Login{ - AuthCodeURL: authCodeURL, - authorizationRequest: request, - Cookie: request.ToCookie(), + AuthCodeURL: authCodeURL, + AuthorizationCodeParams: request, + Cookie: request.Cookie(), }, nil } -func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest, error) { - var req authorizationRequest +func (c *Client) newAuthorizationCodeParams(r *http.Request) (openid.AuthorizationCodeParams, error) { + var req openid.AuthorizationCodeParams callbackURL, err := url.LoginCallback(r) if err != nil { @@ -120,86 +93,42 @@ func (c *Client) newAuthorizationRequest(r *http.Request) (authorizationRequest, return req, fmt.Errorf("creating state: %w", err) } + resource := c.cfg.Client().ResourceIndicator() codeVerifier := oauth2.GenerateVerifier() - return authorizationRequest{ - Acr: acrParam, - CallbackURL: callbackURL, + return openid.AuthorizationCodeParams{ + AcrValues: acrParam, + ClientID: c.oauth2Config.ClientID, CodeVerifier: codeVerifier, - Locale: locale, Nonce: nonce, Prompt: prompt, + RedirectURI: callbackURL, + Resource: resource, + Scope: c.oauth2Config.Scopes, State: state, + UILocales: locale, }, nil } -func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest) (string, error) { +func (c *Client) authCodeURL(ctx context.Context, request openid.AuthorizationCodeParams) (string, error) { var authCodeURL string if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" { - opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("nonce", request.Nonce), - oauth2.SetAuthURLParam("response_mode", "query"), - oauth2.S256ChallengeOption(request.CodeVerifier), - openid.RedirectURIOption(request.CallbackURL), - } - - if resource := c.cfg.Client().ResourceIndicator(); resource != "" { - opts = append(opts, oauth2.SetAuthURLParam("resource", resource)) - } - - if len(request.Acr) > 0 { - opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[SecurityLevelURLParameter], request.Acr)) - } - - if len(request.Locale) > 0 { - opts = append(opts, oauth2.SetAuthURLParam(LoginParameterMapping[LocaleURLParameter], request.Locale)) - } - - if len(request.Prompt) > 0 { - opts = append(opts, oauth2.SetAuthURLParam(PromptURLParameter, request.Prompt)) - opts = append(opts, oauth2.SetAuthURLParam(MaxAgeURLParameter, "0")) - } + opts := request.AuthParams().AuthCodeOptions() + // TODO: replace with separate function authCodeURL = c.oauth2Config.AuthCodeURL(request.State, opts...) } else { - params := map[string]string{ - "client_id": c.oauth2Config.ClientID, - "code_challenge": oauth2.S256ChallengeFromVerifier(request.CodeVerifier), - "code_challenge_method": "S256", - "nonce": request.Nonce, - "redirect_uri": request.CallbackURL, - "response_mode": "query", - "response_type": "code", - "scope": stringslib.Join(c.oauth2Config.Scopes, " "), - "state": request.State, - } - - if resource := c.cfg.Client().ResourceIndicator(); resource != "" { - params["resource"] = resource - } - - if len(request.Acr) > 0 { - params[LoginParameterMapping[SecurityLevelURLParameter]] = request.Acr - } - - if len(request.Locale) > 0 { - params[LoginParameterMapping[LocaleURLParameter]] = request.Locale - } - - if len(request.Prompt) > 0 { - params[PromptURLParameter] = request.Prompt - params[MaxAgeURLParameter] = "0" - } - - authParams, err := c.AuthParams() + clientAuthParams, err := c.ClientAuthenticationParams() if err != nil { return "", fmt.Errorf("generating client authentication parameters: %w", err) } - payload := authParams.URLValues(params).Encode() endpoint := c.cfg.Provider().PushedAuthorizationRequestEndpoint() - body, err := c.oauthPostRequest(ctx, endpoint, payload) + body, err := c.oauthPostRequest(ctx, endpoint, request.AuthParams(). + Merge(clientAuthParams). + URLValues(). + Encode()) if err != nil { return "", err } @@ -209,6 +138,7 @@ func (c *Client) authCodeURL(ctx context.Context, request authorizationRequest) return "", fmt.Errorf("unmarshalling token response: %w", err) } + // TODO: this can a separate function to replace oauth2config.AuthCodeURL v := urllib.Values{ "client_id": {c.oauth2Config.ClientID}, "request_uri": {pushedAuthorizationResponse.RequestUri}, diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index 9d6ad26..b53050a 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -6,8 +6,6 @@ import ( "fmt" "net/http" - "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/openid" ) @@ -66,15 +64,15 @@ func (c *Client) authorizationServerIssuerIdentification(iss string) error { } func (c *Client) redeemTokens(ctx context.Context, code string, cookie *openid.LoginCookie) (*openid.Tokens, error) { - params, err := c.AuthParams() + params, err := c.ClientAuthenticationParams() if err != nil { return nil, err } - rawTokens, err := c.AuthCodeGrant(ctx, code, params.AuthCodeOptions([]oauth2.AuthCodeOption{ - openid.RedirectURIOption(cookie.RedirectURI), - oauth2.VerifierOption(cookie.CodeVerifier), - })) + rawTokens, err := c.AuthCodeGrant(ctx, code, params.Merge(openid.AuthParams{ + "redirect_uri": cookie.RedirectURI, + "code_verifier": cookie.CodeVerifier, + }).AuthCodeOptions()) if err != nil { return nil, fmt.Errorf("exchanging authorization code for token: %w", err) } diff --git a/pkg/openid/oauth2.go b/pkg/openid/oauth2.go index ef05828..998bcc5 100644 --- a/pkg/openid/oauth2.go +++ b/pkg/openid/oauth2.go @@ -3,6 +3,7 @@ package openid import ( "fmt" "net/url" + "strings" "golang.org/x/oauth2" ) @@ -27,10 +28,71 @@ type TokenErrorResponse struct { ErrorDescription string `json:"error_description"` } +// AuthorizationCodeParams represents the (variable) parameters for the authorization code flow. +type AuthorizationCodeParams struct { + AcrValues string + ClientID string + CodeVerifier string + Nonce string + Prompt string + RedirectURI string + Resource string + Scope []string + State string + UILocales string +} + +// AuthParams converts AuthorizationCodeParams the actual parameters to be sent to the authorization server as part of the authorization code flow. +func (a AuthorizationCodeParams) AuthParams() AuthParams { + params := AuthParams{ + "client_id": a.ClientID, + "code_challenge": oauth2.S256ChallengeFromVerifier(a.CodeVerifier), + "code_challenge_method": "S256", + "nonce": a.Nonce, + "prompt": a.Prompt, + "redirect_uri": a.RedirectURI, + "response_mode": "query", + "response_type": "code", + "scope": strings.Join(a.Scope, " "), + "state": a.State, + } + + if len(a.AcrValues) > 0 { + params["acr_values"] = a.AcrValues + } + + if len(a.UILocales) > 0 { + params["ui_locales"] = a.UILocales + } + + if len(a.Prompt) > 0 { + params["max_age"] = "0" + } + + if len(a.Resource) > 0 { + params["resource"] = a.Resource + } + + return params +} + +// Cookie creates a LoginCookie for storing client-side state as part of the authorization code flow. +func (a AuthorizationCodeParams) Cookie() LoginCookie { + return LoginCookie{ + Acr: a.AcrValues, + CodeVerifier: a.CodeVerifier, + Nonce: a.Nonce, + State: a.State, + RedirectURI: a.RedirectURI, + } +} + type AuthParams map[string]string -// AuthCodeOptions adds AuthParams to the given [oauth2.AuthCodeOption] slice and returns the updated slice. -func (a AuthParams) AuthCodeOptions(opts []oauth2.AuthCodeOption) []oauth2.AuthCodeOption { +// AuthCodeOptions converts AuthParams to a slice of [oauth2.AuthCodeOption]. +func (a AuthParams) AuthCodeOptions() []oauth2.AuthCodeOption { + opts := make([]oauth2.AuthCodeOption, 0, len(a)) + for key, val := range a { opts = append(opts, oauth2.SetAuthURLParam(key, val)) } @@ -38,14 +100,10 @@ func (a AuthParams) AuthCodeOptions(opts []oauth2.AuthCodeOption) []oauth2.AuthC return opts } -// URLValues adds AuthParams to the given map of parameters and returns a [url.Values]. -func (a AuthParams) URLValues(params map[string]string) url.Values { +// URLValues converts AuthParams to a [url.Values]. +func (a AuthParams) URLValues() url.Values { v := url.Values{} - for key, val := range params { - v.Set(key, val) - } - for key, val := range a { v.Set(key, val) } @@ -53,27 +111,33 @@ func (a AuthParams) URLValues(params map[string]string) url.Values { return v } -// AuthParamsClientSecret returns a map of parameters to be sent to the authorization server when using a client secret for client authentication in RFC 6749, section 2.3.1. +// Merge merges two AuthParams into one. +// Conflicting keys are overridden by the given AuthParams. +func (a AuthParams) Merge(other AuthParams) AuthParams { + for key, val := range other { + a[key] = val + } + + return a +} + +// ClientAuthParamsSecret returns a map of parameters to be sent to the authorization server when using a client secret for client authentication in RFC 6749, section 2.3.1. // The target authorization server must support the "client_secret_post" client authentication method. -func AuthParamsClientSecret(clientSecret string) AuthParams { - return map[string]string{ +func ClientAuthParamsSecret(clientSecret string) AuthParams { + return AuthParams{ "client_secret": clientSecret, } } -// AuthParamsJwtBearer returns a map of parameters to be sent to the authorization server when using a JWT for client authentication in RFC 7523, section 2.2. +// ClientAuthParamsJwtBearer returns a map of parameters to be sent to the authorization server when using a JWT for client authentication in RFC 7523, section 2.2. // The target authorization server must support the "private_key_jwt" client authentication method. -func AuthParamsJwtBearer(clientAssertion string) AuthParams { - return map[string]string{ +func ClientAuthParamsJwtBearer(clientAssertion string) AuthParams { + return AuthParams{ "client_assertion": clientAssertion, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", } } -func RedirectURIOption(redirectUri string) oauth2.AuthCodeOption { - return oauth2.SetAuthURLParam("redirect_uri", redirectUri) -} - func StateMismatchError(queryParams url.Values, expectedState string) error { actualState := queryParams.Get("state")