Skip to content

Commit

Permalink
feat(headless): information on pending provider signup
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Feb 19, 2025
1 parent 7fa8822 commit 2542b3f
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 52 deletions.
7 changes: 7 additions & 0 deletions ChangeLog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ Note worthy changes
fields are present and required (``'*'``). This change is performed in a
backwards compatible manner.

- Headless: if, while signing up using a third-party provider account, there is
insufficient information received from the provider to automatically complete
the signup process, an additional step is needed to complete the missing data
before the user is fully signed up and authenticated. You can now perform a
``GET`` request to ``/_allauth/{client}/v1/auth/provider/signup`` to obtain
information on the pending signup.


Fixes
-----
Expand Down
9 changes: 5 additions & 4 deletions allauth/account/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,23 @@ def settings_check(app_configs, **kwargs):
)

# Mandatory email verification requires email
email_required = "email" in signup_fields and signup_fields["email"]["required"]
if (
app_settings.EMAIL_VERIFICATION
== app_settings.EmailVerificationMethod.MANDATORY
and not app_settings.EMAIL_REQUIRED
and not email_required
):
ret.append(
Critical(
msg="ACCOUNT_EMAIL_VERIFICATION = 'mandatory' requires ACCOUNT_EMAIL_REQUIRED = True"
msg="ACCOUNT_EMAIL_VERIFICATION = 'mandatory' requires 'email*' in ACCOUNT_SIGNUP_FIELDS"
)
)

if not app_settings.USER_MODEL_USERNAME_FIELD:
if app_settings.USERNAME_REQUIRED:
if "username" in signup_fields:
ret.append(
Critical(
msg="No ACCOUNT_USER_MODEL_USERNAME_FIELD, yet, ACCOUNT_USERNAME_REQUIRED = True"
msg="No ACCOUNT_USER_MODEL_USERNAME_FIELD, yet, ACCOUNT_SIGNUP_FIELDS contains 'username'"
)
)

Expand Down
17 changes: 9 additions & 8 deletions allauth/headless/account/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
from allauth.headless.base.response import APIResponse


def email_address_data(addr):
return {
"email": addr.email,
"verified": addr.verified,
"primary": addr.primary,
}


class RequestEmailVerificationResponse(APIResponse):
def __init__(self, request, verification_sent):
super().__init__(request, status=200 if verification_sent else 403)
Expand All @@ -22,14 +30,7 @@ def __init__(self, request, verification, stage):

class EmailAddressesResponse(APIResponse):
def __init__(self, request, email_addresses):
data = [
{
"email": addr.email,
"verified": addr.verified,
"primary": addr.primary,
}
for addr in email_addresses
]
data = [email_address_data(addr) for addr in email_addresses]
super().__init__(request, data=data)


Expand Down
9 changes: 5 additions & 4 deletions allauth/headless/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ def serialize_user(self, user) -> Dict[str, Any]:
verification).
"""
ret = {
"id": user.pk,
"display": user_display(user),
"has_usable_password": user.has_usable_password(),
}
email = EmailAddress.objects.get_primary_email(user)
if email:
ret["email"] = email
if user.pk:
ret["id"] = user.pk
email = EmailAddress.objects.get_primary_email(user)
if email:
ret["email"] = email
username = user_username(user)
if username:
ret["username"] = username
Expand Down
30 changes: 22 additions & 8 deletions allauth/headless/socialaccount/response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from allauth.headless.account.response import email_address_data
from allauth.headless.adapter import get_adapter
from allauth.headless.base.response import APIResponse
from allauth.headless.constants import Client, Flow
from allauth.socialaccount.adapter import (
Expand All @@ -7,6 +9,14 @@
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider


def _socialaccount_data(request, account):
return {
"uid": account.uid,
"provider": _provider_data(request, account.get_provider()),
"display": account.get_provider_account().to_str(),
}


def _provider_data(request, provider):
ret = {"id": provider.sub_id, "name": provider.name, "flows": []}
if provider.supports_redirect:
Expand Down Expand Up @@ -86,12 +96,16 @@ def get_config_data(request):

class SocialAccountsResponse(APIResponse):
def __init__(self, request, accounts):
data = [
{
"uid": account.uid,
"provider": _provider_data(request, account.get_provider()),
"display": account.get_provider_account().to_str(),
}
for account in accounts
]
data = [_socialaccount_data(request, account) for account in accounts]
super().__init__(request, data=data)


class SocialLoginResponse(APIResponse):
def __init__(self, request, sociallogin):
adapter = get_adapter()
data = {
"user": adapter.serialize_user(sociallogin.user),
"account": _socialaccount_data(request, sociallogin.account),
"email": [email_address_data(ea) for ea in sociallogin.email_addresses],
}
super().__init__(request, data=data)
72 changes: 47 additions & 25 deletions allauth/headless/socialaccount/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from http import HTTPStatus
from unittest.mock import patch

from django.urls import reverse
Expand Down Expand Up @@ -32,7 +33,7 @@ def test_valid_redirect(client, headless_reverse, db):
"process": AuthProcess.LOGIN,
},
)
assert resp.status_code == 302
assert resp.status_code == HTTPStatus.FOUND


def test_manage_providers(auth_client, user, headless_reverse, provider_id):
Expand All @@ -46,16 +47,16 @@ def test_manage_providers(auth_client, user, headless_reverse, provider_id):
headless_reverse("headless:socialaccount:manage_providers"),
)
data = resp.json()
assert data["status"] == 200
assert data["status"] == HTTPStatus.OK
assert len(data["data"]) == 2
resp = auth_client.delete(
headless_reverse("headless:socialaccount:manage_providers"),
data={"provider": account_to_del.provider, "account": account_to_del.uid},
content_type="application/json",
)
assert resp.status_code == 200
assert resp.status_code == HTTPStatus.OK
assert resp.json() == {
"status": 200,
"status": HTTPStatus.OK,
"data": [
{
"display": "Unittest Server",
Expand All @@ -79,9 +80,9 @@ def test_disconnect_bad_request(auth_client, user, headless_reverse, provider_id
data={"provider": provider_id, "account": "unknown"},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.status_code == HTTPStatus.BAD_REQUEST
assert resp.json() == {
"status": 400,
"status": HTTPStatus.BAD_REQUEST,
"errors": [{"code": "account_not_found", "message": "Unknown account."}],
}

Expand All @@ -96,9 +97,9 @@ def test_disconnect_not_allowed(auth_client, user, headless_reverse, provider_id
data={"provider": provider_id, "account": account.uid},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.status_code == HTTPStatus.BAD_REQUEST
assert resp.json() == {
"status": 400,
"status": HTTPStatus.BAD_REQUEST,
"errors": [
{"code": "no_password", "message": "Your account has no password set up."}
],
Expand All @@ -124,7 +125,7 @@ def test_valid_token(client, headless_reverse, db):
},
content_type="application/json",
)
assert resp.status_code == 200
assert resp.status_code == HTTPStatus.OK
assert EmailAddress.objects.filter(email="a@b.com", verified=True).exists()


Expand All @@ -141,10 +142,10 @@ def test_invalid_token(client, headless_reverse, db, google_provider_settings):
},
content_type="application/json",
)
assert resp.status_code == 400
assert resp.status_code == HTTPStatus.BAD_REQUEST
data = resp.json()
assert data == {
"status": 400,
"status": HTTPStatus.BAD_REQUEST,
"errors": [
{"message": "Invalid token.", "code": "invalid_token", "param": "token"}
],
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_valid_token_multiple_apps(
},
content_type="application/json",
)
assert resp.status_code == 200
assert resp.status_code == HTTPStatus.OK


def test_auth_error_no_headless_request(client, db, google_provider_settings, settings):
Expand Down Expand Up @@ -249,17 +250,18 @@ def test_token_signup_closed(client, headless_reverse, db):
},
content_type="application/json",
)
assert resp.status_code == 403
assert resp.status_code == HTTPStatus.FORBIDDEN
assert not EmailAddress.objects.filter(email="a@b.com", verified=True).exists()


def test_provider_signup(client, headless_reverse, db, settings):
settings.ACCOUNT_EMAIL_VERIFICATION = "mandatory"
settings.ACCOUNT_EMAIL_REQUIRED = True
settings.ACCOUNT_USERNAME_REQUIRED = False
account_uid = "123"
id_token = json.dumps(
{
"id": 123,
"id": account_uid,
}
)
resp = client.post(
Expand All @@ -273,17 +275,37 @@ def test_provider_signup(client, headless_reverse, db, settings):
},
content_type="application/json",
)
assert resp.status_code == 401
assert resp.status_code == HTTPStatus.UNAUTHORIZED
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0]
assert pending_flow["id"] == "provider_signup"

resp = client.get(headless_reverse("headless:socialaccount:provider_signup"))
assert resp.status_code == HTTPStatus.OK
assert resp.json() == {
"data": {
"email": [],
"account": {
"display": "Dummy",
"provider": {
"flows": ["provider_redirect", "provider_token"],
"id": "dummy",
"name": "Dummy",
},
"uid": account_uid,
},
"user": {"display": "user", "has_usable_password": False},
},
"status": HTTPStatus.OK,
}

resp = client.post(
headless_reverse("headless:socialaccount:provider_signup"),
data={
"email": "a@b.com",
},
content_type="application/json",
)
assert resp.status_code == 401
assert resp.status_code == HTTPStatus.UNAUTHORIZED
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0]
assert pending_flow["id"] == "verify_email"
assert EmailAddress.objects.filter(email="a@b.com").exists()
Expand All @@ -309,7 +331,7 @@ def test_signup_closed(client, headless_reverse, db, settings):
},
content_type="application/json",
)
assert resp.status_code == 401
assert resp.status_code == HTTPStatus.UNAUTHORIZED
pending_flow = [f for f in resp.json()["data"]["flows"] if f.get("is_pending")][0]
assert pending_flow["id"] == "provider_signup"
with patch(
Expand All @@ -323,7 +345,7 @@ def test_signup_closed(client, headless_reverse, db, settings):
},
content_type="application/json",
)
assert resp.status_code == 403
assert resp.status_code == HTTPStatus.FORBIDDEN


def test_connect(user, auth_client, sociallogin_setup_state, headless_reverse, db):
Expand All @@ -336,7 +358,7 @@ def test_connect(user, auth_client, sociallogin_setup_state, headless_reverse, d
"id": 123,
},
)
assert resp.status_code == 302
assert resp.status_code == HTTPStatus.FOUND
assert resp["location"] == "/foo"
assert SocialAccount.objects.filter(user=user, provider="dummy", uid="123").exists()

Expand All @@ -355,7 +377,7 @@ def test_connect_reauthentication_required(
"id": 123,
},
)
assert resp.status_code == 302
assert resp.status_code == HTTPStatus.FOUND
assert (
resp["location"] == "/foo?error=reauthentication_required&error_process=connect"
)
Expand All @@ -378,7 +400,7 @@ def test_connect_already_connected(
},
)
# We're redirected, and an error code is shown.
assert resp.status_code == 302
assert resp.status_code == HTTPStatus.FOUND
assert resp["location"] == "/foo?error=connected_other&error_process=connect"
assert not SocialAccount.objects.filter(
user=user, provider="dummy", uid="123"
Expand All @@ -404,7 +426,7 @@ def test_token_connect(user, auth_client, headless_reverse, db):
},
content_type="application/json",
)
assert resp.status_code == 200
assert resp.status_code == HTTPStatus.OK
assert SocialAccount.objects.filter(uid="123", user=user).exists()


Expand Down Expand Up @@ -433,9 +455,9 @@ def test_token_connect_already_connected(
content_type="application/json",
)
assert not SocialAccount.objects.filter(uid="123", user=user).exists()
assert resp.status_code == 400
assert resp.status_code == HTTPStatus.BAD_REQUEST
assert resp.json() == {
"status": 400,
"status": HTTPStatus.BAD_REQUEST,
"errors": [
{
"code": "connected_other",
Expand All @@ -453,4 +475,4 @@ def test_provider_signup_not_pending(client, headless_reverse, db, settings):
},
content_type="application/json",
)
assert resp.status_code == 409
assert resp.status_code == HTTPStatus.CONFLICT
8 changes: 7 additions & 1 deletion allauth/headless/socialaccount/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
SignupInput,
)
from allauth.headless.socialaccount.internal import complete_token_login
from allauth.headless.socialaccount.response import SocialAccountsResponse
from allauth.headless.socialaccount.response import (
SocialAccountsResponse,
SocialLoginResponse,
)
from allauth.socialaccount.adapter import (
get_adapter as get_socialaccount_adapter,
)
Expand All @@ -37,6 +40,9 @@ def handle(self, request, *args, **kwargs):
return ForbiddenResponse(request)
return super().handle(request, *args, **kwargs)

def get(self, request, *args, **kwargs):
return SocialLoginResponse(request, self.sociallogin)

def post(self, request, *args, **kwargs):
response = flows.signup.signup_by_form(
self.request, self.sociallogin, self.input
Expand Down
Loading

0 comments on commit 2542b3f

Please sign in to comment.