diff --git a/consent/manager.go b/consent/manager.go index f09c803c06b..577fffa27f1 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -6,6 +6,7 @@ package consent import ( "context" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" @@ -65,6 +66,8 @@ type ( GetDeviceUserAuthRequest(ctx context.Context, challenge string) (*flow.DeviceUserAuthRequest, error) HandleDeviceUserAuthRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledDeviceUserAuthRequest) (*flow.DeviceUserAuthRequest, error) VerifyAndInvalidateDeviceUserAuthRequest(ctx context.Context, verifier string) (*flow.HandledDeviceUserAuthRequest, error) + + Transaction(context.Context, func(ctx context.Context, c *pop.Connection) error) error } ManagerProvider interface { diff --git a/consent/strategy_default.go b/consent/strategy_default.go index ed84da79e76..a740ee3e3af 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/gobuffalo/pop/v6" "github.com/gorilla/sessions" "github.com/hashicorp/go-retryablehttp" "github.com/pborman/uuid" @@ -39,8 +40,6 @@ import ( "github.com/ory/x/urlx" ) -type ctxKey int - const ( DeviceVerificationPath = "/oauth2/device/verify" CookieAuthenticationSIDName = "sid" @@ -1170,21 +1169,11 @@ func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest( ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") defer otelx.End(span, &err) - return s.handleOAuth2AuthorizationRequest(ctx, w, r, req, nil) -} - -func (s *DefaultStrategy) handleOAuth2AuthorizationRequest( - ctx context.Context, - w http.ResponseWriter, - r *http.Request, - req fosite.AuthorizeRequester, - f *flow.Flow, -) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) if loginVerifier == "" && consentVerifier == "" { // ok, we need to process this request and redirect to the original endpoint - return nil, nil, s.requestAuthentication(ctx, w, r, req, f) + return nil, nil, s.requestAuthentication(ctx, w, r, req, nil) } else if loginVerifier != "" { f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier) if err != nil { @@ -1208,7 +1197,10 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ctx context.Context, w http.ResponseWriter, r *http.Request, -) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { +) (_ *flow.AcceptOAuth2ConsentRequest, _ *flow.Flow, err error) { + ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, "DefaultStrategy.HandleOAuth2AuthorizationRequest") + defer otelx.End(span, &err) + deviceVerifier := strings.TrimSpace(r.URL.Query().Get("device_verifier")) loginVerifier := strings.TrimSpace(r.URL.Query().Get("login_verifier")) consentVerifier := strings.TrimSpace(r.URL.Query().Get("consent_verifier")) @@ -1246,16 +1238,32 @@ func (s *DefaultStrategy) HandleOAuth2DeviceAuthorizationRequest( ar.RequestedAudience = fosite.Arguments(deviceFlow.RequestedAudience) } - // TODO(nsklikas): wrap these 2 function calls in a transaction (one persists the flow and the other invalidates the user_code) - consentSession, f, err := s.handleOAuth2AuthorizationRequest(ctx, w, r, ar, deviceFlow) - if err != nil { - return nil, nil, err - } - err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(r.Context(), string(f.DeviceCodeRequestID), f.ID) - if err != nil { - return nil, nil, err + if loginVerifier == "" && consentVerifier == "" { + // ok, we need to process this request and redirect to the authentication endpoint + return nil, nil, s.requestAuthentication(ctx, w, r, ar, deviceFlow) + } else if loginVerifier != "" { + f, err := s.verifyAuthentication(ctx, w, r, ar, loginVerifier) + if err != nil { + return nil, nil, err + } + + // ok, we need to process this request and redirect to consent endpoint + return nil, f, s.requestConsent(ctx, w, r, ar, f) } + var consentSession *flow.AcceptOAuth2ConsentRequest + var f *flow.Flow + + err = s.r.ConsentManager().Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + consentSession, f, err = s.verifyConsent(ctx, w, r, consentVerifier) + if err != nil { + return err + } + err = s.r.OAuth2Storage().UpdateAndInvalidateUserCodeSessionByRequestID(ctx, string(f.DeviceCodeRequestID), f.ID) + + return err + }) + return consentSession, f, err } diff --git a/oauth2/handler.go b/oauth2/handler.go index d611885a056..32e12e621ec 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -747,14 +747,12 @@ func (h *Handler) performOAuth2DeviceVerificationFlow(w http.ResponseWriter, r * return } - // TODO(nsklikas): We need to add a db transaction here req, err := h.r.OAuth2Storage().GetDeviceCodeSessionByRequestID(ctx, f.DeviceCodeRequestID.String(), &Session{}) if err != nil { x.LogError(r, err, h.r.Logger()) h.r.Writer().WriteError(w, r, err) return } - // TODO(nsklika): Can we refactor this so we don't have to pass in the session? session, err := h.updateSessionWithRequest(ctx, consentSession, f, r, req, req.GetSession().(*Session)) if err != nil { h.r.Writer().WriteError(w, r, err)