Skip to content

Commit

Permalink
feat(account): Keep track of authentication methods used
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Dec 1, 2023
1 parent f784cdf commit e39aed5
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 6 deletions.
40 changes: 40 additions & 0 deletions allauth/account/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import time


AUTHENTICATION_METHODS_SESSION_KEY = "account_authentication_methods"


def record_authentication(request, method, **extra_data):
"""Here we keep a log of all authentication methods used within the current
session. Important to note is that having entries here does not imply that
a user is fully signed in. For example, consider a case where a user
authenticates using a password, but fails to complete the 2FA challenge.
Or, a user successfully signs in into an inactive account or one that still
needs verification. In such cases, ``request.user`` is still anonymous, yet,
we do have an entry here.
Example data::
{'method': 'password',
'at': 1701423602.7184925,
'username': 'john.doe'}
{'method': 'socialaccount',
'at': 1701423567.6368647,
'provider': 'amazon',
'uid': 'amzn1.account.K2LI23KL2LK2'}
{'method': 'mfa',
'at': 1701423602.6392953,
'id': 1,
'type': 'totp'}
"""
methods = request.session.get(AUTHENTICATION_METHODS_SESSION_KEY, [])
data = {
"method": method,
"at": time.time(),
**extra_data,
}
methods.append(data)
request.session[AUTHENTICATION_METHODS_SESSION_KEY] = methods
12 changes: 10 additions & 2 deletions allauth/account/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from django.utils.safestring import mark_safe
from django.utils.translation import gettext, gettext_lazy as _, pgettext

from allauth.account.authentication import record_authentication

from ..utils import (
build_absolute_uri,
get_username_max_length,
Expand Down Expand Up @@ -200,13 +202,19 @@ def clean(self):
return self.cleaned_data

def login(self, request, redirect_url=None):
email = self.user_credentials().get("email")
credentials = self.user_credentials()
extra_data = {
field: credentials.get(field)
for field in ["email", "username"]
if field in credentials
}
record_authentication(request, method="password", **extra_data)
ret = perform_login(
request,
self.user,
email_verification=app_settings.EMAIL_VERIFICATION,
redirect_url=redirect_url,
email=email,
email=credentials.get("email"),
)
remember = app_settings.SESSION_REMEMBER
if remember is None:
Expand Down
2 changes: 2 additions & 0 deletions allauth/account/reauthentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def resume_request(request):


def record_authentication(request, user):
# TODO: This is a different/independent mechanism from
# ``authentication.record_authentication()``. We need to unify this.
request.session[AUTHENTICATED_AT_SESSION_KEY] = time.time()


Expand Down
13 changes: 12 additions & 1 deletion allauth/account/tests/test_login.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from unittest.mock import patch
from unittest.mock import ANY, patch

import django
from django.conf import settings
Expand All @@ -9,6 +9,7 @@
from django.urls import NoReverseMatch, reverse

from allauth.account import app_settings
from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
from allauth.account.forms import LoginForm
from allauth.account.models import EmailAddress
from allauth.tests import TestCase
Expand Down Expand Up @@ -46,6 +47,16 @@ def test_username_containing_at(self):
self.assertRedirects(
resp, settings.LOGIN_REDIRECT_URL, fetch_redirect_response=False
)
self.assertEqual(
self.client.session[AUTHENTICATION_METHODS_SESSION_KEY],
[
{
"at": ANY,
"username": "@raymond.penners",
"method": "password",
}
],
)

def _create_user(self, username="john", password="doe", **kwargs):
user = get_user_model().objects.create(
Expand Down
3 changes: 2 additions & 1 deletion allauth/mfa/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from allauth.mfa import totp
from allauth.mfa.adapter import get_adapter
from allauth.mfa.models import Authenticator
from allauth.mfa.utils import post_authentication


class AuthenticateForm(forms.Form):
Expand Down Expand Up @@ -44,7 +45,7 @@ def clean_code(self):
raise forms.ValidationError(get_adapter().error_messages["incorrect_code"])

def save(self):
self.authenticator.record_usage()
post_authentication(context.request, self.authenticator)


class ActivateTOTPForm(forms.Form):
Expand Down
42 changes: 41 additions & 1 deletion allauth/mfa/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import ANY, patch

import django
from django.conf import settings
Expand All @@ -7,6 +7,7 @@
import pytest
from pytest_django.asserts import assertFormError

from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
from allauth.account.models import EmailAddress
from allauth.mfa import app_settings
from allauth.mfa.adapter import get_adapter
Expand Down Expand Up @@ -130,6 +131,10 @@ def test_totp_login(client, user_with_totp, user_password, totp_validation_bypas
)
assert resp.status_code == 302
assert resp["location"] == settings.LOGIN_REDIRECT_URL
assert client.session[AUTHENTICATION_METHODS_SESSION_KEY] == [
{"method": "password", "at": ANY, "username": user_with_totp.username},
{"method": "mfa", "at": ANY, "id": ANY, "type": Authenticator.Type.TOTP},
]


def test_download_recovery_codes(auth_client, user_with_recovery_codes, user_password):
Expand Down Expand Up @@ -169,6 +174,41 @@ def test_generate_recovery_codes(auth_client, user_with_recovery_codes, user_pas
assert not rc.validate_code(prev_code)


def test_recovery_codes_login(
client, user_with_totp, user_with_recovery_codes, user_password
):
resp = client.post(
reverse("account_login"),
{"login": user_with_totp.username, "password": user_password},
)
assert resp.status_code == 302
assert resp["location"] == reverse("mfa_authenticate")
resp = client.get(reverse("mfa_authenticate"))
assert resp.context["request"].user.is_anonymous
resp = client.post(reverse("mfa_authenticate"), {"code": "123"})
assert resp.context["form"].errors == {
"code": [get_adapter().error_messages["incorrect_code"]]
}
rc = Authenticator.objects.get(
user=user_with_recovery_codes, type=Authenticator.Type.RECOVERY_CODES
)
resp = client.post(
reverse("mfa_authenticate"),
{"code": rc.wrap().get_unused_codes()[0]},
)
assert resp.status_code == 302
assert resp["location"] == settings.LOGIN_REDIRECT_URL
assert client.session[AUTHENTICATION_METHODS_SESSION_KEY] == [
{"method": "password", "at": ANY, "username": user_with_totp.username},
{
"method": "mfa",
"at": ANY,
"id": ANY,
"type": Authenticator.Type.RECOVERY_CODES,
},
]


def test_add_email_not_allowed(auth_client, user_with_totp):
resp = auth_client.post(
reverse("account_email"),
Expand Down
10 changes: 10 additions & 0 deletions allauth/mfa/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from allauth.account.authentication import record_authentication
from allauth.mfa.adapter import get_adapter
from allauth.mfa.models import Authenticator

Expand All @@ -17,3 +18,12 @@ def is_mfa_enabled(user, types=None):
if types is not None:
qs = qs.filter(type__in=types)
return qs.exists()


def post_authentication(request, authenticator):
authenticator.record_usage()
extra_data = {
"id": authenticator.pk,
"type": authenticator.type,
}
record_authentication(request, "mfa", **extra_data)
9 changes: 9 additions & 0 deletions allauth/socialaccount/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.authentication import record_authentication
from allauth.account.reauthentication import reauthenticate_then_callback
from allauth.account.utils import (
assess_unique_email,
Expand Down Expand Up @@ -195,6 +196,14 @@ def _add_social_account(request, sociallogin):
def complete_social_login(request, sociallogin):
assert not sociallogin.is_existing
sociallogin.lookup()
record_authentication(
request,
"socialaccount",
**{
"provider": sociallogin.account.provider,
"uid": sociallogin.account.uid,
}
)
try:
get_adapter().pre_social_login(request, sociallogin)
signals.pre_social_login.send(
Expand Down
25 changes: 24 additions & 1 deletion allauth/socialaccount/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import requests
import warnings
from unittest.mock import Mock, patch
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse

from django.conf import settings
Expand All @@ -15,6 +15,7 @@
from django.utils.http import urlencode

import allauth.app_settings
from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
from allauth.account.models import EmailAddress
from allauth.account.utils import user_email, user_username
from allauth.socialaccount import app_settings
Expand Down Expand Up @@ -81,6 +82,17 @@ def test_login(self):
provider_account.get_profile_url()
provider_account.get_brand()
provider_account.to_str()
self.assertEqual(
self.client.session[AUTHENTICATION_METHODS_SESSION_KEY],
[
{
"at": ANY,
"provider": self.provider_id,
"method": "socialaccount",
"uid": account.uid,
}
],
)
return account

@override_settings(
Expand Down Expand Up @@ -211,6 +223,17 @@ def test_login(self):
resp_mock,
)
self.assertRedirects(resp, reverse("socialaccount_signup"))
self.assertEqual(
self.client.session[AUTHENTICATION_METHODS_SESSION_KEY],
[
{
"at": ANY,
"provider": self.provider_id,
"method": "socialaccount",
"uid": ANY,
}
],
)

@override_settings(SOCIALACCOUNT_AUTO_SIGNUP=False)
def test_login_with_pkce_disabled(self):
Expand Down

0 comments on commit e39aed5

Please sign in to comment.