Skip to content

Commit

Permalink
Breaking: Use httpx generator-based auth-flow protocol in place of in…
Browse files Browse the repository at this point in the history
…ternal Clients.
  • Loading branch information
rafalkrupinski committed Mar 7, 2024
1 parent 164a47f commit 55c9385
Show file tree
Hide file tree
Showing 28 changed files with 247 additions and 230 deletions.
31 changes: 27 additions & 4 deletions httpx_auth/_authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing
from typing import Generator

import httpx
from httpx import Request, Response


class _MultiAuth(httpx.Auth):
Expand All @@ -9,11 +11,32 @@ class _MultiAuth(httpx.Auth):
def __init__(self, *authentication_modes):
self.authentication_modes = authentication_modes

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
def sync_auth_flow(
self, request: Request
) -> typing.Generator[Request, Response, None]:
for authentication_mode in self.authentication_modes:
# auth_flow may yield one or more requests, the last of which is the user request with added auth headers
flow = authentication_mode.sync_auth_flow(request)
req = next(flow)
while True:
if req is request:
break
resp = yield req
req = flow.send(resp)
yield request

async def async_auth_flow(
self, request: Request
) -> typing.AsyncGenerator[Request, Response]:
for authentication_mode in self.authentication_modes:
next(authentication_mode.auth_flow(request))
# auth_flow may yield one or more requests, the last of which is the user request with added auth headers
flow = authentication_mode.async_auth_flow(request)
req = await anext(flow)
while True:
if req is request:
break
resp = yield req
req = await flow.asend(resp)
yield request

def __add__(self, other) -> "_MultiAuth":
Expand Down
63 changes: 25 additions & 38 deletions httpx_auth/_oauth2/authorization_code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Generator
from hashlib import sha512
from typing import Iterable, Union

Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
:param code_field_name: Field name containing the code. code by default.
:param username: Username in case basic authentication should be used to retrieve token.
:param password: User password in case basic authentication should be used to retrieve token.
:param client: httpx.Client instance that will be used to request the token.
:param headers: Additional headers to set when requesting or refreshing token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL and as body parameters in the token URL.
Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
username = kwargs.pop("username", None)
password = kwargs.pop("password", None)
self.auth = (username, password) if username and password else None
self.client = kwargs.pop("client", None)
self.token_headers = kwargs.pop("headers", {})

# As described in https://tools.ietf.org/html/rfc6749#section-4.1.2
code_field_name = kwargs.pop("code_field_name", "code")
Expand Down Expand Up @@ -136,7 +137,11 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
self.refresh_token,
)

def request_new_token(self) -> tuple:
def request_new_token(
self,
) -> Generator[
httpx.Request, httpx.Response, Union[tuple[str, str], tuple[str, str, int]]
]:
# Request code
state, code = authentication_responses_server.request_new_grant(
self.code_grant_details
Expand All @@ -145,46 +150,30 @@ def request_new_token(self) -> tuple:
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.3
self.token_data["code"] = code

client = self.client or httpx.Client()
self._configure_client(client)
try:
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in, refresh_token = request_new_grant_with_post(
self.token_url, self.token_data, self.token_field_name, client
)
finally:
# Close client only if it was created by this module
if self.client is None:
client.close()
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in, refresh_token = yield from request_new_grant_with_post(
self.token_url, self.token_data, self.token_field_name, self.token_headers
)
# Handle both Access and Bearer tokens
return (
(self.state, token, expires_in, refresh_token)
if expires_in
else (self.state, token)
)

def refresh_token(self, refresh_token: str) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
# As described in https://tools.ietf.org/html/rfc6749#section-6
self.refresh_data["refresh_token"] = refresh_token
token, expires_in, refresh_token = request_new_grant_with_post(
self.token_url,
self.refresh_data,
self.token_field_name,
client,
)
finally:
# Close client only if it was created by this module
if self.client is None:
client.close()
def refresh_token(
self, refresh_token: str
) -> Generator[httpx.Request, httpx.Response, tuple[str, str, int, str]]:
# As described in https://tools.ietf.org/html/rfc6749#section-6
self.refresh_data["refresh_token"] = refresh_token
token, expires_in, refresh_token = yield from request_new_grant_with_post(
self.token_url,
self.refresh_data,
self.token_field_name,
self.token_headers,
)
return self.state, token, expires_in, refresh_token

def _configure_client(self, client: httpx.Client):
client.auth = self.auth
client.timeout = self.timeout


class OktaAuthorizationCode(OAuth2AuthorizationCode):
"""
Expand Down Expand Up @@ -220,8 +209,7 @@ def __init__(self, instance: str, client_id: str, **kwargs):
:param header_value: Format used to send the token value.
"{token}" must be present as it will be replaced by the actual token.
Token will be sent as "Bearer {token}" by default.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param headers: Additional headers to set when requesting or refreshing token.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL.
Usual parameters are:
Expand Down Expand Up @@ -276,8 +264,7 @@ def __init__(
:param header_value: Format used to send the token value.
"{token}" must be present as it will be replaced by the actual token.
Token will be sent as "Bearer {token}" by default.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param headers: Additional headers to set when requesting or refreshing token.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL.
"""
Expand Down
60 changes: 25 additions & 35 deletions httpx_auth/_oauth2/authorization_code_pkce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import os
from collections.abc import Generator
from hashlib import sha256, sha512
from typing import Union

import httpx

Expand Down Expand Up @@ -50,7 +52,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
Default to 30 seconds to ensure token will not expire between the time of retrieval and the time the request
reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry.
:param code_field_name: Field name containing the code. code by default.
:param client: httpx.Client instance that will be used to request the token.
:param headers: Additional headers to set when requesting or refreshing token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL and as body parameters in the token URL.
Expand All @@ -69,7 +71,7 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):

BrowserAuth.__init__(self, kwargs)

self.client = kwargs.pop("client", None)
self.token_headers = kwargs.pop("headers", {})

header_name = kwargs.pop("header_name", None) or "Authorization"
header_value = kwargs.pop("header_value", None) or "Bearer {token}"
Expand Down Expand Up @@ -140,7 +142,11 @@ def __init__(self, authorization_url: str, token_url: str, **kwargs):
self, state, early_expiry, header_name, header_value, self.refresh_token
)

def request_new_token(self) -> tuple:
def request_new_token(
self,
) -> Generator[
httpx.Request, httpx.Response, Union[tuple[str, str, int, str], tuple[str, str]]
]:
# Request code
state, code = authentication_responses_server.request_new_grant(
self.code_grant_details
Expand All @@ -149,45 +155,30 @@ def request_new_token(self) -> tuple:
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.3
self.token_data["code"] = code

client = self.client or httpx.Client()
self._configure_client(client)
try:
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in, refresh_token = request_new_grant_with_post(
self.token_url, self.token_data, self.token_field_name, client
)
finally:
# Close client only if it was created by this module
if self.client is None:
client.close()
# As described in https://tools.ietf.org/html/rfc6749#section-4.1.4
token, expires_in, refresh_token = yield from request_new_grant_with_post(
self.token_url, self.token_data, self.token_field_name, self.token_headers
)
# Handle both Access and Bearer tokens
return (
(self.state, token, expires_in, refresh_token)
if expires_in
else (self.state, token)
)

def refresh_token(self, refresh_token: str) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
# As described in https://tools.ietf.org/html/rfc6749#section-6
self.refresh_data["refresh_token"] = refresh_token
token, expires_in, refresh_token = request_new_grant_with_post(
self.token_url,
self.refresh_data,
self.token_field_name,
client,
)
finally:
# Close client only if it was created by this module
if self.client is None:
client.close()
def refresh_token(
self, refresh_token: str
) -> Generator[httpx.Request, httpx.Response, tuple[str, str, int, str]]:
# As described in https://tools.ietf.org/html/rfc6749#section-6
self.refresh_data["refresh_token"] = refresh_token
token, expires_in, refresh_token = yield from request_new_grant_with_post(
self.token_url,
self.refresh_data,
self.token_field_name,
self.token_headers,
)
return self.state, token, expires_in, refresh_token

def _configure_client(self, client: httpx.Client):
client.timeout = self.timeout

@staticmethod
def generate_code_verifier() -> bytes:
"""
Expand Down Expand Up @@ -256,8 +247,7 @@ def __init__(self, instance: str, client_id: str, **kwargs):
:param header_value: Format used to send the token value.
"{token}" must be present as it will be replaced by the actual token.
Token will be sent as "Bearer {token}" by default.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param headers: Additional headers to set when requesting or refreshing token.
:param kwargs: all additional authorization parameters that should be put as query parameter
in the authorization URL and as body parameters in the token URL.
Usual parameters are:
Expand Down
26 changes: 7 additions & 19 deletions httpx_auth/_oauth2/client_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
:param early_expiry: Number of seconds before actual token expiry where token will be considered as expired.
Default to 30 seconds to ensure token will not expire between the time of retrieval and the time the request
reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry.
:param client: httpx.Client instance that will be used to request the token.
Use it to provide a custom proxying rule for instance.
:param headers: Additional headers to set when requesting or refreshing token.
:param kwargs: all additional authorization parameters that should be put as query parameter in the token URL.
"""
self.token_url = token_url
Expand All @@ -58,7 +57,7 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
# Time is expressed in seconds
self.timeout = int(kwargs.pop("timeout", None) or 60)

self.client = kwargs.pop("client", None)
self.token_headers = kwargs.pop("headers", {})

# As described in https://tools.ietf.org/html/rfc6749#section-4.4.2
self.data = {"grant_type": "client_credentials"}
Expand All @@ -78,24 +77,13 @@ def __init__(self, token_url: str, client_id: str, client_secret: str, **kwargs)
)

def request_new_token(self) -> tuple:
client = self.client or httpx.Client()
self._configure_client(client)
try:
# As described in https://tools.ietf.org/html/rfc6749#section-4.4.3
token, expires_in, _ = request_new_grant_with_post(
self.token_url, self.data, self.token_field_name, client
)
finally:
# Close client only if it was created by this module
if self.client is None:
client.close()
# As described in https://tools.ietf.org/html/rfc6749#section-4.4.3
token, expires_in, _ = yield from request_new_grant_with_post(
self.token_url, self.data, self.token_field_name, self.token_headers
)
# Handle both Access and Bearer tokens
return (self.state, token, expires_in) if expires_in else (self.state, token)

def _configure_client(self, client: httpx.Client):
client.auth = (self.client_id, self.client_secret)
client.timeout = self.timeout


class OktaClientCredentials(OAuth2ClientCredentials):
"""
Expand Down Expand Up @@ -131,7 +119,7 @@ def __init__(
:param early_expiry: Number of seconds before actual token expiry where token will be considered as expired.
Default to 30 seconds to ensure token will not expire between the time of retrieval and the time the request
reaches the actual server. Set it to 0 to deactivate this feature and use the same token until actual expiry.
:param client: httpx.Client instance that will be used to request the token.
:param headers: Additional headers to set when requesting or refreshing token.
Use it to provide a custom proxying rule for instance.
:param kwargs: all additional authorization parameters that should be put as query parameter in the token URL.
"""
Expand Down
16 changes: 11 additions & 5 deletions httpx_auth/_oauth2/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from collections.abc import Mapping
from typing import Callable, Generator, Optional, Union
from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode

Expand Down Expand Up @@ -69,9 +70,9 @@ def _content_from_response(response: httpx.Response) -> dict:


def request_new_grant_with_post(
url: str, data, grant_name: str, client: httpx.Client
) -> (str, int, str):
response = client.post(url, data=data)
url: str, data, grant_name: str, headers: Mapping[str, str]
) -> Generator[httpx.Request, httpx.Response, tuple[str, int, str]]:
response = yield httpx.Request("post", url, data=data, headers=headers)

if response.is_error:
# As described in https://tools.ietf.org/html/rfc6749#section-5.2
Expand Down Expand Up @@ -106,11 +107,12 @@ def __init__(
self.header_name = header_name
self.header_value = header_value
self.refresh_token = refresh_token
self.requires_response_body = True

def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = OAuth2.token_cache.get_token(
token = yield from OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
Expand All @@ -120,7 +122,11 @@ def auth_flow(
yield request

@abc.abstractmethod
def request_new_token(self) -> Union[tuple[str, str], tuple[str, str, int]]:
def request_new_token(
self,
) -> Generator[
httpx.Request, httpx.Response, Union[tuple[str, str], tuple[str, str, int]]
]:
pass # pragma: no cover

def _update_user_request(self, request: httpx.Request, token: str) -> None:
Expand Down
8 changes: 7 additions & 1 deletion httpx_auth/_oauth2/implicit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from collections.abc import Generator
from hashlib import sha512

import httpx
Expand Down Expand Up @@ -109,7 +110,12 @@ def __init__(self, authorization_url: str, **kwargs):
header_value,
)

def request_new_token(self) -> tuple[str, str]:
def request_new_token(
self,
) -> Generator[httpx.Request, httpx.Response, tuple[str, str]]:
# make this function an empty generator
yield from ()

return authentication_responses_server.request_new_grant(self.grant_details)


Expand Down
Loading

0 comments on commit 55c9385

Please sign in to comment.