From 7637fe0da962c80798ec8a813fcd3cd563c16c1e Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 4 Oct 2024 21:47:16 +1000 Subject: [PATCH] feat: claims interface --- token/jwt/claims_id_token.go | 94 +++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index e63c325b..7bddfe33 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -5,7 +5,6 @@ package jwt import ( "bytes" - "encoding/json" "errors" "fmt" "time" @@ -22,14 +21,15 @@ type IDTokenClaims struct { JTI string `json:"jti"` Issuer string `json:"iss"` Subject string `json:"sub"` - Audience []string `json:"aud,omitempty"` - Nonce string `json:"nonce,omitempty"` - ExpirationTime *NumericDate `json:"exp,omitempty"` - IssuedAt *NumericDate `json:"iat,omitempty"` - RequestedAt *NumericDate `json:"rat,omitempty"` + Audience []string `json:"aud"` + ExpirationTime *NumericDate `json:"exp"` + IssuedAt *NumericDate `json:"iat"` AuthTime *NumericDate `json:"auth_time,omitempty"` + RequestedAt *NumericDate `json:"rat,omitempty"` + Nonce string `json:"nonce,omitempty"` AuthenticationContextClassReference string `json:"acr,omitempty"` AuthenticationMethodsReferences []string `json:"amr,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` AccessTokenHash string `json:"at_hash,omitempty"` CodeHash string `json:"c_hash,omitempty"` StateHash string `json:"s_hash,omitempty"` @@ -171,12 +171,6 @@ func (c *IDTokenClaims) GetRequestedAtSafe() time.Time { return c.RequestedAt.UTC() } -func (c *IDTokenClaims) MarshalJSON() (data []byte, err error) { - claims := c.ToMapClaims() - - return json.Marshal(claims) -} - func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { claims := MapClaims{} @@ -204,8 +198,6 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { c.Subject, ok = value.(string) case ClaimAudience: c.Audience, ok = toStringSlice(value) - case ClaimNonce: - c.Nonce, ok = value.(string) case ClaimExpirationTime: if c.ExpirationTime, err = toNumericDate(value); err == nil { ok = true @@ -214,24 +206,30 @@ func (c *IDTokenClaims) UnmarshalJSON(data []byte) error { if c.IssuedAt, err = toNumericDate(value); err == nil { ok = true } - case ClaimRequestedAt: - if c.RequestedAt, err = toNumericDate(value); err == nil { - ok = true - } case ClaimAuthenticationTime: if c.AuthTime, err = toNumericDate(value); err == nil { ok = true } + case ClaimRequestedAt: + if c.RequestedAt, err = toNumericDate(value); err == nil { + ok = true + } + case ClaimNonce: + c.Nonce, ok = value.(string) case ClaimAuthenticationContextClassReference: c.AuthenticationContextClassReference, ok = value.(string) case ClaimAuthenticationMethodsReference: c.AuthenticationMethodsReferences, ok = toStringSlice(value) + case ClaimAuthorizedParty: + c.AuthorizedParty, ok = value.(string) case ClaimAccessTokenHash: c.AccessTokenHash, ok = value.(string) case ClaimCodeHash: c.CodeHash, ok = value.(string) case ClaimStateHash: c.StateHash, ok = value.(string) + case ClaimExtra: + c.Extra, ok = value.(map[string]any) default: if c.Extra == nil { c.Extra = make(map[string]any) @@ -255,9 +253,9 @@ func (c *IDTokenClaims) ToMap() map[string]any { var ret = Copy(c.Extra) if c.JTI != "" { - ret[consts.ClaimJWTID] = c.JTI + ret[ClaimJWTID] = c.JTI } else { - ret[consts.ClaimJWTID] = uuid.New().String() + ret[ClaimJWTID] = uuid.New().String() } if c.Issuer != "" { @@ -273,69 +271,75 @@ func (c *IDTokenClaims) ToMap() map[string]any { } if len(c.Audience) > 0 { - ret[consts.ClaimAudience] = c.Audience + ret[ClaimAudience] = c.Audience } else { delete(ret, ClaimAudience) } - if len(c.Nonce) > 0 { - ret[consts.ClaimNonce] = c.Nonce + if c.ExpirationTime != nil { + ret[ClaimExpirationTime] = c.ExpirationTime.Unix() } else { - delete(ret, consts.ClaimNonce) + delete(ret, ClaimExpirationTime) } - if c.ExpirationTime != nil { - ret[consts.ClaimExpirationTime] = c.ExpirationTime.Unix() + if c.IssuedAt != nil { + ret[ClaimIssuedAt] = c.IssuedAt.Unix() } else { - delete(ret, consts.ClaimExpirationTime) + delete(ret, ClaimIssuedAt) } - if c.IssuedAt != nil { - ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + if c.AuthTime != nil { + ret[ClaimAuthenticationTime] = c.AuthTime.Unix() } else { - delete(ret, consts.ClaimIssuedAt) + delete(ret, ClaimAuthenticationTime) } if c.RequestedAt != nil { - ret[consts.ClaimRequestedAt] = c.RequestedAt.Unix() + ret[ClaimRequestedAt] = c.RequestedAt.Unix() } else { - delete(ret, consts.ClaimRequestedAt) + delete(ret, ClaimRequestedAt) } - if c.AuthTime != nil { - ret[consts.ClaimAuthenticationTime] = c.AuthTime.Unix() + if len(c.Nonce) > 0 { + ret[ClaimNonce] = c.Nonce } else { - delete(ret, consts.ClaimAuthenticationTime) + delete(ret, ClaimNonce) } if len(c.AuthenticationContextClassReference) > 0 { - ret[consts.ClaimAuthenticationContextClassReference] = c.AuthenticationContextClassReference + ret[ClaimAuthenticationContextClassReference] = c.AuthenticationContextClassReference } else { - delete(ret, consts.ClaimAuthenticationContextClassReference) + delete(ret, ClaimAuthenticationContextClassReference) } if len(c.AuthenticationMethodsReferences) > 0 { - ret[consts.ClaimAuthenticationMethodsReference] = c.AuthenticationMethodsReferences + ret[ClaimAuthenticationMethodsReference] = c.AuthenticationMethodsReferences + } else { + delete(ret, ClaimAuthenticationMethodsReference) + } + + if len(c.AuthorizedParty) > 0 { + ret[ClaimAuthorizedParty] = c.AuthorizedParty } else { - delete(ret, consts.ClaimAuthenticationMethodsReference) + delete(ret, ClaimAuthorizedParty) } if len(c.AccessTokenHash) > 0 { - ret[consts.ClaimAccessTokenHash] = c.AccessTokenHash + ret[ClaimAccessTokenHash] = c.AccessTokenHash } else { - delete(ret, consts.ClaimAccessTokenHash) + delete(ret, ClaimAccessTokenHash) } if len(c.CodeHash) > 0 { - ret[consts.ClaimCodeHash] = c.CodeHash + ret[ClaimCodeHash] = c.CodeHash } else { - delete(ret, consts.ClaimCodeHash) + delete(ret, ClaimCodeHash) } if len(c.StateHash) > 0 { - ret[consts.ClaimStateHash] = c.StateHash + ret[ClaimStateHash] = c.StateHash } else { - delete(ret, consts.ClaimStateHash) + delete(ret, ClaimStateHash) } return ret