Skip to content

Commit

Permalink
Merge pull request #10 from Intility/improve_signature_loop
Browse files Browse the repository at this point in the history
Signature loop improvement
  • Loading branch information
JonasKs authored Aug 17, 2021
2 parents 8122b19 + 41043ab commit a8f108e
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 72 deletions.
2 changes: 1 addition & 1 deletion fastapi_azure_auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.1.0'
__version__ = '1.1.1'
20 changes: 15 additions & 5 deletions fastapi_azure_auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import json
import logging
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -78,11 +80,20 @@ async def __call__(self, request: Request) -> dict[str, Any]:
Extends call to also validate the token
"""
access_token = await super().__call__(request=request)
try:
# Extract header information of the token.
header = json.loads(base64.b64decode(access_token.split('.')[0])) # header, claims, signature
except Exception as error:
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
raise InvalidAuth(detail='Invalid token format')

# Load new config if old
await provider_config.load_config()
for index, key in enumerate(provider_config.signing_keys):

# Use the `kid` from the header to find a matching signing key to use
if key := provider_config.signing_keys.get(header.get('kid')):
try:
# Set strict in case defaults change
# We require and validate all fields in an Azure AD token
options = {
'verify_signature': True,
'verify_aud': True,
Expand All @@ -103,7 +114,7 @@ async def __call__(self, request: Request) -> dict[str, Any]:
'require_at_hash': False,
'leeway': 0,
}
# Validate token and return claims
# Validate token
token = jwt.decode(
access_token,
key=key,
Expand All @@ -114,6 +125,7 @@ async def __call__(self, request: Request) -> dict[str, Any]:
)
if not self.allow_guest_users and token['tid'] != provider_config.tenant_id:
raise GuestUserException()
# Attach the user to the request. Can be accessed through `request.state.user`
user: User = User(**token | {'claims': token})
request.state.user = user
return token
Expand All @@ -126,8 +138,6 @@ async def __call__(self, request: Request) -> dict[str, Any]:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired')
except JWTError as error:
if str(error) == 'Signature verification failed.' and index < len(provider_config.signing_keys) - 1:
continue
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token')
except Exception as error:
Expand Down
20 changes: 10 additions & 10 deletions fastapi_azure_auth/provider_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self) -> None:
self._config_timestamp: Optional[datetime] = None

self.authorization_endpoint: str
self.signing_keys: list[KeyTypes]
self.signing_keys: dict[str, KeyTypes]
self.token_endpoint: str
self.end_session_endpoint: str
self.issuer: str
Expand Down Expand Up @@ -68,24 +68,24 @@ async def _load_openid_config(self) -> None:
async with session.get(jwks_uri, timeout=10) as jwks_response:
jwks_response.raise_for_status()
keys = await jwks_response.json()
signing_certificates = [x['x5c'][0] for x in keys['keys'] if x.get('use', 'sig') == 'sig']
self._load_keys(signing_certificates)
self._load_keys(keys['keys'])

self.authorization_endpoint = openid_cfg['authorization_endpoint']
self.token_endpoint = openid_cfg['token_endpoint']
self.end_session_endpoint = openid_cfg['end_session_endpoint']
self.issuer = openid_cfg['issuer']

def _load_keys(self, certificates: list[str]) -> None:
def _load_keys(self, keys: list[dict]) -> None:
"""
Create certificates based on signing keys and store them
"""
new_keys = []
for cert in certificates:
log.debug('Loading public key from certificate: %s', cert)
cert_obj = load_der_x509_certificate(base64.b64decode(cert), backend)
new_keys.append(cert_obj.public_key())
self.signing_keys = new_keys
self.signing_keys = {}
for key in keys:
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
log.debug('Loading public key from certificate: %s', key)
cert_obj = load_der_x509_certificate(base64.b64decode(key['x5c'][0]), backend)
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
self.signing_keys[kid] = cert_obj.public_key()


provider_config = ProviderConfig()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fastapi-azure-auth"
version = "1.1.0" # Remember to change in __init__.py as well
version = "1.1.1" # Remember to change in __init__.py as well
description = "Easy and secure implementation of Azure AD for your FastAPI APIs"
authors = ["Jonas Krüger Svensson <jonas.svensson@intility.no>"]
readme = "README.md"
Expand Down
68 changes: 19 additions & 49 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from aioresponses import aioresponses
from tests.utils import build_openid_keys
from tests.utils import build_openid_keys, openid_configuration

from fastapi_azure_auth.provider_config import provider_config

Expand All @@ -17,54 +17,7 @@ def mock_openid():
with aioresponses() as mock:
mock.get(
'https://login.microsoftonline.com/intility_tenant_id/v2.0/.well-known/openid-configuration',
payload={
'token_endpoint': 'https://login.microsoftonline.com/intility_tenant_id/token',
'token_endpoint_auth_methods_supported': [
'client_secret_post',
'private_key_jwt',
'client_secret_basic',
],
'jwks_uri': 'https://login.microsoftonline.com/common/discovery/keys',
'response_modes_supported': ['query', 'fragment', 'form_post'],
'subject_types_supported': ['pairwise'],
'id_token_signing_alg_values_supported': ['RS256'],
'response_types_supported': ['code', 'id_token', 'code id_token', 'token id_token', 'token'],
'scopes_supported': ['openid'],
'issuer': 'https://sts.windows.net/intility_tenant_id/',
'microsoft_multi_refresh_token': True,
'authorization_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/authorize',
'device_authorization_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/devicecode',
'http_logout_supported': True,
'frontchannel_logout_supported': True,
'end_session_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/logout',
'claims_supported': [
'sub',
'iss',
'cloud_instance_name',
'cloud_instance_host_name',
'cloud_graph_host_name',
'msgraph_host',
'aud',
'exp',
'iat',
'auth_time',
'acr',
'amr',
'nonce',
'email',
'given_name',
'family_name',
'nickname',
],
'check_session_iframe': 'https://login.microsoftonline.com/intility_tenant_idoauth2/checksession',
'userinfo_endpoint': 'https://login.microsoftonline.com/intility_tenant_idopenid/userinfo',
'kerberos_endpoint': 'https://login.microsoftonline.com/intility_tenant_idkerberos',
'tenant_region_scope': 'EU',
'cloud_instance_name': 'microsoftonline.com',
'cloud_graph_host_name': 'graph.windows.net',
'msgraph_host': 'graph.microsoft.com',
'rbac_url': 'https://pas.windows.net',
},
payload=openid_configuration(),
)
yield mock

Expand All @@ -87,6 +40,23 @@ def mock_openid_and_empty_keys(mock_openid):
yield mock_openid


@pytest.fixture
def mock_openid_ok_then_empty(mock_openid):
mock_openid.get(
'https://login.microsoftonline.com/common/discovery/keys',
payload=build_openid_keys(),
)
mock_openid.get(
'https://login.microsoftonline.com/common/discovery/keys',
payload=build_openid_keys(empty_keys=True),
)
mock_openid.get(
'https://login.microsoftonline.com/intility_tenant_id/v2.0/.well-known/openid-configuration',
payload=openid_configuration(),
)
yield mock_openid


@pytest.fixture
def mock_openid_and_no_valid_keys(mock_openid):
mock_openid.get(
Expand Down
68 changes: 67 additions & 1 deletion tests/test_validate_token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from datetime import datetime, timedelta

import pytest
from demoproj.core.config import settings
Expand All @@ -9,6 +10,8 @@
build_access_token_expired,
build_access_token_guest,
build_access_token_invalid_claims,
build_evil_access_token,
build_openid_keys,
)

from fastapi_azure_auth.auth import AzureAuthorizationCodeBearer
Expand Down Expand Up @@ -91,6 +94,7 @@ async def test_no_keys_to_decode_with(mock_openid_and_empty_keys):
assert response.json() == {'detail': 'Unable to verify token, no signing keys found'}


@pytest.mark.asyncio
async def test_invalid_token_claims(mock_openid_and_keys):
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}
Expand All @@ -99,14 +103,16 @@ async def test_invalid_token_claims(mock_openid_and_keys):
assert response.json() == {'detail': 'Token contains invalid claims'}


@pytest.mark.asyncio
async def test_no_valid_keys_for_token(mock_openid_and_no_valid_keys):
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Unable to validate token'}
assert response.json() == {'detail': 'Unable to verify token, no signing keys found'}


@pytest.mark.asyncio
async def test_expired_token(mock_openid_and_keys):
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_expired()}
Expand All @@ -115,10 +121,70 @@ async def test_expired_token(mock_openid_and_keys):
assert response.json() == {'detail': 'Token signature has expired'}


@pytest.mark.asyncio
async def test_evil_token(mock_openid_and_keys):
"""Kid matches what we expect, but it's not signed correctly"""
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_evil_access_token()}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Unable to validate token'}


@pytest.mark.asyncio
async def test_malformed_token(mock_openid_and_keys):
"""A short token, that only has a broken header"""
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Invalid token format'}


@pytest.mark.asyncio
async def test_only_header(mock_openid_and_keys):
"""Only header token, with a matching kid, so the rest of the logic will be called, but can't be validated"""
async with AsyncClient(
app=app,
base_url='http://test',
headers={
'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6InJlYWwgdGh1bWJ'
'wcmludCIsInR5cCI6IkpXVCIsIng1dCI6ImFub3RoZXIgdGh1bWJwcmludCJ9'
}, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Unable to validate token'}


@pytest.mark.asyncio
async def test_exception_raised(mock_openid_and_keys, mocker):
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_expired()}
) as ac:
response = await ac.get('api/v1/hello')
assert response.json() == {'detail': 'Unable to process token'}


@pytest.mark.asyncio
async def test_change_of_keys_works(mock_openid_ok_then_empty, freezer):
"""
* Do a successful request.
* Set time to 25 hours later, so that a new provider config has to be fetched
* Ensure new keys returned is an empty list, so the next request shouldn't work.
* Generate a new, valid token
* Do request
"""
async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()}
) as ac:
response = await ac.get('api/v1/hello')
assert response.status_code == 200

freezer.move_to(datetime.now() + timedelta(hours=25)) # The keys fetched are now outdated

async with AsyncClient(
app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()}
) as ac:
second_resonse = await ac.get('api/v1/hello')
assert second_resonse.json() == {'detail': 'Unable to verify token, no signing keys found'}
Loading

0 comments on commit a8f108e

Please sign in to comment.