diff --git a/fastapi_azure_auth/__init__.py b/fastapi_azure_auth/__init__.py index 1a72d32..b3ddbc4 100644 --- a/fastapi_azure_auth/__init__.py +++ b/fastapi_azure_auth/__init__.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = '1.1.1' diff --git a/fastapi_azure_auth/auth.py b/fastapi_azure_auth/auth.py index f0bbaa7..bc47caa 100644 --- a/fastapi_azure_auth/auth.py +++ b/fastapi_azure_auth/auth.py @@ -1,3 +1,5 @@ +import base64 +import json import logging from typing import Any, Dict, Optional @@ -78,11 +80,20 @@ async def __call__(self, request: Request) -> dict[str, Any]: Extends call to also validate the token """ access_token = await super().__call__(request=request) + try: + # Extract header information of the token. + header = json.loads(base64.b64decode(access_token.split('.')[0])) # header, claims, signature + except Exception as error: + log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True) + raise InvalidAuth(detail='Invalid token format') + # Load new config if old await provider_config.load_config() - for index, key in enumerate(provider_config.signing_keys): + + # Use the `kid` from the header to find a matching signing key to use + if key := provider_config.signing_keys.get(header.get('kid')): try: - # Set strict in case defaults change + # We require and validate all fields in an Azure AD token options = { 'verify_signature': True, 'verify_aud': True, @@ -103,7 +114,7 @@ async def __call__(self, request: Request) -> dict[str, Any]: 'require_at_hash': False, 'leeway': 0, } - # Validate token and return claims + # Validate token token = jwt.decode( access_token, key=key, @@ -114,6 +125,7 @@ async def __call__(self, request: Request) -> dict[str, Any]: ) if not self.allow_guest_users and token['tid'] != provider_config.tenant_id: raise GuestUserException() + # Attach the user to the request. Can be accessed through `request.state.user` user: User = User(**token | {'claims': token}) request.state.user = user return token @@ -126,8 +138,6 @@ async def __call__(self, request: Request) -> dict[str, Any]: log.info('Token signature has expired. %s', error) raise InvalidAuth(detail='Token signature has expired') except JWTError as error: - if str(error) == 'Signature verification failed.' and index < len(provider_config.signing_keys) - 1: - continue log.warning('Invalid token. Error: %s', error, exc_info=True) raise InvalidAuth(detail='Unable to validate token') except Exception as error: diff --git a/fastapi_azure_auth/provider_config.py b/fastapi_azure_auth/provider_config.py index 09ea8ab..0d17a34 100644 --- a/fastapi_azure_auth/provider_config.py +++ b/fastapi_azure_auth/provider_config.py @@ -18,7 +18,7 @@ def __init__(self) -> None: self._config_timestamp: Optional[datetime] = None self.authorization_endpoint: str - self.signing_keys: list[KeyTypes] + self.signing_keys: dict[str, KeyTypes] self.token_endpoint: str self.end_session_endpoint: str self.issuer: str @@ -68,24 +68,24 @@ async def _load_openid_config(self) -> None: async with session.get(jwks_uri, timeout=10) as jwks_response: jwks_response.raise_for_status() keys = await jwks_response.json() - signing_certificates = [x['x5c'][0] for x in keys['keys'] if x.get('use', 'sig') == 'sig'] - self._load_keys(signing_certificates) + self._load_keys(keys['keys']) self.authorization_endpoint = openid_cfg['authorization_endpoint'] self.token_endpoint = openid_cfg['token_endpoint'] self.end_session_endpoint = openid_cfg['end_session_endpoint'] self.issuer = openid_cfg['issuer'] - def _load_keys(self, certificates: list[str]) -> None: + def _load_keys(self, keys: list[dict]) -> None: """ Create certificates based on signing keys and store them """ - new_keys = [] - for cert in certificates: - log.debug('Loading public key from certificate: %s', cert) - cert_obj = load_der_x509_certificate(base64.b64decode(cert), backend) - new_keys.append(cert_obj.public_key()) - self.signing_keys = new_keys + self.signing_keys = {} + for key in keys: + if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption + log.debug('Loading public key from certificate: %s', key) + cert_obj = load_der_x509_certificate(base64.b64decode(key['x5c'][0]), backend) + if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it. + self.signing_keys[kid] = cert_obj.public_key() provider_config = ProviderConfig() diff --git a/pyproject.toml b/pyproject.toml index 103dd79..7916282 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fastapi-azure-auth" -version = "1.1.0" # Remember to change in __init__.py as well +version = "1.1.1" # Remember to change in __init__.py as well description = "Easy and secure implementation of Azure AD for your FastAPI APIs" authors = ["Jonas Krüger Svensson "] readme = "README.md" diff --git a/tests/conftest.py b/tests/conftest.py index 7a7f704..2da5465 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest from aioresponses import aioresponses -from tests.utils import build_openid_keys +from tests.utils import build_openid_keys, openid_configuration from fastapi_azure_auth.provider_config import provider_config @@ -17,54 +17,7 @@ def mock_openid(): with aioresponses() as mock: mock.get( 'https://login.microsoftonline.com/intility_tenant_id/v2.0/.well-known/openid-configuration', - payload={ - 'token_endpoint': 'https://login.microsoftonline.com/intility_tenant_id/token', - 'token_endpoint_auth_methods_supported': [ - 'client_secret_post', - 'private_key_jwt', - 'client_secret_basic', - ], - 'jwks_uri': 'https://login.microsoftonline.com/common/discovery/keys', - 'response_modes_supported': ['query', 'fragment', 'form_post'], - 'subject_types_supported': ['pairwise'], - 'id_token_signing_alg_values_supported': ['RS256'], - 'response_types_supported': ['code', 'id_token', 'code id_token', 'token id_token', 'token'], - 'scopes_supported': ['openid'], - 'issuer': 'https://sts.windows.net/intility_tenant_id/', - 'microsoft_multi_refresh_token': True, - 'authorization_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/authorize', - 'device_authorization_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/devicecode', - 'http_logout_supported': True, - 'frontchannel_logout_supported': True, - 'end_session_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/logout', - 'claims_supported': [ - 'sub', - 'iss', - 'cloud_instance_name', - 'cloud_instance_host_name', - 'cloud_graph_host_name', - 'msgraph_host', - 'aud', - 'exp', - 'iat', - 'auth_time', - 'acr', - 'amr', - 'nonce', - 'email', - 'given_name', - 'family_name', - 'nickname', - ], - 'check_session_iframe': 'https://login.microsoftonline.com/intility_tenant_idoauth2/checksession', - 'userinfo_endpoint': 'https://login.microsoftonline.com/intility_tenant_idopenid/userinfo', - 'kerberos_endpoint': 'https://login.microsoftonline.com/intility_tenant_idkerberos', - 'tenant_region_scope': 'EU', - 'cloud_instance_name': 'microsoftonline.com', - 'cloud_graph_host_name': 'graph.windows.net', - 'msgraph_host': 'graph.microsoft.com', - 'rbac_url': 'https://pas.windows.net', - }, + payload=openid_configuration(), ) yield mock @@ -87,6 +40,23 @@ def mock_openid_and_empty_keys(mock_openid): yield mock_openid +@pytest.fixture +def mock_openid_ok_then_empty(mock_openid): + mock_openid.get( + 'https://login.microsoftonline.com/common/discovery/keys', + payload=build_openid_keys(), + ) + mock_openid.get( + 'https://login.microsoftonline.com/common/discovery/keys', + payload=build_openid_keys(empty_keys=True), + ) + mock_openid.get( + 'https://login.microsoftonline.com/intility_tenant_id/v2.0/.well-known/openid-configuration', + payload=openid_configuration(), + ) + yield mock_openid + + @pytest.fixture def mock_openid_and_no_valid_keys(mock_openid): mock_openid.get( diff --git a/tests/test_validate_token.py b/tests/test_validate_token.py index f897946..20604d3 100644 --- a/tests/test_validate_token.py +++ b/tests/test_validate_token.py @@ -1,4 +1,5 @@ import time +from datetime import datetime, timedelta import pytest from demoproj.core.config import settings @@ -9,6 +10,8 @@ build_access_token_expired, build_access_token_guest, build_access_token_invalid_claims, + build_evil_access_token, + build_openid_keys, ) from fastapi_azure_auth.auth import AzureAuthorizationCodeBearer @@ -91,6 +94,7 @@ async def test_no_keys_to_decode_with(mock_openid_and_empty_keys): assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} +@pytest.mark.asyncio async def test_invalid_token_claims(mock_openid_and_keys): async with AsyncClient( app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()} @@ -99,14 +103,16 @@ async def test_invalid_token_claims(mock_openid_and_keys): assert response.json() == {'detail': 'Token contains invalid claims'} +@pytest.mark.asyncio async def test_no_valid_keys_for_token(mock_openid_and_no_valid_keys): async with AsyncClient( app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_invalid_claims()} ) as ac: response = await ac.get('api/v1/hello') - assert response.json() == {'detail': 'Unable to validate token'} + assert response.json() == {'detail': 'Unable to verify token, no signing keys found'} +@pytest.mark.asyncio async def test_expired_token(mock_openid_and_keys): async with AsyncClient( app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token_expired()} @@ -115,6 +121,42 @@ async def test_expired_token(mock_openid_and_keys): assert response.json() == {'detail': 'Token signature has expired'} +@pytest.mark.asyncio +async def test_evil_token(mock_openid_and_keys): + """Kid matches what we expect, but it's not signed correctly""" + async with AsyncClient( + app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_evil_access_token()} + ) as ac: + response = await ac.get('api/v1/hello') + assert response.json() == {'detail': 'Unable to validate token'} + + +@pytest.mark.asyncio +async def test_malformed_token(mock_openid_and_keys): + """A short token, that only has a broken header""" + async with AsyncClient( + app=app, base_url='http://test', headers={'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsInR5cI6IkpXVCJ9'} + ) as ac: + response = await ac.get('api/v1/hello') + assert response.json() == {'detail': 'Invalid token format'} + + +@pytest.mark.asyncio +async def test_only_header(mock_openid_and_keys): + """Only header token, with a matching kid, so the rest of the logic will be called, but can't be validated""" + async with AsyncClient( + app=app, + base_url='http://test', + headers={ + 'Authorization': 'Bearer eyJhbGciOiJSUzI1NiIsImtpZCI6InJlYWwgdGh1bWJ' + 'wcmludCIsInR5cCI6IkpXVCIsIng1dCI6ImFub3RoZXIgdGh1bWJwcmludCJ9' + }, # {'kid': 'real thumbprint', 'x5t': 'another thumbprint'} + ) as ac: + response = await ac.get('api/v1/hello') + assert response.json() == {'detail': 'Unable to validate token'} + + +@pytest.mark.asyncio async def test_exception_raised(mock_openid_and_keys, mocker): mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol')) async with AsyncClient( @@ -122,3 +164,27 @@ async def test_exception_raised(mock_openid_and_keys, mocker): ) as ac: response = await ac.get('api/v1/hello') assert response.json() == {'detail': 'Unable to process token'} + + +@pytest.mark.asyncio +async def test_change_of_keys_works(mock_openid_ok_then_empty, freezer): + """ + * Do a successful request. + * Set time to 25 hours later, so that a new provider config has to be fetched + * Ensure new keys returned is an empty list, so the next request shouldn't work. + * Generate a new, valid token + * Do request + """ + async with AsyncClient( + app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} + ) as ac: + response = await ac.get('api/v1/hello') + assert response.status_code == 200 + + freezer.move_to(datetime.now() + timedelta(hours=25)) # The keys fetched are now outdated + + async with AsyncClient( + app=app, base_url='http://test', headers={'Authorization': 'Bearer ' + build_access_token()} + ) as ac: + second_resonse = await ac.get('api/v1/hello') + assert second_resonse.json() == {'detail': 'Unable to verify token, no signing keys found'} diff --git a/tests/utils.py b/tests/utils.py index a86e246..e37727c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,6 +51,13 @@ def build_access_token(): return do_build_access_token(tenant_id='intility_tenant_id') +def build_evil_access_token(): + """ + Build an access token, coming from the tenant ID we expect + """ + return do_build_access_token(tenant_id='intility_tenant_id', evil=True) + + def build_access_token_guest(): """ Build an access token, but as a guest user. @@ -72,7 +79,7 @@ def build_access_token_expired(): return do_build_access_token(tenant_id='intility_tenant_id', expired=True) -def do_build_access_token(tenant_id=None, aud=None, expired=False): +def do_build_access_token(tenant_id=None, aud=None, expired=False, evil=False): """ Build the access token and encode it with the signing key. """ @@ -106,14 +113,16 @@ def do_build_access_token(tenant_id=None, aud=None, expired=False): 'uti': 'abcdefghijkl-mnopqrstu', 'ver': '1.0', } + signing_key = signing_key_a if evil else signing_key_b return jwt.encode( claims, - signing_key_b.private_bytes( + signing_key.private_bytes( crypto_serialization.Encoding.PEM, crypto_serialization.PrivateFormat.PKCS8, crypto_serialization.NoEncryption(), ), algorithm='RS256', + headers={'kid': 'real thumbprint', 'x5t': 'another thumbprint'}, ) @@ -126,7 +135,7 @@ def build_openid_keys(empty_keys=False, no_valid_keys=False): elif no_valid_keys: return { 'keys': [ - { + { # this key is not used 'kty': 'RSA', 'use': 'sig', 'kid': 'dummythumbprint', @@ -156,8 +165,8 @@ def build_openid_keys(empty_keys=False, no_valid_keys=False): { 'kty': 'RSA', 'use': 'sig', - 'kid': 'dummythumbprint', - 'x5t': 'dummythumbprint', + 'kid': 'real thumbprint', + 'x5t': 'real thumbprint2', 'n': 'somebase64encodedmodulus', 'e': 'somebase64encodedexponent', 'x5c': [ @@ -168,5 +177,56 @@ def build_openid_keys(empty_keys=False, no_valid_keys=False): } +def openid_configuration(): + return { + 'token_endpoint': 'https://login.microsoftonline.com/intility_tenant_id/token', + 'token_endpoint_auth_methods_supported': [ + 'client_secret_post', + 'private_key_jwt', + 'client_secret_basic', + ], + 'jwks_uri': 'https://login.microsoftonline.com/common/discovery/keys', + 'response_modes_supported': ['query', 'fragment', 'form_post'], + 'subject_types_supported': ['pairwise'], + 'id_token_signing_alg_values_supported': ['RS256'], + 'response_types_supported': ['code', 'id_token', 'code id_token', 'token id_token', 'token'], + 'scopes_supported': ['openid'], + 'issuer': 'https://sts.windows.net/intility_tenant_id/', + 'microsoft_multi_refresh_token': True, + 'authorization_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/authorize', + 'device_authorization_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/devicecode', + 'http_logout_supported': True, + 'frontchannel_logout_supported': True, + 'end_session_endpoint': 'https://login.microsoftonline.com/intility_tenant_idoauth2/logout', + 'claims_supported': [ + 'sub', + 'iss', + 'cloud_instance_name', + 'cloud_instance_host_name', + 'cloud_graph_host_name', + 'msgraph_host', + 'aud', + 'exp', + 'iat', + 'auth_time', + 'acr', + 'amr', + 'nonce', + 'email', + 'given_name', + 'family_name', + 'nickname', + ], + 'check_session_iframe': 'https://login.microsoftonline.com/intility_tenant_idoauth2/checksession', + 'userinfo_endpoint': 'https://login.microsoftonline.com/intility_tenant_idopenid/userinfo', + 'kerberos_endpoint': 'https://login.microsoftonline.com/intility_tenant_idkerberos', + 'tenant_region_scope': 'EU', + 'cloud_instance_name': 'microsoftonline.com', + 'cloud_graph_host_name': 'graph.windows.net', + 'msgraph_host': 'graph.microsoft.com', + 'rbac_url': 'https://pas.windows.net', + } + + signing_key_a, signing_cert_a = generate_key_and_cert() signing_key_b, signing_cert_b = generate_key_and_cert()