Skip to content

Commit

Permalink
refactor(openid): extract request params for remaining grants, minor …
Browse files Browse the repository at this point in the history
…cleanups
  • Loading branch information
tronghn committed Jan 24, 2025
1 parent 062e7b0 commit c147a5a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 53 deletions.
21 changes: 9 additions & 12 deletions pkg/openid/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,15 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A
}

func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) {
params, err := c.ClientAuthenticationParams()
clientAuth, err := c.ClientAuthenticationParams()
if err != nil {
return nil, err
}

payload := params.Merge(openid.AuthParams{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": c.cfg.Client().ClientID(),
}).URLValues().Encode()

endpoint := c.cfg.Provider().TokenEndpoint()
payload := openid.RefreshGrantParams(c.cfg.Client().ClientID(), refreshToken).
With(clientAuth)

body, err := c.oauthPostRequest(ctx, endpoint, payload)
if err != nil {
return nil, err
Expand All @@ -114,18 +111,18 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
return &tokenResponse, nil
}

func (c *Client) ClientAuthenticationParams() (openid.AuthParams, error) {
func (c *Client) ClientAuthenticationParams() (openid.RequestParams, error) {
switch c.cfg.Client().AuthMethod() {
case openidconfig.AuthMethodPrivateKeyJWT:
assertion, err := c.MakeAssertion(DefaultClientAssertionLifetime)
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
}

return openid.ClientAuthParamsJwtBearer(assertion), nil
return openid.ClientAuthJwtBearerParams(assertion), nil

case openidconfig.AuthMethodClientSecret:
return openid.ClientAuthParamsSecret(c.cfg.Client().ClientSecret()), nil
return openid.ClientAuthSecretParams(c.cfg.Client().ClientSecret()), nil
}

return nil, fmt.Errorf("unsupported client authentication method: %q", c.cfg.Client().AuthMethod())
Expand Down Expand Up @@ -163,8 +160,8 @@ func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
return string(encoded), nil
}

func (c *Client) oauthPostRequest(ctx context.Context, endpoint, payload string) ([]byte, error) {
r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(payload))
func (c *Client) oauthPostRequest(ctx context.Context, endpoint string, payload openid.RequestParams) ([]byte, error) {
r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(payload.URLValues().Encode()))
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
Expand Down
27 changes: 10 additions & 17 deletions pkg/openid/client/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,16 @@ func (c *Client) newAuthorizationCodeParams(r *http.Request) (openid.Authorizati
}, nil
}

func (c *Client) authCodeURL(ctx context.Context, request openid.AuthorizationCodeParams) (string, error) {
var authCodeURL string

if c.cfg.Provider().PushedAuthorizationRequestEndpoint() == "" {
opts := request.AuthParams().AuthCodeOptions()

// TODO: replace with separate function
authCodeURL = c.oauth2Config.AuthCodeURL(request.State, opts...)
} else {
clientAuthParams, err := c.ClientAuthenticationParams()
func (c *Client) authCodeURL(ctx context.Context, authCodeParams openid.AuthorizationCodeParams) (string, error) {
usePushedAuthorization := len(c.cfg.Provider().PushedAuthorizationRequestEndpoint()) > 0
if usePushedAuthorization {
clientAuth, err := c.ClientAuthenticationParams()
if err != nil {
return "", fmt.Errorf("generating client authentication parameters: %w", err)
}

endpoint := c.cfg.Provider().PushedAuthorizationRequestEndpoint()
body, err := c.oauthPostRequest(ctx, endpoint, request.AuthParams().
Merge(clientAuthParams).
URLValues().
Encode())
body, err := c.oauthPostRequest(ctx, endpoint, authCodeParams.RequestParams().With(clientAuth))
if err != nil {
return "", err
}
Expand All @@ -138,7 +129,7 @@ func (c *Client) authCodeURL(ctx context.Context, request openid.AuthorizationCo
return "", fmt.Errorf("unmarshalling token response: %w", err)
}

// TODO: this can a separate function to replace oauth2config.AuthCodeURL
// TODO: this can be a separate function to replace oauth2config.AuthCodeURL
v := urllib.Values{
"client_id": {c.oauth2Config.ClientID},
"request_uri": {pushedAuthorizationResponse.RequestUri},
Expand All @@ -151,10 +142,12 @@ func (c *Client) authCodeURL(ctx context.Context, request openid.AuthorizationCo
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
authCodeURL = buf.String()
return buf.String(), nil
}

return authCodeURL, nil
opts := authCodeParams.RequestParams().AuthCodeOptions()
// TODO: replace with separate function
return c.oauth2Config.AuthCodeURL(authCodeParams.State, opts...), nil
}

func (l *Login) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
Expand Down
14 changes: 9 additions & 5 deletions pkg/openid/client/login_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,19 @@ func (c *Client) authorizationServerIssuerIdentification(iss string) error {
}

func (c *Client) redeemTokens(ctx context.Context, code string, cookie *openid.LoginCookie) (*openid.Tokens, error) {
params, err := c.ClientAuthenticationParams()
clientAuth, err := c.ClientAuthenticationParams()
if err != nil {
return nil, err
}

rawTokens, err := c.AuthCodeGrant(ctx, code, params.Merge(openid.AuthParams{
"redirect_uri": cookie.RedirectURI,
"code_verifier": cookie.CodeVerifier,
}).AuthCodeOptions())
payload := openid.ExchangeAuthorizationCodeParams(
c.cfg.Client().ClientID(),
code,
cookie.CodeVerifier,
cookie.RedirectURI,
).With(clientAuth).AuthCodeOptions()

rawTokens, err := c.AuthCodeGrant(ctx, code, payload)
if err != nil {
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
}
Expand Down
63 changes: 44 additions & 19 deletions pkg/openid/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,23 @@ type AuthorizationCodeParams struct {
Prompt string
RedirectURI string
Resource string
Scope []string
Scope scopes.Scopes
State string
UILocales string
}

// AuthParams converts AuthorizationCodeParams the actual parameters to be sent to the authorization server as part of the authorization code flow.
func (a AuthorizationCodeParams) AuthParams() AuthParams {
params := AuthParams{
// RequestParams converts AuthorizationCodeParams the actual parameters to be sent to the authorization server as part of the authorization code flow.
// This mandates required use of PKCE (RFC 7636), state and nonce.
func (a AuthorizationCodeParams) RequestParams() RequestParams {
params := RequestParams{
"client_id": a.ClientID,
"code_challenge": oauth2.S256ChallengeFromVerifier(a.CodeVerifier),
"code_challenge_method": "S256",
"nonce": a.Nonce,
"redirect_uri": a.RedirectURI,
"response_mode": "query",
"response_type": "code",
"scope": strings.Join(a.Scope, " "),
"scope": a.Scope.String(),
"state": a.State,
}

Expand Down Expand Up @@ -88,10 +89,10 @@ func (a AuthorizationCodeParams) Cookie() LoginCookie {
}
}

type AuthParams map[string]string
type RequestParams map[string]string

// AuthCodeOptions converts AuthParams to a slice of [oauth2.AuthCodeOption].
func (a AuthParams) AuthCodeOptions() []oauth2.AuthCodeOption {
// AuthCodeOptions converts RequestParams to a slice of [oauth2.AuthCodeOption].
func (a RequestParams) AuthCodeOptions() []oauth2.AuthCodeOption {
opts := make([]oauth2.AuthCodeOption, 0, len(a))

for key, val := range a {
Expand All @@ -101,8 +102,8 @@ func (a AuthParams) AuthCodeOptions() []oauth2.AuthCodeOption {
return opts
}

// URLValues converts AuthParams to a [url.Values].
func (a AuthParams) URLValues() url.Values {
// URLValues converts RequestParams to a [url.Values].
func (a RequestParams) URLValues() url.Values {
v := url.Values{}

for key, val := range a {
Expand All @@ -112,33 +113,57 @@ func (a AuthParams) URLValues() url.Values {
return v
}

// Merge merges two AuthParams into one.
// Conflicting keys are overridden by the given AuthParams.
func (a AuthParams) Merge(other AuthParams) AuthParams {
// With returns a new RequestParams with the given RequestParams added.
// Conflicting keys are overridden by the given RequestParams.
func (a RequestParams) With(other RequestParams) RequestParams {
for key, val := range other {
a[key] = val
}

return a
}

// ClientAuthParamsSecret returns a map of parameters to be sent to the authorization server when using a client secret for client authentication in RFC 6749, section 2.3.1.
// ClientAuthSecretParams returns a map of parameters to be sent to the authorization server when using a client secret for client authentication in RFC 6749, section 2.3.1.
// The target authorization server must support the "client_secret_post" client authentication method.
func ClientAuthParamsSecret(clientSecret string) AuthParams {
return AuthParams{
func ClientAuthSecretParams(clientSecret string) RequestParams {
return RequestParams{
"client_secret": clientSecret,
}
}

// ClientAuthParamsJwtBearer returns a map of parameters to be sent to the authorization server when using a JWT for client authentication in RFC 7523, section 2.2.
// ClientAuthJwtBearerParams returns a map of parameters to be sent to the authorization server when using a JWT for client authentication in RFC 7523, section 2.2.
// The target authorization server must support the "private_key_jwt" client authentication method.
func ClientAuthParamsJwtBearer(clientAssertion string) AuthParams {
return AuthParams{
func ClientAuthJwtBearerParams(clientAssertion string) RequestParams {
return RequestParams{
"client_assertion": clientAssertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
}
}

// ExchangeAuthorizationCodeParams returns a map of parameters to be sent to the authorization server when exchanging
// an authorization code for token request as defined in RFC 6749, section 4.1.3.
//
// Additionally, PKCE (RFC 7636) is required for this request.
func ExchangeAuthorizationCodeParams(clientID, code, codeVerifier, redirectURI string) RequestParams {
return RequestParams{
"client_id": clientID,
"code": code,
"code_verifier": codeVerifier,
"grant_type": "authorization_code",
"redirect_uri": redirectURI,
}
}

// RefreshGrantParams returns a map of parameters to be sent to the authorization server when performing the refresh
// token grant as defined in RFC 6749, section 6.
func RefreshGrantParams(clientID, refreshToken string) RequestParams {
return RequestParams{
"client_id": clientID,
"grant_type": "refresh_token",
"refresh_token": refreshToken,
}
}

func StateMismatchError(queryParams url.Values, expectedState string) error {
actualState := queryParams.Get("state")

Expand Down

0 comments on commit c147a5a

Please sign in to comment.