Skip to content

Commit

Permalink
Introduce a common base class for OAuth2 implementations.
Browse files Browse the repository at this point in the history
Initial step for Colin-b#72
  • Loading branch information
Raphael Krupinski committed Jan 25, 2024
1 parent a7b6198 commit 2e6efa1
Showing 1 changed file with 44 additions and 54 deletions.
98 changes: 44 additions & 54 deletions httpx_auth/authentication.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import base64
import os
import uuid
Expand Down Expand Up @@ -90,6 +91,30 @@ class OAuth2:
token_cache = oauth2_tokens.TokenMemoryCache()


class OAuthAuthBase(abc.ABC, httpx.Auth):
state: Optional[str] = None
early_expiry: float

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
)
self._update_user_request(request, token)
yield request

@abc.abstractmethod
def request_new_token(self) -> Union[tuple[str, str], tuple[str, str, int]]:
pass # pragma: no cover

@abc.abstractmethod
def _update_user_request(self, request: httpx.Request, token: str) -> None:
pass # pragma: no cover


class SupportMultiAuth:
"""Inherit from this class to be able to use your class with httpx_auth provided authentication classes."""

Expand Down Expand Up @@ -136,7 +161,7 @@ def __init__(self, kwargs):
)


class OAuth2ResourceOwnerPasswordCredentials(httpx.Auth, SupportMultiAuth):
class OAuth2ResourceOwnerPasswordCredentials(OAuthAuthBase, SupportMultiAuth):
"""
Resource Owner Password Credentials Grant
Expand Down Expand Up @@ -168,6 +193,8 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs):
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as body parameters in the token URL.
"""
super().__init__()

self.token_url = token_url
if not self.token_url:
raise Exception("Token URL is mandatory.")
Expand Down Expand Up @@ -205,16 +232,8 @@ def __init__(self, token_url: str, username: str, password: str, **kwargs):
all_parameters_in_url = _add_parameters(self.token_url, self.data)
self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest()

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
)
def _update_user_request(self, request: httpx.Request, token: str) -> None:
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
Expand All @@ -237,7 +256,7 @@ def _configure_client(self, client: httpx.Client):
client.timeout = self.timeout


class OAuth2ClientCredentials(httpx.Auth, SupportMultiAuth):
class OAuth2ClientCredentials(OAuthAuthBase, SupportMultiAuth):
"""
Client Credentials Grant
Expand Down Expand Up @@ -276,6 +295,8 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
if not self.client_secret:
raise Exception("client_secret is mandatory.")

super().__init__()

self.header_name = kwargs.pop("header_name", None) or "Authorization"
self.header_value = kwargs.pop("header_value", None) or "Bearer {token}"
if "{token}" not in self.header_value:
Expand All @@ -299,16 +320,8 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
all_parameters_in_url = _add_parameters(self.token_url, self.data)
self.state = sha512(all_parameters_in_url.encode("unicode_escape")).hexdigest()

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
)
def _update_user_request(self, request: httpx.Request, token: str) -> None:
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
Expand All @@ -330,7 +343,7 @@ def _configure_client(self, client: httpx.Client):
client.timeout = self.timeout


class OAuth2AuthorizationCode(httpx.Auth, SupportMultiAuth, BrowserAuth):
class OAuth2AuthorizationCode(OAuthAuthBase, SupportMultiAuth, BrowserAuth):
"""
Authorization Code Grant
Expand Down Expand Up @@ -389,7 +402,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
if not self.token_url:
raise Exception("Token URL is mandatory.")

BrowserAuth.__init__(self, kwargs)
super().__init__(kwargs)

self.header_name = kwargs.pop("header_name", None) or "Authorization"
self.header_value = kwargs.pop("header_value", None) or "Bearer {token}"
Expand Down Expand Up @@ -447,16 +460,8 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
}
self.token_data.update(kwargs)

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
)
def _update_user_request(self, request: httpx.Request, token: str) -> None:
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
# Request code
Expand Down Expand Up @@ -486,7 +491,7 @@ def _configure_client(self, client: httpx.Client):
client.timeout = self.timeout


class OAuth2AuthorizationCodePKCE(httpx.Auth, SupportMultiAuth, BrowserAuth):
class OAuth2AuthorizationCodePKCE(OAuthAuthBase, SupportMultiAuth, BrowserAuth):
"""
Proof Key for Code Exchange
Expand Down Expand Up @@ -543,7 +548,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
if not self.token_url:
raise Exception("Token URL is mandatory.")

BrowserAuth.__init__(self, kwargs)
super().__init__(kwargs)

self.client = kwargs.pop("client", None)

Expand Down Expand Up @@ -612,16 +617,8 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
}
self.token_data.update(kwargs)

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
)
def _update_user_request(self, request: httpx.Request, token: str) -> None:
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple:
# Request code
Expand Down Expand Up @@ -682,7 +679,7 @@ def generate_code_challenge(verifier: bytes) -> bytes:
return base64.urlsafe_b64encode(digest).rstrip(b"=")


class OAuth2Implicit(httpx.Auth, SupportMultiAuth, BrowserAuth):
class OAuth2Implicit(OAuthAuthBase, SupportMultiAuth, BrowserAuth):
"""
Implicit Grant
Expand Down Expand Up @@ -732,7 +729,7 @@ def __init__(self, authorization_url: str, **kwargs):
if not self.authorization_url:
raise Exception("Authorization URL is mandatory.")

BrowserAuth.__init__(self, kwargs)
super().__init__(kwargs)

self.header_name = kwargs.pop("header_name", None) or "Authorization"
self.header_value = kwargs.pop("header_value", None) or "Bearer {token}"
Expand Down Expand Up @@ -778,18 +775,11 @@ def __init__(self, authorization_url: str, **kwargs):
self.redirect_uri_port,
)

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=oauth2_authentication_responses_server.request_new_grant,
grant_details=self.grant_details,
)
def _update_user_request(self, request: httpx.Request, token: str) -> None:
request.headers[self.header_name] = self.header_value.format(token=token)
yield request

def request_new_token(self) -> tuple[str, str]:
return oauth2_authentication_responses_server.request_new_grant(self.grant_details)

class AzureActiveDirectoryImplicit(OAuth2Implicit):
"""
Expand Down

0 comments on commit 2e6efa1

Please sign in to comment.