diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b48b38..c697fdf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,17 @@ repos: + - repo: local + hooks: + - id: async-generator + name: check async code + entry: python tools/generate_async.py -a --fix + language: python + types: [ python ] + pass_filenames: false + additional_dependencies: + - "ruff==0.7.4" + - "black==24.10.0" + - "docstrfmt==1.9.0" + - "niquests>=3.10,<4" - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: @@ -24,7 +37,7 @@ repos: files: ^(.*\.toml)$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.4 hooks: - id: ruff args: [ --exit-non-zero-on-fix, --fix ] diff --git a/CHANGES.rst b/CHANGES.rst index 0c0fa39..f8f2707 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,16 @@ prawcore follows `semantic versioning `_. Unreleased ---------- +**Added** + +- Asynchronous interfaces from our synchronous ones. Find :class:`AsyncRequestor` for + :class:`Requestor` and :class:`AsyncSession` for :class:`Session`. + +**Changed** + +- Switch HTTP Client Requests for the compatible Niquests. Support for HTTP/2+ + introduced with mirrored sync / async interfaces. + 2.4.0 (2023/10/01) ------------------ diff --git a/README.rst b/README.rst index 640f929..f4d0e46 100644 --- a/README.rst +++ b/README.rst @@ -85,6 +85,32 @@ Save the above as ``trophies.py`` and then execute via: Additional examples can be found at: https://github.com/praw-dev/prawcore/tree/main/examples +Or! with async/await! + +.. code-block:: python + + #!/usr/bin/env python + import os + import pprint + import asyncio + import prawcore + + + async def main(): + authenticator = prawcore.AsyncTrustedAuthenticator( + prawcore.Requestor("YOUR_VALID_USER_AGENT"), + os.environ["PRAWCORE_CLIENT_ID"], + os.environ["PRAWCORE_CLIENT_SECRET"], + ) + authorizer = prawcore.AsyncReadOnlyAuthorizer(authenticator) + await authorizer.refresh() + + async with prawcore.async_session(authorizer) as session: + pprint.pprint(await session.request("GET", "/api/v1/user/bboe/trophies")) + + + asyncio.run(main()) + Depending on prawcore --------------------- diff --git a/examples/caching_requestor.py b/examples/caching_requestor.py index 6d0dc84..c483d3d 100755 --- a/examples/caching_requestor.py +++ b/examples/caching_requestor.py @@ -10,12 +10,12 @@ import os import sys -import requests +import niquests import prawcore -class CachingSession(requests.Session): +class CachingSession(niquests.Session): """Cache GETs in memory. Toy example of custom session to showcase the ``session`` parameter of diff --git a/prawcore/__init__.py b/prawcore/__init__.py index 65fce52..1d53208 100644 --- a/prawcore/__init__.py +++ b/prawcore/__init__.py @@ -2,6 +2,18 @@ import logging +from ._async.auth import ( + AsyncAuthorizer, + AsyncDeviceIDAuthorizer, + AsyncImplicitAuthorizer, + AsyncReadOnlyAuthorizer, + AsyncScriptAuthorizer, + AsyncTrustedAuthenticator, + AsyncUntrustedAuthenticator, +) +from ._async.requestor import AsyncRequestor +from ._async.sessions import AsyncSession +from ._async.sessions import session as async_session from .auth import ( Authorizer, DeviceIDAuthorizer, diff --git a/prawcore/_async/__init__.py b/prawcore/_async/__init__.py new file mode 100644 index 0000000..0f2b88c --- /dev/null +++ b/prawcore/_async/__init__.py @@ -0,0 +1,2 @@ +# this part should be autogenerated +# mirror of our synchronous part. diff --git a/prawcore/_async/auth.py b/prawcore/_async/auth.py new file mode 100644 index 0000000..b736b1f --- /dev/null +++ b/prawcore/_async/auth.py @@ -0,0 +1,477 @@ +"""Provides Authentication and Authorization classes.""" + +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Awaitable, Callable + +from niquests import Request +from niquests.status_codes import codes + +from .. import const +from ..exceptions import InvalidInvocation, OAuthException, ResponseException + +if TYPE_CHECKING: + from niquests.models import Response + + from .requestor import AsyncRequestor + + +class AsyncBaseAuthenticator(ABC): + """Provide the base authenticator object that stores OAuth2 credentials.""" + + @abstractmethod + def _auth(self) -> tuple[str, str]: + pass + + def __init__( + self, + requestor: AsyncRequestor, + client_id: str, + redirect_uri: str | None = None, + ): + """Represent a single authentication to Reddit's API. + + :param requestor: An instance of :class:`.AsyncRequestor`. + :param client_id: The OAuth2 client ID to use with the session. + :param redirect_uri: The redirect URI exactly as specified in your OAuth + application settings on Reddit. This parameter is required if you want to + use the :meth:`~.Authorizer.authorize_url` method, or the + :meth:`~.Authorizer.authorize` method of the :class:`.AsyncAuthorizer` class + (default: ``None``). + + """ + self._requestor = requestor + self.client_id = client_id + self.redirect_uri = redirect_uri + + async def _post( + self, url: str, success_status: int = codes["ok"], **data: Any + ) -> Response: + response = await self._requestor.request( + "post", + url, + auth=self._auth(), + data=sorted(data.items()), + headers={"Connection": "close"}, + ) + if response.status_code != success_status: + raise ResponseException(response) + return response + + def authorize_url( + self, duration: str, scopes: list[str], state: str, implicit: bool = False + ) -> str: + """Return the URL used out-of-band to grant access to your application. + + :param duration: Either ``"permanent"`` or ``"temporary"``. ``"temporary"`` + authorizations generate access tokens that last only 1 hour. ``"permanent"`` + authorizations additionally generate a refresh token that can be + indefinitely used to generate new hour-long access tokens. Only + ``"temporary"`` can be specified if ``implicit`` is set to ``True``. + :param scopes: A list of OAuth scopes to request authorization for. + :param state: A string that will be reflected in the callback to + ``redirect_uri``. Elements must be printable ASCII characters in the range + ``0x20`` through ``0x7E`` inclusive. This value should be temporarily unique + to the client for whom the URL was generated. + :param implicit: Use the implicit grant flow (default: ``False``). This flow is + only available for ``UntrustedAuthenticators``. + + :returns: URL to be used out-of-band for granting access to your application. + + :raises: :class:`.InvalidInvocation` if ``redirect_uri`` is not provided, if + ``implicit`` is ``True`` and an authenticator other than + :class:`.AsyncUntrustedAuthenticator` is used, or ``implicit`` is ``True`` + and ``duration`` is ``"permanent"``. + + """ + if self.redirect_uri is None: + msg = "redirect URI not provided" + raise InvalidInvocation(msg) + if implicit and not isinstance(self, AsyncUntrustedAuthenticator): + msg = "Only AsyncUntrustedAuthenticator instances can use the implicit grant flow." + raise InvalidInvocation(msg) + if implicit and duration != "temporary": + msg = "The implicit grant flow only supports temporary access tokens." + raise InvalidInvocation(msg) + + params = { + "client_id": self.client_id, + "duration": duration, + "redirect_uri": self.redirect_uri, + "response_type": "token" if implicit else "code", + "scope": " ".join(scopes), + "state": state, + } + url = self._requestor.reddit_url + const.AUTHORIZATION_PATH + request = Request("GET", url, params=params) + return request.prepare().url + + async def revoke_token(self, token: str, token_type: str | None = None): + """Ask Reddit to revoke the provided token. + + :param token: The access or refresh token to revoke. + :param token_type: When provided, hint to Reddit what the token type is for a + possible efficiency gain. The value can be either ``"access_token"`` or + ``"refresh_token"``. + + """ + data = {"token": token} + if token_type is not None: + data["token_type_hint"] = token_type + url = self._requestor.reddit_url + const.REVOKE_TOKEN_PATH + await self._post(url, **data) + + +class AsyncBaseAuthorizer: + """Superclass for OAuth2 authorization tokens and scopes.""" + + AUTHENTICATOR_CLASS: tuple | type = AsyncBaseAuthenticator + + def __init__(self, authenticator: AsyncBaseAuthenticator): + """Represent a single authorization to Reddit's API. + + :param authenticator: An instance of :class:`.AsyncBaseAuthenticator`. + + """ + self._authenticator = authenticator + self._clear_access_token() + self._validate_authenticator() + + def _clear_access_token(self): + self._expiration_timestamp: float + self.access_token: str | None = None + self.scopes: set[str] | None = None + + async def _request_token(self, **data: Any): + url = self._authenticator._requestor.reddit_url + const.ACCESS_TOKEN_PATH + pre_request_time = time.time() + response = await self._authenticator._post(url=url, **data) + payload = response.json() + if "error" in payload: # Why are these OKAY responses? + raise OAuthException( + response, payload["error"], payload.get("error_description") + ) + + self._expiration_timestamp = pre_request_time - 10 + payload["expires_in"] + self.access_token = payload["access_token"] + if "refresh_token" in payload: + self.refresh_token = payload["refresh_token"] + self.scopes = set(payload["scope"].split(" ")) + + def _validate_authenticator(self): + if not isinstance(self._authenticator, self.AUTHENTICATOR_CLASS): + msg = "Must use an authenticator of type" + if isinstance(self.AUTHENTICATOR_CLASS, type): + msg += f" {self.AUTHENTICATOR_CLASS.__name__}." + else: + msg += ( + f" {' or '.join([i.__name__ for i in self.AUTHENTICATOR_CLASS])}." + ) + raise InvalidInvocation(msg) + + def is_valid(self) -> bool: + """Return whether the :class`.Authorizer` is ready to authorize requests. + + A ``True`` return value does not guarantee that the ``access_token`` is actually + valid on the server side. + + """ + return ( + self.access_token is not None and time.time() < self._expiration_timestamp + ) + + async def revoke(self): + """Revoke the current Authorization.""" + if self.access_token is None: + msg = "no token available to revoke" + raise InvalidInvocation(msg) + + await self._authenticator.revoke_token(self.access_token, "access_token") + self._clear_access_token() + + +class AsyncTrustedAuthenticator(AsyncBaseAuthenticator): + """Store OAuth2 authentication credentials for web, or script type apps.""" + + RESPONSE_TYPE: str = "code" + + def __init__( + self, + requestor: AsyncRequestor, + client_id: str, + client_secret: str, + redirect_uri: str | None = None, + ): + """Represent a single authentication to Reddit's API. + + :param requestor: An instance of :class:`.AsyncRequestor`. + :param client_id: The OAuth2 client ID to use with the session. + :param client_secret: The OAuth2 client secret to use with the session. + :param redirect_uri: The redirect URI exactly as specified in your OAuth + application settings on Reddit. This parameter is required if you want to + use the :meth:`~.Authorizer.authorize_url` method, or the + :meth:`~.Authorizer.authorize` method of the :class:`.AsyncAuthorizer` class + (default: ``None``). + + """ + super().__init__(requestor, client_id, redirect_uri) + self.client_secret = client_secret + + def _auth(self) -> tuple[str, str]: + return self.client_id, self.client_secret + + +class AsyncUntrustedAuthenticator(AsyncBaseAuthenticator): + """Store OAuth2 authentication credentials for installed applications.""" + + def _auth(self) -> tuple[str, str]: + return self.client_id, "" + + +class AsyncAuthorizer(AsyncBaseAuthorizer): + """Manages OAuth2 authorization tokens and scopes.""" + + def __init__( + self, + authenticator: AsyncBaseAuthenticator, + *, + post_refresh_callback: ( + Callable[[AsyncAuthorizer], Awaitable[None]] | None + ) = None, + pre_refresh_callback: ( + Callable[[AsyncAuthorizer], Awaitable[None]] | None + ) = None, + refresh_token: str | None = None, + ): + """Represent a single authorization to Reddit's API. + + :param authenticator: An instance of a subclass of + :class:`.AsyncBaseAuthenticator`. + :param post_refresh_callback: When a single-argument function is passed, the + function will be called prior to refreshing the access and refresh tokens. + The argument to the callback is the :class:`.AsyncAuthorizer` instance. This + callback can be used to inspect and modify the attributes of the + :class:`.AsyncAuthorizer`. + :param pre_refresh_callback: When a single-argument function is passed, the + function will be called after refreshing the access and refresh tokens. The + argument to the callback is the :class:`.AsyncAuthorizer` instance. This + callback can be used to inspect and modify the attributes of the + :class:`.AsyncAuthorizer`. + :param refresh_token: Enables the ability to refresh the authorization. + + """ + super().__init__(authenticator) + self._post_refresh_callback = post_refresh_callback + self._pre_refresh_callback = pre_refresh_callback + self.refresh_token = refresh_token + + async def authorize(self, code: str): + """Obtain and set authorization tokens based on ``code``. + + :param code: The code obtained by an out-of-band authorization request to + Reddit. + + """ + if self._authenticator.redirect_uri is None: + msg = "redirect URI not provided" + raise InvalidInvocation(msg) + await self._request_token( + code=code, + grant_type="authorization_code", + redirect_uri=self._authenticator.redirect_uri, + ) + + async def refresh(self): + """Obtain a new access token from the refresh_token.""" + if self._pre_refresh_callback: + await self._pre_refresh_callback(self) + if self.refresh_token is None: + msg = "refresh token not provided" + raise InvalidInvocation(msg) + await self._request_token( + grant_type="refresh_token", refresh_token=self.refresh_token + ) + if self._post_refresh_callback: + await self._post_refresh_callback(self) + + async def revoke(self, only_access: bool = False): + """Revoke the current Authorization. + + :param only_access: When explicitly set to ``True``, do not evict the refresh + token if one is set. + + Revoking a refresh token will in-turn revoke all access tokens associated with + that authorization. + + """ + if only_access or self.refresh_token is None: + await super().revoke() + else: + await self._authenticator.revoke_token(self.refresh_token, "refresh_token") + self._clear_access_token() + self.refresh_token = None + + +class AsyncImplicitAuthorizer(AsyncBaseAuthorizer): + """Manages implicit installed-app type authorizations.""" + + AUTHENTICATOR_CLASS = AsyncUntrustedAuthenticator + + def __init__( + self, + authenticator: AsyncUntrustedAuthenticator, + access_token: str, + expires_in: int, + scope: str, + ): + """Represent a single implicit authorization to Reddit's API. + + :param authenticator: An instance of :class:`.AsyncUntrustedAuthenticator`. + :param access_token: The access_token obtained from Reddit via callback to the + authenticator's ``redirect_uri``. + :param expires_in: The number of seconds the ``access_token`` is valid for. The + origin of this value was returned from Reddit via callback to the + authenticator's redirect uri. Note, you may need to subtract an offset + before passing in this number to account for a delay between when Reddit + prepared the response, and when you make this function call. + :param scope: A space-delimited string of Reddit OAuth2 scope names as returned + from Reddit in the callback to the authenticator's redirect uri. + + """ + super().__init__(authenticator) + self._expiration_timestamp = time.time() + expires_in + self.access_token = access_token + self.scopes = set(scope.split(" ")) + + +class AsyncReadOnlyAuthorizer(AsyncAuthorizer): + """Manages authorizations that are not associated with a Reddit account. + + While the ``"*"`` scope will be available, some endpoints simply will not work due + to the lack of an associated Reddit account. + + """ + + AUTHENTICATOR_CLASS = AsyncTrustedAuthenticator + + def __init__( + self, + authenticator: AsyncBaseAuthenticator, + scopes: list[str] | None = None, + ): + """Represent a ReadOnly authorization to Reddit's API. + + :param scopes: A list of OAuth scopes to request authorization for (default: + ``None``). The scope ``"*"`` is requested when the default argument is used. + + """ + super().__init__(authenticator) + self._scopes = scopes + + async def refresh(self): + """Obtain a new ReadOnly access token.""" + additional_kwargs = {} + if self._scopes: + additional_kwargs["scope"] = " ".join(self._scopes) + await self._request_token(grant_type="client_credentials", **additional_kwargs) + + +class AsyncScriptAuthorizer(AsyncAuthorizer): + """Manages personal-use script type authorizations. + + Only users who are listed as developers for the application will be granted access + tokens. + + """ + + AUTHENTICATOR_CLASS = AsyncTrustedAuthenticator + + def __init__( + self, + authenticator: AsyncBaseAuthenticator, + username: str | None, + password: str | None, + two_factor_callback: Callable | None = None, + scopes: list[str] | None = None, + ): + """Represent a single personal-use authorization to Reddit's API. + + :param authenticator: An instance of :class:`.AsyncTrustedAuthenticator`. + :param username: The Reddit username of one of the application's developers. + :param password: The password associated with ``username``. + :param two_factor_callback: A function that returns OTPs (One-Time Passcodes), + also known as 2FA auth codes. If this function is provided, prawcore will + call it when authenticating. + :param scopes: A list of OAuth scopes to request authorization for (default: + ``None``). The scope ``"*"`` is requested when the default argument is used. + + """ + super().__init__(authenticator) + self._password = password + self._scopes = scopes + self._two_factor_callback = two_factor_callback + self._username = username + + async def refresh(self): + """Obtain a new personal-use script type access token.""" + additional_kwargs = {} + if self._scopes: + additional_kwargs["scope"] = " ".join(self._scopes) + two_factor_code = self._two_factor_callback and self._two_factor_callback() + if two_factor_code: + additional_kwargs["otp"] = two_factor_code + await self._request_token( + grant_type="password", + username=self._username, + password=self._password, + **additional_kwargs, + ) + + +class AsyncDeviceIDAuthorizer(AsyncBaseAuthorizer): + """Manages app-only OAuth2 for 'installed' applications. + + While the ``"*"`` scope will be available, some endpoints simply will not work due + to the lack of an associated Reddit account. + + """ + + AUTHENTICATOR_CLASS = (AsyncTrustedAuthenticator, AsyncUntrustedAuthenticator) + + def __init__( + self, + authenticator: AsyncBaseAuthenticator, + device_id: str | None = None, + scopes: list[str] | None = None, + ): + """Represent an app-only OAuth2 authorization for 'installed' apps. + + :param authenticator: An instance of :class:`.AsyncUntrustedAuthenticator` or + :class:`.AsyncTrustedAuthenticator`. + :param device_id: A unique ID (20-30 character ASCII string) (default: + ``None``). ``device_id`` is set to ``"DO_NOT_TRACK_THIS_DEVICE"`` when the + default argument is used. For more information about this parameter, see: + https://github.com/reddit/reddit/wiki/OAuth2#application-only-oauth + :param scopes: A list of OAuth scopes to request authorization for (default: + ``None``). The scope ``"*"`` is requested when the default argument is used. + + """ + if device_id is None: + device_id = "DO_NOT_TRACK_THIS_DEVICE" + super().__init__(authenticator) + self._device_id = device_id + self._scopes = scopes + + async def refresh(self): + """Obtain a new access token.""" + additional_kwargs = {} + if self._scopes: + additional_kwargs["scope"] = " ".join(self._scopes) + grant_type = "https://oauth.reddit.com/grants/installed_client" + await self._request_token( + grant_type=grant_type, + device_id=self._device_id, + **additional_kwargs, + ) diff --git a/prawcore/_async/rate_limit.py b/prawcore/_async/rate_limit.py new file mode 100644 index 0000000..762a2b2 --- /dev/null +++ b/prawcore/_async/rate_limit.py @@ -0,0 +1,105 @@ +"""Provide the AsyncRateLimiter class.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Mapping + +if TYPE_CHECKING: + from niquests.models import Response + +log = logging.getLogger(__package__) + + +class AsyncRateLimiter: + """Facilitates the rate limiting of requests to Reddit. + + Rate limits are controlled based on feedback from requests to Reddit. + + """ + + def __init__(self, *, window_size: int): + """Create an instance of the RateLimit class.""" + self.remaining: float | None = None + self.next_request_timestamp: float | None = None + self.reset_timestamp: float | None = None + self.used: int | None = None + self.window_size: int = window_size + + async def call( + self, + request_function: Callable[[Any], Awaitable[Response]], + set_header_callback: Callable[[], Awaitable[dict[str, str]]], + *args: Any, + **kwargs: Any, + ) -> Response: + """Rate limit the call to ``request_function``. + + :param request_function: A function call that returns an HTTP response object. + :param set_header_callback: A callback function used to set the request headers. + This callback is called after any necessary sleep time occurs. + :param args: The positional arguments to ``request_function``. + :param kwargs: The keyword arguments to ``request_function``. + + """ + await self.delay() + kwargs["headers"] = await set_header_callback() + response = await request_function(*args, **kwargs) + self.update(response.headers) + return response + + async def delay(self): + """Sleep for an amount of time to remain under the rate limit.""" + if self.next_request_timestamp is None: + return + sleep_seconds = self.next_request_timestamp - time.time() + if sleep_seconds <= 0: + return + message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to call" + log.debug(message) + await asyncio.sleep(sleep_seconds) + + def update(self, response_headers: Mapping[str, str]): + """Update the state of the rate limiter based on the response headers. + + This method should only be called following an HTTP request to Reddit. + + Response headers that do not contain ``x-ratelimit`` fields will be treated as a + single request. This behavior is to error on the safe-side as such responses + should trigger exceptions that indicate invalid behavior. + + """ + if "x-ratelimit-remaining" not in response_headers: + if self.remaining is not None: + self.remaining -= 1 + self.used += 1 + return + + now = time.time() + + seconds_to_reset = int(response_headers["x-ratelimit-reset"]) + self.remaining = float(response_headers["x-ratelimit-remaining"]) + self.used = int(response_headers["x-ratelimit-used"]) + self.reset_timestamp = now + seconds_to_reset + + if self.remaining <= 0: + self.next_request_timestamp = self.reset_timestamp + return + + self.next_request_timestamp = min( + self.reset_timestamp, + now + + min( + max( + seconds_to_reset + - ( + self.window_size + - (self.window_size / (self.remaining + self.used) * self.used) + ), + 0, + ), + 10, + ), + ) diff --git a/prawcore/_async/requestor.py b/prawcore/_async/requestor.py new file mode 100644 index 0000000..21543d8 --- /dev/null +++ b/prawcore/_async/requestor.py @@ -0,0 +1,74 @@ +"""Provides the HTTP request handling interface.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import niquests + +from ..const import TIMEOUT +from ..exceptions import InvalidInvocation, RequestException + +if TYPE_CHECKING: + from niquests import AsyncSession, Response + + +class AsyncRequestor: + """Requestor provides an interface to HTTP requests.""" + + def __getattr__(self, attribute: str) -> Any: + """Pass all undefined attributes to the ``_http`` attribute.""" + if attribute.startswith("__"): + raise AttributeError + return getattr(self._http, attribute) + + def __init__( + self, + user_agent: str, + oauth_url: str = "https://oauth.reddit.com", + reddit_url: str = "https://www.reddit.com", + session: AsyncSession | None = None, + timeout: float = TIMEOUT, + ): + """Create an instance of the AsyncRequestor class. + + :param user_agent: The user-agent for your application. Please follow Reddit's + user-agent guidelines: https://github.com/reddit/reddit/wiki/API#rules + :param oauth_url: The URL used to make OAuth requests to the Reddit site + (default: ``"https://oauth.reddit.com"``). + :param reddit_url: The URL used when obtaining access tokens (default: + ``"https://www.reddit.com"``). + :param session: A session instance to handle requests, compatible with + ``niquests.AsyncSession()`` (default: ``None``). + :param timeout: How many seconds to wait for the server to send data before + giving up (default: ``prawcore.const.TIMEOUT``). + + """ + # Imported locally to avoid an import cycle, with __init__ + from .. import __version__ + + if user_agent is None or len(user_agent) < 7: + msg = "user_agent is not descriptive" + raise InvalidInvocation(msg) + + self._http = session or niquests.AsyncSession() + self._http.headers["User-Agent"] = f"{user_agent} prawcore/{__version__}" + + self.oauth_url = oauth_url + self.reddit_url = reddit_url + self.timeout = timeout + + async def close(self): + """Call close on the underlying session.""" + await self._http.close() + + async def request( + self, *args: Any, timeout: float | None = None, **kwargs: Any + ) -> Response: + """Issue the HTTP request capturing any errors that may occur.""" + try: + return await self._http.request( + *args, timeout=timeout or self.timeout, **kwargs + ) + except Exception as exc: # noqa: BLE001 + raise RequestException(exc, args, kwargs) from exc diff --git a/prawcore/_async/sessions.py b/prawcore/_async/sessions.py new file mode 100644 index 0000000..749c207 --- /dev/null +++ b/prawcore/_async/sessions.py @@ -0,0 +1,375 @@ +"""prawcore.sessions: Provides prawcore.Session and prawcore.session.""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from abc import ABC, abstractmethod +from copy import deepcopy +from pprint import pformat +from typing import TYPE_CHECKING, Any, BinaryIO, TextIO +from urllib.parse import urljoin + +from niquests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout +from niquests.status_codes import codes + +from ..const import TIMEOUT, WINDOW_SIZE +from ..exceptions import ( + BadJSON, + BadRequest, + Conflict, + InvalidInvocation, + NotFound, + Redirect, + RequestException, + ServerError, + SpecialError, + TooLarge, + TooManyRequests, + UnavailableForLegalReasons, + URITooLong, +) +from ..util import authorization_error_class +from .auth import AsyncBaseAuthorizer +from .rate_limit import AsyncRateLimiter + +if TYPE_CHECKING: + from niquests.models import Response + + from .auth import AsyncAuthorizer + from .requestor import AsyncRequestor + +log = logging.getLogger(__package__) + + +class AsyncRetryStrategy(ABC): + """An abstract class for scheduling request retries. + + The strategy controls both the number and frequency of retry attempts. + + Instances of this class are immutable. + + """ + + @abstractmethod + def _sleep_seconds(self) -> float | None: + pass + + async def sleep(self): + """Sleep until we are ready to attempt the request.""" + sleep_seconds = self._sleep_seconds() + if sleep_seconds is not None: + message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to retry" + log.debug(message) + await asyncio.sleep(sleep_seconds) + + +class AsyncSession: + """The low-level connection interface to Reddit's API.""" + + RETRY_EXCEPTIONS = (ChunkedEncodingError, ConnectionError, ReadTimeout) + RETRY_STATUSES = { + 520, + 522, + codes["bad_gateway"], + codes["gateway_timeout"], + codes["internal_server_error"], + codes["request_timeout"], + codes["service_unavailable"], + } + STATUS_EXCEPTIONS = { + codes["bad_gateway"]: ServerError, + codes["bad_request"]: BadRequest, + codes["conflict"]: Conflict, + codes["found"]: Redirect, + codes["forbidden"]: authorization_error_class, + codes["gateway_timeout"]: ServerError, + codes["internal_server_error"]: ServerError, + codes["media_type"]: SpecialError, + codes["moved_permanently"]: Redirect, + codes["not_found"]: NotFound, + codes["request_entity_too_large"]: TooLarge, + codes["request_uri_too_large"]: URITooLong, + codes["service_unavailable"]: ServerError, + codes["too_many_requests"]: TooManyRequests, + codes["unauthorized"]: authorization_error_class, + codes[ + "unavailable_for_legal_reasons" + ]: UnavailableForLegalReasons, # Cloudflare's status (not named in niquests) + 520: ServerError, + 522: ServerError, + } + SUCCESS_STATUSES = {codes["accepted"], codes["created"], codes["ok"]} + + @staticmethod + def _log_request( + data: list[tuple[str, str]] | None, + method: str, + params: dict[str, int], + url: str, + ): + log.debug("Fetching: %s %s at %s", method, url, time.time()) + log.debug("Data: %s", pformat(data)) + log.debug("Params: %s", pformat(params)) + + @property + def _requestor(self) -> AsyncRequestor: + return self._authorizer._authenticator._requestor + + async def __aenter__(self) -> AsyncSession: # noqa: PYI034 + """Allow this object to be used as a context manager.""" + return self + + async def __aexit__(self, *_args): + """Allow this object to be used as a context manager.""" + await self.close() + + def __init__( + self, + authorizer: AsyncBaseAuthorizer | None, + window_size: int = WINDOW_SIZE, + ): + """Prepare the connection to Reddit's API. + + :param authorizer: An instance of :class:`.AsyncAuthorizer`. + :param window_size: The size of the rate limit reset window in seconds. + + """ + if not isinstance(authorizer, AsyncBaseAuthorizer): + msg = f"invalid AsyncAuthorizer: {authorizer}" + raise InvalidInvocation(msg) + self._authorizer = authorizer + self._rate_limiter = AsyncRateLimiter(window_size=window_size) + self._retry_strategy_class = AsyncFiniteRetryStrategy + + async def _do_retry( + self, + data: list[tuple[str, Any]], + files: dict[str, BinaryIO | TextIO], + json: dict[str, Any], + method: str, + params: dict[str, int], + response: Response | None, + retry_strategy_state: AsyncFiniteRetryStrategy, + saved_exception: Exception | None, + timeout: float, + url: str, + ) -> dict[str, Any] | str | None: + status = repr(saved_exception) if saved_exception else response.status_code + log.warning("Retrying due to %s status: %s %s", status, method, url) + return await self._request_with_retries( + data=data, + files=files, + json=json, + method=method, + params=params, + timeout=timeout, + url=url, + retry_strategy_state=retry_strategy_state.consume_available_retry(), + # noqa: E501 + ) + + async def _make_request( + self, + data: list[tuple[str, Any]], + files: dict[str, BinaryIO | TextIO], + json: dict[str, Any], + method: str, + params: dict[str, Any], + retry_strategy_state: AsyncFiniteRetryStrategy, + timeout: float, + url: str, + ) -> tuple[Response, None] | tuple[None, Exception]: + try: + response = await self._rate_limiter.call( + self._requestor.request, + self._set_header_callback, + method, + url, + allow_redirects=False, + data=data, + files=files, + json=json, + params=params, + timeout=timeout, + ) + log.debug( + "Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s", + response.status_code, + response.headers.get("content-length"), + response.headers.get("x-ratelimit-reset"), + response.headers.get("x-ratelimit-remaining"), + response.headers.get("x-ratelimit-used"), + time.time(), + ) + return response, None + except RequestException as exception: + if ( + not retry_strategy_state.should_retry_on_failure() + or not isinstance( # noqa: E501 + exception.original_exception, self.RETRY_EXCEPTIONS + ) + ): + raise + return None, exception.original_exception + + async def _request_with_retries( + self, + data: list[tuple[str, Any]], + files: dict[str, BinaryIO | TextIO], + json: dict[str, Any], + method: str, + params: dict[str, Any], + timeout: float, + url: str, + retry_strategy_state: AsyncFiniteRetryStrategy | None = None, + ) -> dict[str, Any] | str | None: + if retry_strategy_state is None: + retry_strategy_state = self._retry_strategy_class() + + await retry_strategy_state.sleep() + self._log_request(data, method, params, url) + response, saved_exception = await self._make_request( + data, + files, + json, + method, + params, + retry_strategy_state, + timeout, + url, + ) + + do_retry = False + if response is not None and response.status_code == codes["unauthorized"]: + self._authorizer._clear_access_token() + if hasattr(self._authorizer, "refresh"): + do_retry = True + + if retry_strategy_state.should_retry_on_failure() and ( + do_retry or response is None or response.status_code in self.RETRY_STATUSES + ): + return await self._do_retry( + data, + files, + json, + method, + params, + response, + retry_strategy_state, + saved_exception, + timeout, + url, + ) + if response.status_code in self.STATUS_EXCEPTIONS: + raise self.STATUS_EXCEPTIONS[response.status_code](response) + if response.status_code == codes["no_content"]: + return None + assert ( + response.status_code in self.SUCCESS_STATUSES + ), f"Unexpected status code: {response.status_code}" + if response.headers.get("content-length") == "0": + return "" + try: + return response.json() + except ValueError: + raise BadJSON(response) from None + + async def _set_header_callback(self) -> dict[str, str]: + if not self._authorizer.is_valid() and hasattr(self._authorizer, "refresh"): + await self._authorizer.refresh() + return {"Authorization": f"bearer {self._authorizer.access_token}"} + + async def close(self): + """Close the session and perform any clean up.""" + await self._requestor.close() + + async def request( + self, + method: str, + path: str, + data: dict[str, Any] | None = None, + files: dict[str, BinaryIO | TextIO] | None = None, + json: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + timeout: float = TIMEOUT, + ) -> dict[str, Any] | str | None: + """Return the json content from the resource at ``path``. + + :param method: The request verb. E.g., ``"GET"``, ``"POST"``, ``"PUT"``. + :param path: The path of the request. This path will be combined with the + ``oauth_url`` of the AsyncRequestor. + :param data: Dictionary, bytes, or file-like object to send in the body of the + request. + :param files: Dictionary, mapping ``filename`` to file-like object. + :param json: Object to be serialized to JSON in the body of the request. + :param params: The query parameters to send with the request. + :param timeout: Specifies a particular timeout, in seconds. + + Automatically refreshes the access token if it becomes invalid and a refresh + token is available. + + :raises: :class:`.InvalidInvocation` in such a case if a refresh token is not + available. + + """ + params = deepcopy(params) or {} + params["raw_json"] = 1 + if isinstance(data, dict): + data = deepcopy(data) + data["api_type"] = "json" + data = sorted(data.items()) + if isinstance(json, dict): + json = deepcopy(json) + json["api_type"] = "json" + url = urljoin(self._requestor.oauth_url, path) + return await self._request_with_retries( + data=data, + files=files, + json=json, + method=method, + params=params, + timeout=timeout, + url=url, + ) + + +def session( + authorizer: AsyncAuthorizer = None, + window_size: int = WINDOW_SIZE, +) -> AsyncSession: + """Return a :class:`.AsyncSession` instance. + + :param authorizer: An instance of :class:`.AsyncAuthorizer`. + :param window_size: The size of the rate limit reset window in seconds. + + """ + return AsyncSession(authorizer=authorizer, window_size=window_size) + + +class AsyncFiniteRetryStrategy(AsyncRetryStrategy): + """A ``RetryStrategy`` that retries requests a finite number of times.""" + + def __init__(self, retries: int = 3): + """Initialize the strategy. + + :param retries: Number of times to attempt a request (default: ``3``). + + """ + self._retries = retries + + def _sleep_seconds(self) -> float | None: + if self._retries < 3: + base = 0 if self._retries == 2 else 2 + return base + 2 * random.random() # noqa: S311 + return None + + def consume_available_retry(self) -> AsyncFiniteRetryStrategy: + """Allow one fewer retry.""" + return type(self)(self._retries - 1) + + def should_retry_on_failure(self) -> bool: + """Return ``True`` if and only if the strategy will allow another retry.""" + return self._retries > 1 diff --git a/prawcore/auth.py b/prawcore/auth.py index 2c977bc..7556d41 100644 --- a/prawcore/auth.py +++ b/prawcore/auth.py @@ -6,14 +6,14 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable -from requests import Request -from requests.status_codes import codes +from niquests import Request +from niquests.status_codes import codes from . import const from .exceptions import InvalidInvocation, OAuthException, ResponseException if TYPE_CHECKING: - from requests.models import Response + from niquests.models import Response from prawcore.requestor import Requestor diff --git a/prawcore/exceptions.py b/prawcore/exceptions.py index f29e153..3eb8acf 100644 --- a/prawcore/exceptions.py +++ b/prawcore/exceptions.py @@ -6,7 +6,7 @@ from urllib.parse import urlparse if TYPE_CHECKING: - from requests.models import Response + from niquests.models import Response class PrawcoreException(Exception): # noqa: N818 @@ -23,7 +23,7 @@ class OAuthException(PrawcoreException): def __init__(self, response: Response, error: str, description: str | None = None): """Initialize a OAuthException instance. - :param response: A ``requests.response`` instance. + :param response: A ``niquests.response`` instance. :param error: The error type returned by Reddit. :param description: A description of the error when provided. @@ -67,7 +67,7 @@ class ResponseException(PrawcoreException): def __init__(self, response: Response): """Initialize a ResponseException instance. - :param response: A ``requests.response`` instance. + :param response: A ``niquests.response`` instance. """ self.response = response @@ -113,7 +113,7 @@ class Redirect(ResponseException): def __init__(self, response: Response): """Initialize a Redirect exception instance. - :param response: A ``requests.response`` instance containing a location header. + :param response: A ``niquests.response`` instance containing a location header. """ path = urlparse(response.headers["location"]).path @@ -139,7 +139,7 @@ class SpecialError(ResponseException): def __init__(self, response: Response): """Initialize a SpecialError exception instance. - :param response: A ``requests.response`` instance containing a message and a + :param response: A ``niquests.response`` instance containing a message and a list of special errors. """ @@ -162,7 +162,7 @@ class TooManyRequests(ResponseException): def __init__(self, response: Response): """Initialize a TooManyRequests exception instance. - :param response: A ``requests.response`` instance that may contain a retry-after + :param response: A ``niquests.response`` instance that may contain a retry-after header and a message. """ diff --git a/prawcore/rate_limit.py b/prawcore/rate_limit.py index 4ddbb08..1cb55df 100644 --- a/prawcore/rate_limit.py +++ b/prawcore/rate_limit.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping if TYPE_CHECKING: - from requests.models import Response + from niquests.models import Response log = logging.getLogger(__package__) diff --git a/prawcore/requestor.py b/prawcore/requestor.py index eaf6269..673c843 100644 --- a/prawcore/requestor.py +++ b/prawcore/requestor.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Any -import requests +import niquests from .const import TIMEOUT from .exceptions import InvalidInvocation, RequestException if TYPE_CHECKING: - from requests.models import Response, Session + from niquests import Response, Session class Requestor: @@ -39,7 +39,7 @@ def __init__( :param reddit_url: The URL used when obtaining access tokens (default: ``"https://www.reddit.com"``). :param session: A session instance to handle requests, compatible with - ``requests.Session()`` (default: ``None``). + ``niquests.Session()`` (default: ``None``). :param timeout: How many seconds to wait for the server to send data before giving up (default: ``prawcore.const.TIMEOUT``). @@ -51,7 +51,7 @@ def __init__( msg = "user_agent is not descriptive" raise InvalidInvocation(msg) - self._http = session or requests.Session() + self._http = session or niquests.Session() self._http.headers["User-Agent"] = f"{user_agent} prawcore/{__version__}" self.oauth_url = oauth_url @@ -69,4 +69,4 @@ def request( try: return self._http.request(*args, timeout=timeout or self.timeout, **kwargs) except Exception as exc: # noqa: BLE001 - raise RequestException(exc, args, kwargs) from None + raise RequestException(exc, args, kwargs) from exc diff --git a/prawcore/sessions.py b/prawcore/sessions.py index 480624f..2b85807 100644 --- a/prawcore/sessions.py +++ b/prawcore/sessions.py @@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, BinaryIO, TextIO from urllib.parse import urljoin -from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout -from requests.status_codes import codes +from niquests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout +from niquests.status_codes import codes from .auth import BaseAuthorizer from .const import TIMEOUT, WINDOW_SIZE @@ -35,7 +35,7 @@ from .util import authorization_error_class if TYPE_CHECKING: - from requests.models import Response + from niquests.models import Response from .auth import Authorizer from .requestor import Requestor @@ -96,7 +96,7 @@ class Session: codes["unauthorized"]: authorization_error_class, codes[ "unavailable_for_legal_reasons" - ]: UnavailableForLegalReasons, # Cloudflare's status (not named in requests) + ]: UnavailableForLegalReasons, # Cloudflare's status (not named in niquests) 520: ServerError, 522: ServerError, } diff --git a/prawcore/util.py b/prawcore/util.py index ddd7338..54340c4 100644 --- a/prawcore/util.py +++ b/prawcore/util.py @@ -7,7 +7,7 @@ from .exceptions import Forbidden, InsufficientScope, InvalidToken if TYPE_CHECKING: - from requests.models import Response + from niquests.models import Response _auth_error_mapping = { 403: Forbidden, diff --git a/pyproject.toml b/pyproject.toml index dc4ed8f..dbb0b7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ "Programming Language :: Python :: 3.12" ] dependencies = [ - "requests >=2.6.0, <3.0" + "niquests >=3.10, <4.0" ] dynamic = ["version", "description"] keywords = ["praw", "reddit", "api"] @@ -45,7 +45,7 @@ lint = [ test = [ "betamax >=0.8, <0.9", "pytest ==7.*", - "urllib3 ==1.*" + "pytest-asyncio>=0.20,<0.25" ] [project.urls] @@ -60,6 +60,11 @@ line-length = 88 profile = 'black' skip_glob = '.venv*' +[tool.pytest.ini_options] +# this avoids pytest loading betamax+Requests at boot. +# this allows us to patch betamax and makes it use Niquests instead. +addopts = "-p no:pytest-betamax" + [tool.ruff] target-version = "py38" include = [ @@ -78,7 +83,8 @@ ignore = [ "E501", # line-length "PLR0913", # too many arguments "PLR2004", # Magic value used in comparison, - "S101" # use of assert + "S101", # use of assert + "ISC001" # ruff deprecation warning ] select = [ "A", # flake8-builtins diff --git a/tests/conftest.py b/tests/conftest.py index 49286a9..e11487a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,12 +3,44 @@ import os import socket import time +import asyncio from base64 import b64encode -from sys import platform +from sys import platform, modules + +import requests +import niquests +import urllib3 + +# betamax is tied to Requests +# and Niquests is almost entirely compatible with it. +# we can fool it without effort. +modules["requests"] = niquests +modules["requests.adapters"] = niquests.adapters +modules["requests.models"] = niquests.models +modules["requests.exceptions"] = niquests.exceptions +modules["requests.packages.urllib3"] = urllib3 + +# niquests no longer have a compat submodule +# but betamax need it. no worries, as betamax +# explicitly need requests, we'll give it to him. +modules["requests.compat"] = requests.compat + +# doing the import now will make betamax working with Niquests! +# no extra effort. +import betamax + +# the base mock does not implement close(), which is required +# for our HTTP client. No biggy. +betamax.mock_response.MockHTTPResponse.close = lambda _: None import pytest from prawcore import Requestor, TrustedAuthenticator, UntrustedAuthenticator +from prawcore import ( + AsyncRequestor, + AsyncTrustedAuthenticator, + AsyncUntrustedAuthenticator, +) @pytest.fixture(autouse=True) @@ -22,6 +54,17 @@ def _sleep(*_, **__): monkeypatch.setattr(time, "sleep", value=_sleep) +@pytest.fixture(autouse=True) +def patch_async_sleep(monkeypatch): + """Auto patch sleep to speed up tests.""" + + async def _sleep(*_, **__): + """Dud sleep function.""" + pass + + monkeypatch.setattr(asyncio, "sleep", value=_sleep) + + @pytest.fixture def image_path(): """Return path to image.""" @@ -39,6 +82,11 @@ def requestor(): return Requestor("prawcore:test (by /u/bboe)") +@pytest.fixture +def async_requestor(): + return AsyncRequestor("prawcore:test (by /u/bboe)") + + @pytest.fixture def trusted_authenticator(requestor): """Return a TrustedAuthenticator instance.""" @@ -49,12 +97,28 @@ def trusted_authenticator(requestor): ) +@pytest.fixture +def async_trusted_authenticator(async_requestor): + """Return a TrustedAuthenticator instance.""" + return AsyncTrustedAuthenticator( + async_requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + + @pytest.fixture def untrusted_authenticator(requestor): """Return an UntrustedAuthenticator instance.""" return UntrustedAuthenticator(requestor, pytest.placeholders.client_id) +@pytest.fixture +def async_untrusted_authenticator(async_requestor): + """Return an UntrustedAuthenticator instance.""" + return AsyncUntrustedAuthenticator(async_requestor, pytest.placeholders.client_id) + + def env_default(key): """Return environment variable or placeholder string.""" return os.environ.get( diff --git a/tests/integration/asynchronous/__init__.py b/tests/integration/asynchronous/__init__.py new file mode 100644 index 0000000..bbb4cda --- /dev/null +++ b/tests/integration/asynchronous/__init__.py @@ -0,0 +1,122 @@ +"""prawcore Integration test suite.""" + +from __future__ import annotations + +import base64 +import io +import os + +import betamax +import pytest +from betamax.cassette import Cassette, Interaction +from betamax.util import body_io + +from niquests import PreparedRequest, Response + +from niquests.adapters import AsyncHTTPAdapter +from niquests.utils import _swap_context + +try: + from urllib3 import AsyncHTTPResponse, HTTPHeaderDict + from urllib3.backend._async import AsyncLowLevelResponse +except ImportError: + from urllib3_future import AsyncHTTPResponse, HTTPHeaderDict + from urllib3_future.backend._async import AsyncLowLevelResponse + +CASSETTES_PATH = "tests/integration/cassettes" +existing_cassettes = set() +used_cassettes = set() + + +@pytest.mark.asyncio +class AsyncIntegrationTest: + """Base class for prawcore integration tests.""" + + @pytest.fixture(autouse=True) + def inject_fake_async_response(self, cassette_name, monkeypatch): + """betamax does not support Niquests async capabilities. This fixture is made to compensate for this missing feature.""" + cassette_base_dir = os.path.join(os.path.dirname(__file__), "..", "cassettes") + cassette = Cassette( + cassette_name, + serialization_format="json", + cassette_library_dir=cassette_base_dir, + ) + cassette.match_options.update({"method"}) + + def patch_add_urllib3_response(serialized, response, headers): + """This function is patched so that we can construct a proper async dummy response.""" + if "base64_string" in serialized["body"]: + body = io.BytesIO( + base64.b64decode(serialized["body"]["base64_string"].encode()) + ) + else: + body = body_io(**serialized["body"]) + + async def fake_inner_read( + *args, + ) -> tuple[bytes, bool, HTTPHeaderDict | None]: + """Fake the async iter socket read from AsyncHTTPConnection down in urllib3-future.""" + nonlocal body + return body.getvalue(), True, None + + fake_llr = AsyncLowLevelResponse( + method="GET", # hardcoded, but we don't really care. It does not impact the tests. + status=response.status_code, + reason=response.reason, + headers=headers, + body=fake_inner_read, + version=11, + ) + + h = AsyncHTTPResponse( + body, + status=response.status_code, + reason=response.reason, + headers=headers, + original_response=fake_llr, + ) + + response.raw = h + + monkeypatch.setattr( + betamax.util, "add_urllib3_response", patch_add_urllib3_response + ) + + async def fake_send(_, *args, **kwargs) -> Response: + nonlocal cassette + + prep_request: PreparedRequest = args[0] + + interaction: Interaction | None = cassette.find_match(prep_request) + + if interaction: + # betamax can generate a requests.Response + # from a matched interaction. + # three caveats: + # first: not async compatible + # second: we need to output niquests.AsyncResponse first + # third: the underlying HTTPResponse is sync bound + + resp = interaction.as_response() + # Niquests have two kind of responses in async mode. + # A) Response (in case stream=False) + # B) AsyncResponse (in case stream=True) + _swap_context(resp) + + return resp + + raise Exception("no match in cassettes for this request.") + + AsyncHTTPAdapter.send = fake_send + + @pytest.fixture + def cassette_name(self, request): + """Return the name of the cassette to use.""" + marker = request.node.get_closest_marker("cassette_name") + if marker is None: + return ( + f"{request.cls.__name__}.{request.node.name}" + if request.cls + else request.node.name + ) + return marker.args[0] diff --git a/tests/integration/asynchronous/test_authenticator.py b/tests/integration/asynchronous/test_authenticator.py new file mode 100644 index 0000000..252ec9c --- /dev/null +++ b/tests/integration/asynchronous/test_authenticator.py @@ -0,0 +1,41 @@ +"""Test for subclasses of prawcore.auth.BaseAuthenticator class.""" + +import pytest + +import prawcore + +from . import AsyncIntegrationTest + + +class TestTrustedAuthenticator(AsyncIntegrationTest): + async def test_revoke_token(self, async_requestor): + authenticator = prawcore.AsyncTrustedAuthenticator( + async_requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + await authenticator.revoke_token("dummy token") + + async def test_revoke_token__with_access_token_hint(self, async_requestor): + authenticator = prawcore.AsyncTrustedAuthenticator( + async_requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + await authenticator.revoke_token("dummy token", "access_token") + + async def test_revoke_token__with_refresh_token_hint(self, async_requestor): + authenticator = prawcore.AsyncTrustedAuthenticator( + async_requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + await authenticator.revoke_token("dummy token", "refresh_token") + + +class TestUntrustedAuthenticator(AsyncIntegrationTest): + async def test_revoke_token(self, async_requestor): + authenticator = prawcore.AsyncUntrustedAuthenticator( + async_requestor, pytest.placeholders.client_id + ) + await authenticator.revoke_token("dummy token") diff --git a/tests/integration/asynchronous/test_authorizer.py b/tests/integration/asynchronous/test_authorizer.py new file mode 100644 index 0000000..05c392f --- /dev/null +++ b/tests/integration/asynchronous/test_authorizer.py @@ -0,0 +1,280 @@ +"""Test for prawcore.auth.Authorizer classes.""" + +import pytest + +import prawcore + +from . import AsyncIntegrationTest + + +class TestAuthorizer(AsyncIntegrationTest): + async def test_authorize__with_invalid_code(self, async_trusted_authenticator): + async_trusted_authenticator.redirect_uri = pytest.placeholders.redirect_uri + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + with pytest.raises(prawcore.OAuthException): + await authorizer.authorize("invalid code") + assert not authorizer.is_valid() + + async def test_authorize__with_permanent_grant(self, async_trusted_authenticator): + async_trusted_authenticator.redirect_uri = pytest.placeholders.redirect_uri + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + await authorizer.authorize(pytest.placeholders.permanent_grant_code) + + assert authorizer.access_token is not None + assert authorizer.refresh_token is not None + assert isinstance(authorizer.scopes, set) + assert len(authorizer.scopes) > 0 + assert authorizer.is_valid() + + async def test_authorize__with_temporary_grant(self, async_trusted_authenticator): + async_trusted_authenticator.redirect_uri = pytest.placeholders.redirect_uri + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + await authorizer.authorize(pytest.placeholders.temporary_grant_code) + + assert authorizer.access_token is not None + assert authorizer.refresh_token is None + assert isinstance(authorizer.scopes, set) + assert len(authorizer.scopes) > 0 + assert authorizer.is_valid() + + async def test_refresh(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + await authorizer.refresh() + + assert authorizer.access_token is not None + assert isinstance(authorizer.scopes, set) + assert len(authorizer.scopes) > 0 + assert authorizer.is_valid() + + @pytest.mark.cassette_name("TestAuthorizer.test_refresh") + async def test_refresh__post_refresh_callback(self, async_trusted_authenticator): + async def callback(authorizer): + assert authorizer.refresh_token != pytest.placeholders.refresh_token + authorizer.refresh_token = "manually_updated" + + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, + post_refresh_callback=callback, + refresh_token=pytest.placeholders.refresh_token, + ) + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.refresh_token == "manually_updated" + assert isinstance(authorizer.scopes, set) + assert len(authorizer.scopes) > 0 + assert authorizer.is_valid() + + @pytest.mark.cassette_name("TestAuthorizer.test_refresh") + async def test_refresh__pre_refresh_callback(self, async_trusted_authenticator): + async def callback(authorizer): + assert authorizer.refresh_token is None + authorizer.refresh_token = pytest.placeholders.refresh_token + + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, pre_refresh_callback=callback + ) + await authorizer.refresh() + + assert authorizer.access_token is not None + assert isinstance(authorizer.scopes, set) + assert len(authorizer.scopes) > 0 + assert authorizer.is_valid() + + async def test_refresh__with_invalid_token(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token="INVALID_TOKEN" + ) + with pytest.raises(prawcore.ResponseException): + await authorizer.refresh() + assert not authorizer.is_valid() + + async def test_revoke__access_token_with_refresh_set( + self, async_trusted_authenticator + ): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + await authorizer.refresh() + await authorizer.revoke(only_access=True) + + assert authorizer.access_token is None + assert authorizer.refresh_token is not None + assert authorizer.scopes is None + assert not authorizer.is_valid() + + await authorizer.refresh() + + assert authorizer.is_valid() + + async def test_revoke__access_token_without_refresh_set( + self, async_trusted_authenticator + ): + async_trusted_authenticator.redirect_uri = pytest.placeholders.redirect_uri + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + await authorizer.authorize(pytest.placeholders.temporary_grant_code) + await authorizer.revoke() + + assert authorizer.access_token is None + assert authorizer.refresh_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + + async def test_revoke__refresh_token_with_access_set( + self, async_trusted_authenticator + ): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + await authorizer.refresh() + await authorizer.revoke() + + assert authorizer.access_token is None + assert authorizer.refresh_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + + async def test_revoke__refresh_token_without_access_set( + self, async_trusted_authenticator + ): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + await authorizer.revoke() + + assert authorizer.access_token is None + assert authorizer.refresh_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + + +class TestDeviceIDAuthorizer(AsyncIntegrationTest): + async def test_refresh(self, async_untrusted_authenticator): + authorizer = prawcore.AsyncDeviceIDAuthorizer(async_untrusted_authenticator) + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == {"*"} + assert authorizer.is_valid() + + async def test_refresh__with_scopes_and_trusted_authenticator( + self, async_requestor, async_untrusted_authenticator + ): + scope_list = {"adsedit", "adsread", "creddits", "history"} + authorizer = prawcore.AsyncDeviceIDAuthorizer( + prawcore.AsyncTrustedAuthenticator( + async_requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ), + scopes=scope_list, + ) + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == scope_list + assert authorizer.is_valid() + + async def test_refresh__with_short_device_id(self, async_untrusted_authenticator): + authorizer = prawcore.AsyncDeviceIDAuthorizer( + async_untrusted_authenticator, "a" * 19 + ) + with pytest.raises(prawcore.OAuthException): + await authorizer.refresh() + + +class TestReadOnlyAuthorizer(AsyncIntegrationTest): + async def test_refresh(self, async_trusted_authenticator): + authorizer = prawcore.AsyncReadOnlyAuthorizer(async_trusted_authenticator) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == {"*"} + assert authorizer.is_valid() + + async def test_refresh__with_scopes(self, async_trusted_authenticator): + scope_list = {"adsedit", "adsread", "creddits", "history"} + authorizer = prawcore.AsyncReadOnlyAuthorizer( + async_trusted_authenticator, scopes=scope_list + ) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == scope_list + assert authorizer.is_valid() + + +class TestScriptAuthorizer(AsyncIntegrationTest): + async def test_refresh(self, async_trusted_authenticator): + authorizer = prawcore.AsyncScriptAuthorizer( + async_trusted_authenticator, + pytest.placeholders.username, + pytest.placeholders.password, + ) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == {"*"} + assert authorizer.is_valid() + + async def test_refresh__with_invalid_otp(self, async_trusted_authenticator): + authorizer = prawcore.AsyncScriptAuthorizer( + async_trusted_authenticator, + pytest.placeholders.username, + pytest.placeholders.password, + lambda: "fake", + ) + with pytest.raises(prawcore.OAuthException): + await authorizer.refresh() + assert not authorizer.is_valid() + + async def test_refresh__with_invalid_username_or_password( + self, async_trusted_authenticator + ): + authorizer = prawcore.AsyncScriptAuthorizer( + async_trusted_authenticator, pytest.placeholders.username, "invalidpassword" + ) + with pytest.raises(prawcore.OAuthException): + await authorizer.refresh() + assert not authorizer.is_valid() + + async def test_refresh__with_scopes(self, async_trusted_authenticator): + scope_list = {"adsedit", "adsread", "creddits", "history"} + authorizer = prawcore.AsyncScriptAuthorizer( + async_trusted_authenticator, + pytest.placeholders.username, + pytest.placeholders.password, + scopes=scope_list, + ) + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == scope_list + assert authorizer.is_valid() + + async def test_refresh__with_valid_otp(self, async_trusted_authenticator): + authorizer = prawcore.AsyncScriptAuthorizer( + async_trusted_authenticator, + pytest.placeholders.username, + pytest.placeholders.password, + lambda: "000000", + ) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + await authorizer.refresh() + + assert authorizer.access_token is not None + assert authorizer.scopes == {"*"} + assert authorizer.is_valid() diff --git a/tests/integration/asynchronous/test_sessions.py b/tests/integration/asynchronous/test_sessions.py new file mode 100644 index 0000000..6629cd0 --- /dev/null +++ b/tests/integration/asynchronous/test_sessions.py @@ -0,0 +1,322 @@ +"""Test for prawcore.Sessions module.""" + +import logging +from json import dumps + +import pytest +import pytest_asyncio + +import prawcore + +from . import AsyncIntegrationTest + + +class TestSession(AsyncIntegrationTest): + @pytest_asyncio.fixture + async def async_readonly_authorizer(self, async_trusted_authenticator): + authorizer = prawcore.AsyncReadOnlyAuthorizer(async_trusted_authenticator) + await authorizer.refresh() + return authorizer + + @pytest_asyncio.fixture + async def async_script_authorizer(self, async_trusted_authenticator): + authorizer = prawcore.AsyncScriptAuthorizer( + async_trusted_authenticator, + pytest.placeholders.username, + pytest.placeholders.password, + ) + await authorizer.refresh() + return authorizer + + async def test_request__accepted(self, async_script_authorizer, caplog): + caplog.set_level(logging.DEBUG) + session = prawcore.AsyncSession(async_script_authorizer) + await session.request("POST", "api/read_all_messages") + found_message = False + for package, level, message in caplog.record_tuples: + if ( + package == "prawcore._async" + and level == logging.DEBUG + and "Response: 202 (2 bytes)" in message + ): + found_message = True + assert found_message, f"'Response: 202 (2 bytes)' in {caplog.record_tuples}" + + async def test_request__bad_gateway(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.ServerError) as exception_info: + await session.request("GET", "/") + assert exception_info.value.response.status_code == 502 + + async def test_request__bad_json(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + with pytest.raises(prawcore.BadJSON) as exception_info: + await session.request("GET", "/") + assert len(exception_info.value.response.content) == 92 + + async def test_request__bad_request(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + with pytest.raises(prawcore.BadRequest) as exception_info: + await session.request( + "PUT", "/api/v1/me/friends/spez", data='{"note": "prawcore"}' + ) + assert "reason" in exception_info.value.response.json() + + async def test_request__cloudflare_connection_timed_out( + self, async_readonly_authorizer + ): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.ServerError) as exception_info: + await session.request("GET", "/") + await session.request("GET", "/") + await session.request("GET", "/") + assert exception_info.value.response.status_code == 522 + + async def test_request__cloudflare_unknown_error(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.ServerError) as exception_info: + await session.request("GET", "/") + await session.request("GET", "/") + await session.request("GET", "/") + assert exception_info.value.response.status_code == 520 + + async def test_request__conflict(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + previous = "f0214574-430d-11e7-84ca-1201093304fa" + with pytest.raises(prawcore.Conflict) as exception_info: + await session.request( + "POST", + "/r/ThirdRealm/api/wiki/edit", + data={ + "content": "New text", + "page": "index", + "previous": previous, + }, + ) + assert exception_info.value.response.status_code == 409 + + async def test_request__created(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + response = await session.request("PUT", "/api/v1/me/friends/spez", data="{}") + assert "name" in response + + async def test_request__forbidden(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + with pytest.raises(prawcore.Forbidden): + await session.request("GET", "/user/spez/gilded/given") + + async def test_request__gateway_timeout(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.ServerError) as exception_info: + await session.request("GET", "/") + assert exception_info.value.response.status_code == 504 + + async def test_request__get(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + params = {"limit": 100} + response = await session.request("GET", "/", params=params) + assert isinstance(response, dict) + assert len(params) == 1 + assert response["kind"] == "Listing" + + async def test_request__internal_server_error(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.ServerError) as exception_info: + await session.request("GET", "/") + assert exception_info.value.response.status_code == 500 + + async def test_request__no_content(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + response = await session.request("DELETE", "/api/v1/me/friends/spez") + assert response is None + + async def test_request__not_found(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + with pytest.raises(prawcore.NotFound): + await session.request("GET", "/r/reddit_api_test/wiki/invalid") + + async def test_request__okay_with_0_byte_content(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + data = {"model": dumps({"name": "redditdev"})} + path = f"/api/multi/user/{pytest.placeholders.username}/m/praw_x5g968f66a/r/redditdev" + response = await session.request("DELETE", path, data=data) + assert response == "" + + @pytest.mark.recorder_kwargs(match_requests_on=["method", "uri", "body"]) + async def test_request__patch(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + json = {"lang": "ja", "num_comments": 123} + response = await session.request("PATCH", "/api/v1/me/prefs", json=json) + assert response["lang"] == "ja" + assert response["num_comments"] == 123 + + async def test_request__post(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + data = { + "kind": "self", + "sr": "reddit_api_test", + "text": "Test!", + "title": "A Test from PRAWCORE.", + } + key_count = len(data) + response = await session.request("POST", "/api/submit", data=data) + assert "a_test_from_prawcore" in response["json"]["data"]["url"] + assert key_count == len(data) # Ensure data is untouched + + @pytest.mark.recorder_kwargs(match_requests_on=["uri", "method"]) + async def test_request__post__with_files(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + data = {"upload_type": "header"} + with open("tests/integration/files/white-square.png", "rb") as fp: + files = {"file": fp} + response = await session.request( + "POST", + "/r/reddit_api_test/api/upload_sr_img", + data=data, + files=files, + ) + assert "img_src" in response + + async def test_request__raw_json(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + response = await session.request( + "GET", + "/r/reddit_api_test/comments/45xjdr/want_raw_json_test/", + ) + assert ( + "WANT_RAW_JSON test: < > &" + == response[0]["data"]["children"][0]["data"]["title"] + ) + + async def test_request__redirect(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.Redirect) as exception_info: + await session.request("GET", "/r/random") + assert exception_info.value.path.startswith("/r/") + + async def test_request__redirect_301(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.Redirect) as exception_info: + await session.request("GET", "t/bird") + assert exception_info.value.path == "/r/t:bird/" + + async def test_request__service_unavailable(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + with pytest.raises(prawcore.ServerError) as exception_info: + await session.request("GET", "/") + await session.request("GET", "/") + await session.request("GET", "/") + assert exception_info.value.response.status_code == 503 + + async def test_request__too__many_requests__with_retry_headers( + self, async_readonly_authorizer + ): + session = prawcore.AsyncSession(async_readonly_authorizer) + session._requestor._http.headers.update( + {"User-Agent": "python-requests/2.25.1"} + ) + with pytest.raises(prawcore.TooManyRequests) as exception_info: + await session.request("GET", "/api/v1/me") + assert exception_info.value.response.status_code == 429 + assert exception_info.value.response.headers.get("retry-after") + assert exception_info.value.response.reason == "Too Many Requests" + assert str(exception_info.value).startswith( + "received 429 HTTP response. Please wait at least" + ) + assert exception_info.value.message.startswith("\n") + + async def test_request__too__many_requests__without_retry_headers( + self, async_requestor + ): + async_requestor._http.headers.update({"User-Agent": "python-requests/2.25.1"}) + authorizer = prawcore.AsyncReadOnlyAuthorizer( + prawcore.AsyncTrustedAuthenticator( + async_requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + ) + with pytest.raises(prawcore.exceptions.ResponseException) as exception_info: + await authorizer.refresh() + assert exception_info.value.response.status_code == 429 + assert not exception_info.value.response.headers.get("retry-after") + assert exception_info.value.response.reason == "Too Many Requests" + assert exception_info.value.response.json() == { + "message": "Too Many Requests", + "error": 429, + } + + @pytest.mark.recorder_kwargs(match_requests_on=["uri", "method"]) + async def test_request__too_large(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + data = {"upload_type": "header"} + with open("tests/integration/files/too_large.jpg", "rb") as fp: + files = {"file": fp} + with pytest.raises(prawcore.TooLarge) as exception_info: + await session.request( + "POST", + "/r/reddit_api_test/api/upload_sr_img", + data=data, + files=files, + ) + assert exception_info.value.response.status_code == 413 + + async def test_request__unavailable_for_legal_reasons( + self, async_readonly_authorizer + ): + session = prawcore.AsyncSession(async_readonly_authorizer) + exception_class = prawcore.UnavailableForLegalReasons + with pytest.raises(exception_class) as exception_info: + await session.request("GET", "/") + assert exception_info.value.response.status_code == 451 + + async def test_request__unsupported_media_type(self, async_script_authorizer): + session = prawcore.AsyncSession(async_script_authorizer) + exception_class = prawcore.SpecialError + data = { + "content": "type: submission\naction: upvote", + "page": "config/automoderator", + } + with pytest.raises(exception_class) as exception_info: + await session.request("POST", "r/ttft/api/wiki/edit/", data=data) + assert exception_info.value.response.status_code == 415 + + async def test_request__uri_too_long(self, async_readonly_authorizer): + session = prawcore.AsyncSession(async_readonly_authorizer) + path_start = "/api/morechildren?link_id=t3_n7r3uz&children=" + with open("tests/integration/files/comment_ids.txt") as fp: + ids = fp.read() + with pytest.raises(prawcore.URITooLong) as exception_info: + await session.request("GET", (path_start + ids)[:9996]) + assert exception_info.value.response.status_code == 414 + + async def test_request__with_insufficient_scope(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + await authorizer.refresh() + session = prawcore.AsyncSession(authorizer) + with pytest.raises(prawcore.InsufficientScope): + await session.request( + "GET", + "/api/v1/me", + ) + + async def test_request__with_invalid_access_token( + self, async_untrusted_authenticator + ): + authorizer = prawcore.AsyncImplicitAuthorizer( + async_untrusted_authenticator, None, 0, "" + ) + session = prawcore.AsyncSession(authorizer) + session._authorizer.access_token = "invalid" + with pytest.raises(prawcore.InvalidToken): + r = await session.request("get", "/") + + async def test_request__with_invalid_access_token__retry( + self, async_readonly_authorizer + ): + session = prawcore.AsyncSession(async_readonly_authorizer) + session._authorizer.access_token += "invalid" + response = await session.request("GET", "/") + assert isinstance(response, dict) diff --git a/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_with_refresh_set.json b/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_with_refresh_set.json index aa93251..bfea8c7 100644 --- a/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_with_refresh_set.json +++ b/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_with_refresh_set.json @@ -49,7 +49,7 @@ "close" ], "Content-Length": [ - "184" + "182" ], "Content-Type": [ "application/json; charset=UTF-8" diff --git a/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_without_refresh_set.json b/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_without_refresh_set.json index 026aaae..62c35c7 100644 --- a/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_without_refresh_set.json +++ b/tests/integration/cassettes/TestAuthorizer.test_revoke__access_token_without_refresh_set.json @@ -49,7 +49,7 @@ "close" ], "Content-Length": [ - "124" + "123" ], "Content-Type": [ "application/json; charset=UTF-8" diff --git a/tests/integration/cassettes/TestAuthorizer.test_revoke__refresh_token_with_access_set.json b/tests/integration/cassettes/TestAuthorizer.test_revoke__refresh_token_with_access_set.json index 23edb78..377ecdc 100644 --- a/tests/integration/cassettes/TestAuthorizer.test_revoke__refresh_token_with_access_set.json +++ b/tests/integration/cassettes/TestAuthorizer.test_revoke__refresh_token_with_access_set.json @@ -49,7 +49,7 @@ "close" ], "Content-Length": [ - "184" + "182" ], "Content-Type": [ "application/json; charset=UTF-8" diff --git a/tests/integration/cassettes/TestScriptAuthorizer.test_refresh__with_scopes.json b/tests/integration/cassettes/TestScriptAuthorizer.test_refresh__with_scopes.json index f950fc5..b9e4ee9 100644 --- a/tests/integration/cassettes/TestScriptAuthorizer.test_refresh__with_scopes.json +++ b/tests/integration/cassettes/TestScriptAuthorizer.test_refresh__with_scopes.json @@ -49,7 +49,7 @@ "close" ], "Content-Length": [ - "147" + "148" ], "Content-Type": [ "application/json; charset=UTF-8" diff --git a/tests/integration/cassettes/TestSession.test_request__patch.json b/tests/integration/cassettes/TestSession.test_request__patch.json index 34b7ad9..a72c6cf 100644 --- a/tests/integration/cassettes/TestSession.test_request__patch.json +++ b/tests/integration/cassettes/TestSession.test_request__patch.json @@ -49,7 +49,7 @@ "keep-alive" ], "Content-Length": [ - "116" + "86" ], "Content-Type": [ "application/json; charset=UTF-8" diff --git a/tests/unit/asynchronous/__init__.py b/tests/unit/asynchronous/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/asynchronous/test_authenticator.py b/tests/unit/asynchronous/test_authenticator.py new file mode 100644 index 0000000..df517b0 --- /dev/null +++ b/tests/unit/asynchronous/test_authenticator.py @@ -0,0 +1,90 @@ +"""Test for subclasses of prawcore._async.auth.AsyncBaseAuthenticator class.""" + +import pytest + +import prawcore + +from .. import UnitTest + + +class TestTrustedAuthenticator(UnitTest): + @pytest.fixture + def async_trusted_authenticator(self, async_trusted_authenticator): + async_trusted_authenticator.redirect_uri = pytest.placeholders.redirect_uri + return async_trusted_authenticator + + def test_authorize_url(self, async_trusted_authenticator): + url = async_trusted_authenticator.authorize_url( + "permanent", ["identity", "read"], "a_state" + ) + assert f"client_id={pytest.placeholders.client_id}" in url + assert "duration=permanent" in url + assert "response_type=code" in url + assert "scope=identity+read" in url + assert "state=a_state" in url + + def test_authorize_url__fail_with_implicit(self, async_trusted_authenticator): + with pytest.raises(prawcore.InvalidInvocation): + async_trusted_authenticator.authorize_url( + "temporary", ["identity", "read"], "a_state", implicit=True + ) + + def test_authorize_url__fail_without_redirect_uri( + self, async_trusted_authenticator + ): + async_trusted_authenticator.redirect_uri = None + with pytest.raises(prawcore.InvalidInvocation): + async_trusted_authenticator.authorize_url( + "permanent", + ["identity"], + "...", + ) + + +class TestUntrustedAuthenticator(UnitTest): + @pytest.fixture + def async_untrusted_authenticator(self, async_untrusted_authenticator): + async_untrusted_authenticator.redirect_uri = pytest.placeholders.redirect_uri + return async_untrusted_authenticator + + def test_authorize_url__code(self, async_untrusted_authenticator): + url = async_untrusted_authenticator.authorize_url( + "permanent", ["identity", "read"], "a_state" + ) + assert f"client_id={pytest.placeholders.client_id}" in url + assert "duration=permanent" in url + assert "response_type=code" in url + assert "scope=identity+read" in url + assert "state=a_state" in url + + def test_authorize_url__fail_with_token_and_permanent( + self, async_untrusted_authenticator + ): + with pytest.raises(prawcore.InvalidInvocation): + async_untrusted_authenticator.authorize_url( + "permanent", + ["identity", "read"], + "a_state", + implicit=True, + ) + + def test_authorize_url__fail_without_redirect_uri( + self, async_untrusted_authenticator + ): + async_untrusted_authenticator.redirect_uri = None + with pytest.raises(prawcore.InvalidInvocation): + async_untrusted_authenticator.authorize_url( + "temporary", + ["identity"], + "...", + ) + + def test_authorize_url__token(self, async_untrusted_authenticator): + url = async_untrusted_authenticator.authorize_url( + "temporary", ["identity", "read"], "a_state", implicit=True + ) + assert f"client_id={pytest.placeholders.client_id}" in url + assert "duration=temporary" in url + assert "response_type=token" in url + assert "scope=identity+read" in url + assert "state=a_state" in url diff --git a/tests/unit/asynchronous/test_authorizer.py b/tests/unit/asynchronous/test_authorizer.py new file mode 100644 index 0000000..fa2c678 --- /dev/null +++ b/tests/unit/asynchronous/test_authorizer.py @@ -0,0 +1,118 @@ +"""Test for prawcore._async.auth.AsyncAuthorizer classes.""" + +import pytest + +import prawcore + +from .. import UnitTest + + +class AsyncInvalidAuthenticator(prawcore._async.auth.AsyncBaseAuthenticator): + def _auth(self): + pass + + +@pytest.mark.asyncio +class TestAuthorizer(UnitTest): + async def test_authorize__fail_without_redirect_uri( + self, async_trusted_authenticator + ): + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + with pytest.raises(prawcore.InvalidInvocation): + await authorizer.authorize("dummy code") + assert not authorizer.is_valid() + + async def test_initialize(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert authorizer.refresh_token is None + assert not authorizer.is_valid() + + async def test_initialize__with_refresh_token(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert pytest.placeholders.refresh_token == authorizer.refresh_token + assert not authorizer.is_valid() + + async def test_initialize__with_untrusted_authenticator(self): + authenticator = prawcore.AsyncUntrustedAuthenticator(None, None) + authorizer = prawcore.AsyncAuthorizer(authenticator) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert authorizer.refresh_token is None + assert not authorizer.is_valid() + + async def test_refresh__without_refresh_token(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + with pytest.raises(prawcore.InvalidInvocation): + await authorizer.refresh() + assert not authorizer.is_valid() + + async def test_revoke__without_access_token(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer( + async_trusted_authenticator, refresh_token=pytest.placeholders.refresh_token + ) + with pytest.raises(prawcore.InvalidInvocation): + await authorizer.revoke(only_access=True) + + async def test_revoke__without_any_token(self, async_trusted_authenticator): + authorizer = prawcore.AsyncAuthorizer(async_trusted_authenticator) + with pytest.raises(prawcore.InvalidInvocation): + await authorizer.revoke() + + +@pytest.mark.asyncio +class TestDeviceIDAuthorizer(UnitTest): + async def test_initialize(self, async_untrusted_authenticator): + authorizer = prawcore.AsyncDeviceIDAuthorizer(async_untrusted_authenticator) + assert authorizer.access_token is None + assert authorizer.scopes is None + assert not authorizer.is_valid() + + async def test_initialize__with_invalid_authenticator(self): + authenticator = prawcore.AsyncAuthorizer( + AsyncInvalidAuthenticator(None, None, None) + ) + with pytest.raises(prawcore.InvalidInvocation): + prawcore.AsyncDeviceIDAuthorizer(authenticator) + + +@pytest.mark.asyncio +class TestImplicitAuthorizer(UnitTest): + async def test_initialize(self, async_untrusted_authenticator): + authorizer = prawcore.AsyncImplicitAuthorizer( + async_untrusted_authenticator, "fake token", 1, "modposts read" + ) + assert authorizer.access_token == "fake token" + assert authorizer.scopes == {"modposts", "read"} + assert authorizer.is_valid() + + async def test_initialize__with_trusted_authenticator( + self, async_trusted_authenticator + ): + with pytest.raises(prawcore.InvalidInvocation): + prawcore.AsyncImplicitAuthorizer( + async_trusted_authenticator, None, None, None + ) + + +@pytest.mark.asyncio +class TestReadOnlyAuthorizer(UnitTest): + async def test_initialize__with_untrusted_authenticator( + self, async_untrusted_authenticator + ): + with pytest.raises(prawcore.InvalidInvocation): + prawcore.AsyncReadOnlyAuthorizer(async_untrusted_authenticator) + + +@pytest.mark.asyncio +class TestScriptAuthorizer(UnitTest): + async def test_initialize__with_untrusted_authenticator( + self, async_untrusted_authenticator + ): + with pytest.raises(prawcore.InvalidInvocation): + prawcore.AsyncScriptAuthorizer(async_untrusted_authenticator, None, None) diff --git a/tests/unit/asynchronous/test_rate_limit.py b/tests/unit/asynchronous/test_rate_limit.py new file mode 100644 index 0000000..84b347a --- /dev/null +++ b/tests/unit/asynchronous/test_rate_limit.py @@ -0,0 +1,130 @@ +"""Test for prawcore.AsyncSessions module.""" + +from copy import copy +from unittest.mock import patch + +import pytest + +from prawcore._async.rate_limit import AsyncRateLimiter + +from .. import UnitTest + + +@pytest.mark.asyncio +class TestRateLimiter(UnitTest): + @pytest.fixture + def rate_limiter(self): + rate_limiter = AsyncRateLimiter(window_size=600) + rate_limiter.next_request_timestamp = 100 + return rate_limiter + + @staticmethod + def _headers(remaining, used, reset): + return { + "x-ratelimit-remaining": str(float(remaining)), + "x-ratelimit-used": str(used), + "x-ratelimit-reset": str(reset), + } + + @patch("time.time") + @patch("asyncio.sleep") + async def test_delay(self, mock_sleep, mock_time, rate_limiter): + mock_time.return_value = 1 + await rate_limiter.delay() + assert mock_time.called + mock_sleep.assert_called_with(99) + + @patch("time.time") + @patch("time.sleep") + async def test_delay__no_sleep_when_time_in_past( + self, mock_sleep, mock_time, rate_limiter + ): + mock_time.return_value = 101 + await rate_limiter.delay() + assert mock_time.called + assert not mock_sleep.called + + @patch("time.sleep") + async def test_delay__no_sleep_when_time_is_not_set(self, mock_sleep, rate_limiter): + await rate_limiter.delay() + assert not mock_sleep.called + + @patch("time.time") + @patch("time.sleep") + async def test_delay__no_sleep_when_times_match( + self, mock_sleep, mock_time, rate_limiter + ): + mock_time.return_value = 100 + await rate_limiter.delay() + assert mock_time.called + assert not mock_sleep.called + + @patch("time.time") + async def test_update__compute_delay_with_no_previous_info( + self, mock_time, rate_limiter + ): + mock_time.return_value = 100 + rate_limiter.update(self._headers(60, 100, 60)) + assert rate_limiter.remaining == 60 + assert rate_limiter.used == 100 + assert rate_limiter.next_request_timestamp == 100 + + @patch("time.time") + async def test_update__compute_delay_with_single_client( + self, mock_time, rate_limiter + ): + rate_limiter.remaining = 61 + rate_limiter.window_size = 150 + mock_time.return_value = 100 + rate_limiter.update(self._headers(50, 100, 60)) + assert rate_limiter.remaining == 50 + assert rate_limiter.used == 100 + assert rate_limiter.next_request_timestamp == 110 + + @patch("time.time") + async def test_update__compute_delay_with_six_clients( + self, mock_time, rate_limiter + ): + rate_limiter.remaining = 66 + rate_limiter.window_size = 180 + mock_time.return_value = 100 + rate_limiter.update(self._headers(60, 100, 72)) + assert rate_limiter.remaining == 60 + assert rate_limiter.used == 100 + assert rate_limiter.next_request_timestamp == 104.5 + + @patch("time.time") + async def test_update__delay_full_time_with_negative_remaining( + self, mock_time, rate_limiter + ): + mock_time.return_value = 37 + rate_limiter.remaining = -1 + rate_limiter.update(self._headers(0, 100, 13)) + assert rate_limiter.remaining == 0 + assert rate_limiter.used == 100 + assert rate_limiter.next_request_timestamp == 50 + + @patch("time.time") + async def test_update__delay_full_time_with_zero_remaining( + self, mock_time, rate_limiter + ): + mock_time.return_value = 37 + rate_limiter.remaining = 0 + rate_limiter.update(self._headers(0, 100, 13)) + assert rate_limiter.remaining == 0 + assert rate_limiter.used == 100 + assert rate_limiter.next_request_timestamp == 50 + + async def test_update__no_change_without_headers(self, rate_limiter): + prev = copy(rate_limiter) + rate_limiter.update({}) + assert prev.remaining == rate_limiter.remaining + assert prev.used == rate_limiter.used + assert rate_limiter.next_request_timestamp == prev.next_request_timestamp + + async def test_update__values_change_without_headers(self, rate_limiter): + rate_limiter.remaining = 10 + rate_limiter.used = 99 + rate_limiter.update({}) + assert rate_limiter.remaining == 9 + assert rate_limiter.used == 100 diff --git a/tests/unit/asynchronous/test_requestor.py b/tests/unit/asynchronous/test_requestor.py new file mode 100644 index 0000000..1cd77a4 --- /dev/null +++ b/tests/unit/asynchronous/test_requestor.py @@ -0,0 +1,78 @@ +"""Test for prawcore._async.requestor.AsyncRequestor class.""" + +import pickle +from inspect import signature +from unittest.mock import Mock, patch + +import pytest + +import prawcore +from prawcore import RequestException + +from .. import UnitTest + + +@pytest.mark.asyncio +class TestRequestor(UnitTest): + async def test_initialize(self, async_requestor): + assert ( + async_requestor._http.headers["User-Agent"] + == f"prawcore:test (by /u/bboe) prawcore/{prawcore.__version__}" + ) + + async def test_initialize__failures(self): + for agent in [None, "shorty"]: + with pytest.raises(prawcore.InvalidInvocation): + prawcore.AsyncRequestor(agent) + + async def test_pickle(self, async_requestor): + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + pickle.loads(pickle.dumps(async_requestor, protocol=protocol)) + + async def test_request__session_timeout_default(self, async_requestor): + requestor_signature = signature(async_requestor._http.request) + assert ( + str(requestor_signature.parameters["timeout"]) + == "timeout: 'TimeoutType | None' = 120" + ) + + async def test_request__use_custom_session(self): + async def override() -> str: + return "ASYNC OVERRIDE" + + expected_return = await override() + custom_header = "CUSTOM SESSION HEADER" + headers = {"session_header": custom_header} + attrs = {"request.return_value": override(), "headers": headers} + session = Mock(**attrs) + + requestor = prawcore.AsyncRequestor( + "prawcore:test (by /u/bboe)", session=session + ) + + assert ( + requestor._http.headers["User-Agent"] + == f"prawcore:test (by /u/bboe) prawcore/{prawcore.__version__}" + ) + assert requestor._http.headers["session_header"] == custom_header + + assert (await requestor.request("https://reddit.com")) == expected_return + + @patch("requests.AsyncSession") + async def test_request__wrap_request_exceptions(self, mock_session): + exception = Exception("prawcore wrap_request_exceptions") + session_instance = mock_session.return_value + session_instance.request.side_effect = exception + requestor = prawcore.AsyncRequestor("prawcore:test (by /u/bboe)") + with pytest.raises(prawcore.RequestException) as exception_info: + await requestor.request("get", "http://a.b", data="bar") + assert isinstance(exception_info.value, RequestException) + assert exception is exception_info.value.original_exception + assert exception_info.value.request_args == ("get", "http://a.b") + assert exception_info.value.request_kwargs == {"data": "bar"} + + async def test_getattr_async_requestor(self, async_requestor): + """This test is added to cover one line of code in the async requestor that was indirectly used by betamax""" + adapters = getattr(async_requestor, "adapters", None) + + assert adapters is not None diff --git a/tests/unit/asynchronous/test_sessions.py b/tests/unit/asynchronous/test_sessions.py new file mode 100644 index 0000000..b55c6f5 --- /dev/null +++ b/tests/unit/asynchronous/test_sessions.py @@ -0,0 +1,108 @@ +"""Test for prawcore.AsyncSessions module.""" + +import logging +from unittest.mock import Mock, patch + +import pytest +from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout + +import prawcore +from prawcore.exceptions import RequestException + +from .. import UnitTest + + +class AsyncInvalidAuthorizer(prawcore.AsyncAuthorizer): + def __init__(self, requestor): + super(AsyncInvalidAuthorizer, self).__init__( + prawcore.AsyncTrustedAuthenticator( + requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + ) + + def is_valid(self): + return False + + +@pytest.mark.asyncio +class TestSession(UnitTest): + @pytest.fixture + def async_readonly_authorizer(self, async_trusted_authenticator): + return prawcore.AsyncReadOnlyAuthorizer(async_trusted_authenticator) + + async def test_close(self, async_readonly_authorizer): + await prawcore.AsyncSession(async_readonly_authorizer).close() + + async def test_context_manager(self, async_readonly_authorizer): + async with prawcore.AsyncSession(async_readonly_authorizer) as session: + assert isinstance(session, prawcore.AsyncSession) + + async def test_init__with_device_id_authorizer(self, async_untrusted_authenticator): + authorizer = prawcore.AsyncDeviceIDAuthorizer(async_untrusted_authenticator) + prawcore.AsyncSession(authorizer) + + async def test_init__with_implicit_authorizer(self, async_untrusted_authenticator): + authorizer = prawcore.AsyncImplicitAuthorizer( + async_untrusted_authenticator, None, 0, "" + ) + prawcore.AsyncSession(authorizer) + + async def test_init__without_authenticator(self): + with pytest.raises(prawcore.InvalidInvocation): + prawcore.AsyncSession(None) + + @patch("requests.AsyncSession") + @pytest.mark.parametrize( + "exception", + [ChunkedEncodingError(), ConnectionError(), ReadTimeout()], + ids=["ChunkedEncodingError", "ConnectionError", "ReadTimeout"], + ) + async def test_request__retry(self, mock_session, exception, caplog): + caplog.set_level(logging.WARNING) + session_instance = mock_session.return_value + # Handle Auth + response_dict = {"access_token": "", "expires_in": 99, "scope": ""} + + async def override(): + return Mock(headers={}, json=lambda: response_dict, status_code=200) + + session_instance.request.return_value = override() + requestor = prawcore.AsyncRequestor("prawcore:test (by /u/bboe)") + authenticator = prawcore.AsyncTrustedAuthenticator( + requestor, + pytest.placeholders.client_id, + pytest.placeholders.client_secret, + ) + authorizer = prawcore.AsyncReadOnlyAuthorizer(authenticator) + await authorizer.refresh() + session_instance.request.reset_mock() + # Fail on subsequent request + session_instance.request.side_effect = exception + + with pytest.raises(RequestException) as exception_info: + await prawcore.AsyncSession(authorizer).request("GET", "/") + assert ( + "prawcore._async", + logging.WARNING, + f"Retrying due to {exception.__class__.__name__}() status: GET " + "https://oauth.reddit.com/", + ) in caplog.record_tuples + assert isinstance(exception_info.value, RequestException) + assert exception is exception_info.value.original_exception + assert session_instance.request.call_count == 3 + + async def test_request__with_invalid_authorizer(self, async_requestor): + session = prawcore.AsyncSession(AsyncInvalidAuthorizer(async_requestor)) + with pytest.raises(prawcore.InvalidInvocation): + await session.request("get", "/") + + +@pytest.mark.asyncio +class TestSessionFunction(UnitTest): + async def test_session(self, async_requestor): + assert isinstance( + prawcore.async_session(AsyncInvalidAuthorizer(async_requestor)), + prawcore.AsyncSession, + ) diff --git a/tests/unit/test_requestor.py b/tests/unit/test_requestor.py index 56628ef..65a5cbe 100644 --- a/tests/unit/test_requestor.py +++ b/tests/unit/test_requestor.py @@ -30,7 +30,10 @@ def test_pickle(self, requestor): def test_request__session_timeout_default(self, requestor): requestor_signature = signature(requestor._http.request) - assert str(requestor_signature.parameters["timeout"]) == "timeout=None" + assert ( + str(requestor_signature.parameters["timeout"]) + == "timeout: 'TimeoutType | None' = 120" + ) def test_request__use_custom_session(self): override = "REQUEST OVERRIDDEN" diff --git a/tools/generate_async.py b/tools/generate_async.py new file mode 100755 index 0000000..1786d35 --- /dev/null +++ b/tools/generate_async.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python +"""Program to automatically generate asynchronous code from the synchronous counterpart. + +This script is made solely for PRAW. Do not attempt to use this outside of PRAW use +cases. + +The following heavily relies on proper typing in the package. Bad typing WILL lead to +BAD generated async code. + +Not everything is to be blindly made awaitable. This algorithm permit us to scan which +element invoke blocking I/O code and link dependants. + +REQUIRES: git, black, docstrfmt, and ruff installed. + +exit codes: + - 0: The asynchronous code is up-to-date + - 1: The asynchronous code is outdated + - 2: Generated async is invalid (syntax error) + - 3: Third-party tool not found (black, ruff or git) + - 4: Other errors (check stderr) + +""" +from __future__ import annotations + +import os.path +import shutil +import sys +import inspect +import importlib +import ast +import logging +import typing +from os import getcwd +from tempfile import TemporaryDirectory +from pathlib import Path +import re +from types import ModuleType +from typing import Callable +from subprocess import Popen, PIPE, CalledProcessError +import argparse + +import niquests + +# Use the local PRAW over the site-package PRAW. +if "/tools" in getcwd(): + sys.path.insert(0, "..") +else: + sys.path.insert(0, ".") + +import prawcore + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +explain_handler = logging.StreamHandler() +explain_handler.setFormatter( + logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") +) +logger.addHandler(explain_handler) + + +def black_reformat(tmp_dir: str) -> list[str]: + """Simply run black and retrieve which file were reformatted if any.""" + process = Popen( + f"black prawcore/_async", shell=True, stdout=PIPE, stderr=PIPE, cwd=tmp_dir + ) + stdout, stderr = process.communicate() + + stdout, stderr = stdout.decode(), stderr.decode() + + if stderr: + reformatted = [] + + for line in stderr.splitlines(keepends=False): + if line.startswith("reformatted"): + reformatted.append(line) + + return reformatted + + return [] + + +def docstrfmt_reformat(tmp_dir: str) -> list[str]: + process = Popen( + f"docstrfmt prawcore/_async/*.py", + shell=True, + stdout=PIPE, + stderr=PIPE, + cwd=tmp_dir, + ) + stdout, stderr = process.communicate() + + stdout, stderr = stdout.decode(), stderr.decode() + + if stderr: + reformatted = [] + + for line in stderr.splitlines(keepends=False): + if line.startswith("Reformatted"): + reformatted.append(line) + + return reformatted + + return [] + + +def ruff_reformat(tmp_dir: str) -> list[str]: + """Run Ruff linter and extract unfixable errors if any.""" + process = Popen( + f"ruff check --fix", shell=True, stdout=PIPE, stderr=PIPE, cwd=tmp_dir + ) + stdout, stderr = process.communicate() + + stdout, stderr = stdout.decode(), stderr.decode() + + if stderr: + raise IOError(stderr) + + if stdout: + reformatted = [] + + for line in stdout.splitlines(keepends=False): + reformatted.append(line) + + return reformatted + + return [] + + +def git_diff(file_main: str, file_generated: str) -> list[str]: + """Simple git diff between two files.""" + process = Popen( + f"git diff --no-index {file_main} {file_generated}", + shell=True, + stdout=PIPE, + stderr=PIPE, + ) + stdout, stderr = process.communicate() + + stdout, stderr = stdout.decode(), stderr.decode() + + if stderr: + raise IOError(stderr) + + if stdout: + patch_lines = [] + + for line in stdout.splitlines(keepends=False): + patch_lines.append(line) + + return patch_lines + + return [] + + +def preg_replace( + src_str: str, + pattern_in: str, + insert: str, + after: str | None = None, + before: str | None = None, + callback_ignore_if: Callable[[list[str]], bool] | None = None, + extraneous_space: bool = True, +) -> str: + """Smart regex replace function. It does not replace twice if replacement was already applied.""" + if extraneous_space: + insert += " " + + def _inner_repl(m: re.Match[str]) -> str: + full_string = m.string[m.span()[0] : m.span()[1]] + + if insert in m.string[m.span()[0] - len(insert) : m.span()[0] + len(insert)]: + return full_string + + if callback_ignore_if is not None: + if callback_ignore_if(list(m.groups())) is True: + return full_string + + if after is not None: + sub_idx = full_string.index(after) + sub_idx += len(after) + return full_string[:sub_idx] + insert + full_string[sub_idx:] + elif before is not None: + sub_idx = full_string.index(before) + return full_string[:sub_idx] + insert + full_string[sub_idx:] + + if insert in full_string: + return full_string + + first_printable_idx = 0 + + for c in full_string: + if c.isprintable() and c.isspace() is False: + first_printable_idx = full_string.index(c) + break + + return ( + full_string[:first_printable_idx] + + insert + + full_string[first_printable_idx:] + ) + + return re.sub(re.compile(pattern_in), _inner_repl, src_str) + + +def get_owned_by_module_path(my_type: type) -> str: + """Retrieve the module Python source path.""" + return importlib.import_module(my_type.__module__).__file__ + + +def get_definitions_in_module(my_type: type) -> list[str]: + """Retrieve the list of declaration issued from the module that host 'my_type'.""" + symbols_by_import = list( + map(lambda e: e[1] if e[1] is not None else e[0], get_import_made_by(my_type)) + ) + symbols = [] + + for symbol, _ in inspect.getmembers(importlib.import_module(my_type.__module__)): + if symbol not in symbols_by_import and symbol.startswith("__") is False: + symbols.append(symbol) + + return symbols + + +def get_root_module_from(my_type: type, level: int = 0) -> str: + """Infer the module MINUS given level depth. Starting from given 'my_type'.""" + tree = my_type.__module__.split(".") + return ".".join(tree[: len(tree) - level]) + + +def get_import_made_by(my_type: type) -> list[tuple[str, str | None]]: + """List every import made in module that host 'my_type'.""" + required_import = [] + + with open(get_owned_by_module_path(my_type), "r") as fp: + file_ast = ast.parse(fp.read()) + + # Walk every node in the tree + for node in ast.walk(file_ast): + # If the node is 'import x', then extract the module names + if isinstance(node, ast.Import): + required_import.extend([(x.name, None) for x in node.names]) + + # If the node is 'from x import y', then extract the module name + # and check level so we can ignore relative imports + if isinstance(node, ast.ImportFrom): + for name in node.names: + if node.module is None: + required_import.append( + ( + ( + node.module + if node.level == 0 + else f"{get_root_module_from(my_type, node.level)}.{name.name}" + ), + None, + ) + ) + else: + required_import.append( + ( + ( + node.module + if node.level == 0 + else f"{get_root_module_from(my_type, node.level)}.{node.module}" + ), + name.name, + ) + ) + + return required_import + + +def parse_constructor_attrs(my_type: type) -> list[tuple[str, str, str | None]]: + """Tries to uncover attributes initialized in constructor of 'my_type'.""" + raw_init = inspect.getsource(my_type.__init__) + + attributes = [] + + for line in raw_init.splitlines(keepends=False): + if "self." in line and " = " in line: + attr_definition, attr_default_assignment = tuple( + line.split(" = ", maxsplit=1) + ) + + attr_definition = attr_definition.replace("self.", "").strip() + + if "." in attr_definition: + continue + + if ":" in attr_definition: + attr_definition, attr_annotation = tuple( + attr_definition.split(": ", maxsplit=1) + ) + else: + attr_annotation = None + + attr_default_assignment = attr_default_assignment.strip() + + attributes.append( + (attr_definition, attr_default_assignment, attr_annotation) + ) + + return attributes + + +def load_annotation(raw_annotation: str) -> type: + """Dumb & straightforward annotation loader. Meant to import a single type.""" + imported_module = importlib.import_module(".".join(raw_annotation.split(".")[:-1])) + return getattr(imported_module, raw_annotation.split(".")[-1]) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Asynchronous Code Generator for PRAW") + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + dest="verbose", + help="Enable DEBUG logging", + ) + parser.add_argument( + "-f", + "--fix", + action="store_true", + default=False, + dest="fix", + help="Applies the newer version of generated modules", + ) + parser.add_argument( + "-s", + "--sync-diff", + action="store_true", + default=False, + dest="main_diff", + help="Render the differences between the synchronous parts and the newly generated asynchronous ones", + ) + parser.add_argument( + "-a", + "--async-diff", + action="store_true", + default=False, + dest="async_diff", + help="Render the differences between the old async and newest async generated modules", + ) + + args = parser.parse_args(sys.argv[1:]) + + if args.verbose: + logger.setLevel(logging.DEBUG) + + should_write_into_package: bool = args.fix is True + should_display_diff_sync_async: bool = args.main_diff is True + should_display_diff_async_async: bool = args.async_diff is True + + try: + logger.info("Scanning the top-level package __init__ to grab public references") + + #: directly exposed API (public) + top_level_types: list[type] = [] + #: not directly exposed in __init__ but required by references in __init__ + indirect_types: list[type] = [] + #: we need to list them so that we can properly remap import in generated async src + package_submodules: list[str] = [] + + # we want to parse the top-level public API + # that what we care about. + # first step is to identify every 'class' that could + # need async convertion. + for name, target_type in inspect.getmembers(prawcore): + if isinstance(target_type, ModuleType): + if name == "_async": + continue + if "prawcore." not in str(target_type): + continue + package_submodules.append(name) + # we need to exclude: + # A) already Async prefixed + # B) things that aren't package related (e.g. not prawcore) + # C) that aren't class definition. + # D) or that are Exception. + if ( + "Async" in name + or "prawcore" not in str(target_type) + or hasattr(target_type, "__base__") is False + or hasattr(target_type, "__cause__") is True + ): + logger.debug(f"Reference '{name}' ignored/excluded") + continue + + if target_type not in top_level_types: + logger.debug(f"Root reference added '{name}'") + top_level_types.append(target_type) + + if target_type.__bases__: + for parent_type in target_type.__bases__: + if parent_type is object: + continue + if parent_type not in top_level_types: + logger.debug( + f"Parent reference added '{parent_type.__name__}' from '{name}'" + ) + top_level_types.append(parent_type) + for sub_class_of_parent in parent_type.__subclasses__(): + if sub_class_of_parent is target_type: + continue + if sub_class_of_parent not in top_level_types: + logger.debug( + f"Child reference added '{sub_class_of_parent.__name__}' from '{parent_type.__name__}'" + ) + top_level_types.append(sub_class_of_parent) + + logger.info(f"Found references: {', '.join(str(r) for r in top_level_types)}") + + #: class that should have async counterpart + require_async_class: list[type] = [] + #: attrs that are compatible with async but are sync here + require_async_attribute: dict[type, list[tuple[str, type]]] = {} + #: methods that use those attrs + require_async_method: dict[type, list[str]] = {} + #: callable/callback within a class + require_async_callable: dict[type, list[str]] = {} + #: callable raw annotations for later convertion + require_async_callable_annotations: dict[type, list[str]] = {} + + logger.info("Starting in-depth analysis of references (class w/ attributes)") + + # then we'll need to identify across multiple iteration + # what actually is eligible to async convertion + # e.g. class uses niquests.Session + iter_count = 0 + + while True: + ref_count = 0 + logger.debug(f"Analysis iteration n°{iter_count + 1}") + + for target_type in top_level_types + indirect_types: + logger.debug(f"Inspecting {target_type}") + # we want to look at the constructor first. + have_constructor = False + try: + annotations_constructor_for_type = ( + target_type.__init__.__annotations__ + ) + have_constructor = True + except ( + AttributeError + ): # class does not necessarily define a constructor (e.g. ABC class) + annotations_constructor_for_type = {} + + import_clauses_in_module = get_import_made_by(target_type) + + # detect class dependant (but not exposed directly) + for dep_mod, dep_reference in import_clauses_in_module: + if ( + not dep_mod.startswith("prawcore.") + or dep_reference is None + or dep_reference[0].islower() + ): + continue + + dep_type = load_annotation(f"{dep_mod}.{dep_reference}") + + if ( + hasattr(dep_type, "__base__") is False + or hasattr(dep_type, "__cause__") is True + ): + continue + + if dep_type in top_level_types: + continue + if dep_type not in indirect_types: + ref_count += 1 + indirect_types.append(dep_type) + logger.debug( + f"Adding reference '{dep_type}' as indirect (not top level) dependency of '{target_type}'" + ) + + definitions_in_module = get_definitions_in_module(target_type) + + for local_definition in definitions_in_module: + if local_definition.islower(): + continue + + dep_type = load_annotation( + f"{get_root_module_from(target_type)}.{local_definition}" + ) + + if ( + hasattr(dep_type, "__base__") is False + or hasattr(dep_type, "__cause__") is True + ): + continue + if dep_type in top_level_types: + continue + if dep_type not in indirect_types: + ref_count += 1 + indirect_types.append(dep_type) + logger.debug( + f"Adding reference '{dep_type}' as indirect (local) dependency of '{target_type}'" + ) + + if have_constructor: + attrs_in_module = parse_constructor_attrs(target_type) + else: + attrs_in_module = [] + + interpreted_constructor_args = [] + + for arg_name, arg_type in annotations_constructor_for_type.items(): + arg_type = arg_type.replace("| None", "").strip() + + if arg_type.islower(): # built-in types (str, float, ...) + continue + + # this removes subscript on things like typing.Callable[...] + # we want typing.Callable only. + arg_type_no_sub = arg_type.split("[", maxsplit=1)[0] + + full_type = None + + for module, element in import_clauses_in_module: + if element == arg_type_no_sub: + full_type = f"{module}.{element}" + break + + if full_type is None and arg_type in definitions_in_module: + full_type = ( + f"{get_root_module_from(target_type)}.{arg_type_no_sub}" + ) + + interpreted_constructor_args.append((arg_name, full_type, arg_type)) + + attrs_in_module_rebuilt = [] + + for attr, assignment, annotation in attrs_in_module: + if annotation is None: + for x, y, z in interpreted_constructor_args: + if x in assignment: + annotation = ".".join(y.split(".")[:-1]) + "." + z + break + attrs_in_module_rebuilt.append((attr, assignment, annotation)) + + for attr, assignment, annotation in attrs_in_module_rebuilt: + if annotation is None: + continue + + annotation_light = annotation.replace("| None", "").strip() + + # remove subscr / args in annotation in order to import it + if "[" in annotation_light: + annotation_light = annotation_light[ + : annotation_light.index("[") + ] + + if annotation_light.islower(): # built-in types (str, float, ...) + continue + + loaded_type = load_annotation(annotation_light) + + if ( + loaded_type in require_async_class + or loaded_type == niquests.Session + ): + if target_type not in require_async_class: + ref_count += 1 + require_async_class.append(target_type) + if target_type.__subclasses__(): + for child_target in target_type.__subclasses__(): + if child_target not in require_async_class: + require_async_class.append(child_target) + if target_type not in require_async_attribute: + require_async_attribute[target_type] = [] + + if (attr, loaded_type) not in require_async_attribute[ + target_type + ]: + require_async_attribute[target_type].append( + (attr, loaded_type) + ) + elif loaded_type is typing.Callable: + # don't mark callable without args or return as "convertible to async" + # we assume they don't block I/O + if "[" in annotation: + if target_type not in require_async_method: + require_async_method[target_type] = [] + if target_type not in require_async_callable_annotations: + require_async_callable_annotations[target_type] = [] + + if ( + annotation + not in require_async_callable_annotations[target_type] + ): + require_async_callable_annotations[target_type].append( + annotation + ) + + if attr not in require_async_method[target_type]: + require_async_method[target_type].append(attr) + + iter_count += 1 + + if ref_count == 0: + logger.debug( + f"Exiting in-depth analysis at iteration n°{iter_count + 1}" + ) + break + + logger.info("Starting simple base/child inheritance logic") + + inherent_convert_async = [] + + # if a "base" class need to be converted to async + # then, we automatically grab every child for convertion. + for tbc_async in require_async_class: + for tbc_async_subclass in tbc_async.__subclasses__(): + if ( + tbc_async_subclass not in require_async_class + and tbc_async_subclass not in inherent_convert_async + ): + logger.debug( + f"Reference {tbc_async_subclass} is added because {tbc_async} requires async convertion (as parent)" + ) + inherent_convert_async.append(tbc_async_subclass) + + require_async_class.extend(inherent_convert_async) + + logger.info("Starting methods inspection in references") + + method_only_convert_async = [] + + # we now need to scan which methods need to be converted to async + while True: + # this var is needed to know when to stop the loop + # basically, ref_count means how many finding (e.g. to async required) + # has been found. + # we iterate because previous loop contain new reference that can be + # referenced later by other methods/class. + ref_count = 0 + + for tbc_async in require_async_class + indirect_types: + for potential_method_name in dir(tbc_async): + if potential_method_name.startswith("__"): + continue + + potential_method = getattr(tbc_async, potential_method_name) + + if callable(potential_method): + if "self" not in inspect.signature(potential_method).parameters: + continue + + any_callable_in_parameters = False + + for p, t in inspect.signature( + potential_method + ).parameters.items(): + if ( + isinstance(t.annotation, str) + and "Callable" in t.annotation + ): + any_callable_in_parameters = True + if tbc_async not in require_async_callable: + require_async_callable[tbc_async] = [] + if tbc_async not in require_async_callable_annotations: + require_async_callable_annotations[tbc_async] = [] + require_async_callable[tbc_async].append(p) + require_async_callable_annotations[tbc_async].append( + t.annotation + ) + + raw_method_source = inspect.getsource(potential_method) + + try: + async_attrs = require_async_attribute[tbc_async] + except KeyError: + try: + async_attrs = require_async_attribute[ + tbc_async.__base__ + ] + except KeyError: + async_attrs = [] + + for attr_name, attr_type in async_attrs: + if re.findall( + rf"self\.{attr_name}\.(.*)\(", raw_method_source + ): + if tbc_async not in require_async_method: + require_async_method[tbc_async] = [] + + if ( + potential_method_name + not in require_async_method[tbc_async] + ): + require_async_method[tbc_async].append( + potential_method_name + ) + ref_count += 1 + logger.debug( + f"Method '{potential_method_name}' in reference {tbc_async} need to be converted to async. Reason: usage of attribute '{attr_name}' bound to '{attr_type}'" + ) + break + + for _ in require_async_method: + for require_async_method_future in require_async_method[_]: + if ( + f".{require_async_method_future}(" + in raw_method_source + ): + if ( + tbc_async in require_async_method + and potential_method_name + not in require_async_method[tbc_async] + ): + require_async_method[tbc_async].append( + potential_method_name + ) + ref_count += 1 + logger.debug( + f"Method '{potential_method_name}' in reference {tbc_async} need to be converted to async. Reason: usage of method {require_async_method_future}" + ) + break + + if any_callable_in_parameters: + if tbc_async not in require_async_method: + require_async_method[tbc_async] = [] + + if ( + potential_method_name + not in require_async_method[tbc_async] + ): + require_async_method[tbc_async].append( + potential_method_name + ) + ref_count += 1 + logger.debug( + f"Method '{potential_method_name}' in reference {tbc_async} need to be converted to async. Reason: usage of callback" + ) + if ( + tbc_async not in require_async_class + and tbc_async not in method_only_convert_async + ): + method_only_convert_async.append(tbc_async) + + if "sleep(" in raw_method_source: + if tbc_async not in require_async_method: + require_async_method[tbc_async] = [] + + if ( + potential_method_name + not in require_async_method[tbc_async] + ): + require_async_method[tbc_async].append( + potential_method_name + ) + ref_count += 1 + logger.debug( + f"Method '{potential_method_name}' in reference {tbc_async} need to be converted to async. Reason: usage of function sleep" + ) + if ( + tbc_async not in require_async_class + and tbc_async not in method_only_convert_async + ): + method_only_convert_async.append(tbc_async) + + if ref_count == 0: + break + + # some class need to be async due to a method + # using I/O blocking functions like sleep. + # it is detected later in the process because + # the constructor wasn't specified or none of the attrs + # were async "compatible". + require_async_class.extend(method_only_convert_async) + + logger.info("Analysis ended.") + + # output primarily analysis results, mapping. + for tbc_async in require_async_class: + logger.info(f"Reference '{tbc_async}' need async counterpart") + + for tbc_async in require_async_attribute: + logger.info( + f"Reference '{tbc_async}' attributes: {', '.join(str(e) for e in require_async_attribute[tbc_async])}" + ) + + for tbc_async in require_async_method: + logger.info( + f"Reference '{tbc_async}' methods: {', '.join(require_async_method[tbc_async])}" + ) + + file_to_copy = [] + tbd_async_module = [] + + # determine what file need to be duplicated + for tbc_async in require_async_method: + file = get_owned_by_module_path(tbc_async) + + if file not in file_to_copy: + file_to_copy.append(file) + tbd_async_module.append(get_root_module_from(tbc_async)) + + logger.info(f"Async module required: {tbd_async_module}") + logger.info( + f"The following files(modules) need to be duplicated: {file_to_copy}" + ) + + need_remap_submodule_level_import = [] + + for pkg_submodule in package_submodules: + if f"prawcore.{pkg_submodule}" not in tbd_async_module: + need_remap_submodule_level_import.append(pkg_submodule) + + logger.info( + f"The following submodules are untouched: {need_remap_submodule_level_import}" + ) + + # we'll start to duplicate and patch the required files + # with acquired knowledge about the package structure + # and internal relationships. + with TemporaryDirectory(prefix="prawcore_async", delete=False) as tmp_dir: + logger.info(f"Provisioning temporary directory at '{tmp_dir}'") + + logger.info( + f"Copying project configuration (pyproject.toml) to '{tmp_dir}'" + ) + + # we need this, because we want our project configuration + # to be applied as is. not the default one upon post-fmt + # (e.g. tools like black, ruff, etc...) + tmp_module_path = os.path.join(tmp_dir, "pyproject.toml") + shutil.copy( + os.path.join(Path(file_to_copy[0]).parent.parent, "pyproject.toml"), + tmp_module_path, + ) + + package_tmp_rootdir = os.path.join(tmp_dir, "prawcore") + + os.mkdir(package_tmp_rootdir) + + shutil.copy( + os.path.join(Path(file_to_copy[0]).parent, "__init__.py"), + package_tmp_rootdir, + ) + + subpackage_tmp_rootdir = os.path.join(package_tmp_rootdir, "_async") + + os.mkdir(subpackage_tmp_rootdir) + + with open(os.path.join(subpackage_tmp_rootdir, "__init__.py"), "wb") as fp: + pass + + for file_path in file_to_copy: + logger.info(f"Copy and patch module '{file_path}'") + + tmp_module_path = os.path.join( + subpackage_tmp_rootdir, Path(file_path).name + ) + shutil.copy(file_path, tmp_module_path) + + with open(tmp_module_path, "r", encoding="utf-8") as fp: + tmp_src_raw = fp.read() + + # we don't want absolute import as it is messing + # with following "relative" import adjustments. + if "from prawcore." in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace("from prawcore.", "from .") + + # if we don't duplicate submodule 'X' then we'll need to import + # it from a level higher as generated code will lie in a subpackage. + for untouched_submodule in need_remap_submodule_level_import: + clause_a = f"from .{untouched_submodule} import" + patch_a = clause_a.replace(".", "..") + + clause_b = f"from . import {untouched_submodule}" + patch_b = clause_b.replace(".", "..") + + if clause_a in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace(clause_a, patch_a) + + if clause_b in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace(clause_b, patch_b) + + # usually it's about importing __version__ or alike root module attribute + if "from . import __" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + "from . import __", "from .. import __" + ) + + for tbd_async in require_async_class: + # class rename + if f"class {tbd_async.__name__}" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + f"class {tbd_async.__name__}", + f"class Async{tbd_async.__name__}", + ) + + # typing rename + if f" {tbd_async.__name__}" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + f" {tbd_async.__name__}", f" Async{tbd_async.__name__}" + ) + + # class inheritance rename + if f"({tbd_async.__name__}):" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + f"({tbd_async.__name__}):", f"(Async{tbd_async.__name__}):" + ) + + # special edge case rename + if f"({tbd_async.__name__}" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + f"({tbd_async.__name__}", f"(Async{tbd_async.__name__}" + ) + + # typing subscription arg rename + if f"[{tbd_async.__name__}" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + f"[{tbd_async.__name__}", f"[Async{tbd_async.__name__}" + ) + + # do the same as above for kept "Callable" annotations + # to ensure we can properly patch it (for Awaitable) + if tbd_async in require_async_callable_annotations: + require_async_callable_annotations_patched = [] + + for callable_annot in require_async_callable_annotations[ + tbd_async + ]: + require_async_callable_annotations_patched.append( + callable_annot.replace( + f"[{tbd_async.__name__}", + f"[Async{tbd_async.__name__}", + ) + ) + + require_async_callable_annotations[tbd_async] = ( + require_async_callable_annotations_patched + ) + + # docstr replace + if f":class:`.{tbd_async.__name__}`" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + f":class:`.{tbd_async.__name__}`", + f":class:`.Async{tbd_async.__name__}`", + ) + + # automatically convert contextmgr magic methods to async counterpart + if "def __enter__(" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + "def __enter__(", "async def __aenter__(" + ) + tmp_src_raw = tmp_src_raw.replace( + "def __exit__(", "async def __aexit__(" + ) + + # time.sleep is blocking, convert it to asyncio and make it await(ed) + if "time.sleep(" in tmp_src_raw: + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"time\.sleep\(", + insert="await", + before="time.", + ) + tmp_src_raw = tmp_src_raw.replace("time.sleep", "asyncio.sleep") + + # obviously niquests.Session + if "niquests.Session(" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + "niquests.Session(", "niquests.AsyncSession(" + ) + if " Session(" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace(" Session(", " AsyncSession(") + if ": niquests.Session" in tmp_src_raw: + tmp_src_raw = tmp_src_raw.replace( + ": niquests.Session", ": niquests.AsyncSession" + ) + + # we assume every method "sleep" is blocking I/O + # and already made awaitable. + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"(.*)\.sleep\(\)", + insert="await", + ) + + # we should add asyncio import if any function invoked in it. + if ( + "asyncio." in tmp_src_raw + and "import asyncio" not in tmp_src_raw + ): + first_import_idx = tmp_src_raw.index("\nimport") + tmp_src_raw = ( + tmp_src_raw[0:first_import_idx] + + "import asyncio\n" + + tmp_src_raw[first_import_idx:] + ) + + if tbd_async in require_async_attribute: + if tbd_async in require_async_attribute: + for attr_name, attr_type in require_async_attribute[ + tbd_async + ]: + # not every method from given (eligible async) attr can be awaited. + # act with caution. We'll probe our internal map "class -> awaitable methods" + # to know if await is required. + def _inner_check_method_awaitable( + detect_match: list[str], + ) -> bool: + _inner_method_call = detect_match[0] + if attr_type not in require_async_method: + return False + if ( + _inner_method_call + in require_async_method[attr_type] + ): + return True + return False + + tmp_src_raw = preg_replace( + tmp_src_raw, + rf" self.{attr_name}\.(.*)\((.*)\)$", + insert="await", + before="self.", + callback_ignore_if=_inner_check_method_awaitable, + ) + + if tbd_async in require_async_callable: + for callable_name in require_async_callable[tbd_async]: + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"{callable_name}\((?!self)", + insert="await", + before=callable_name, + ) + + if tbd_async in require_async_callable_annotations: + # every Callable annot need to be converted to returning Awaitable + # two fmt: Callable (e.g. without args) --> Callable[[], Awaitable[None]] + # or : Callable[[...], ...] need to become Callable[[.....], Awaitable[...]] + for callable_annot in require_async_callable_annotations[ + tbd_async + ]: + if callable_annot.endswith("]"): + callable_annot = callable_annot.replace("typing.", "") + tmp_src_raw = tmp_src_raw.replace( + callable_annot, + preg_replace( + callable_annot, + r"Callable\[\[(.*)\], (.*)\]", + insert="Awaitable[", + after=", ", + extraneous_space=False, + ) + + "]", + ) + + if ( + "Awaitable[" in tmp_src_raw + and "import Awaitable" not in tmp_src_raw + ): + first_import_idx = tmp_src_raw.index("\nimport") + tmp_src_raw = ( + tmp_src_raw[0:first_import_idx] + + "\nfrom typing import Awaitable\n" + + tmp_src_raw[first_import_idx:] + ) + + if tbd_async in require_async_method: + for method_name in require_async_method[tbd_async]: + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"def {method_name}\(", + insert="async", + before="def", + ) + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"self.(.*)\.{method_name}\(", + insert="await", + before="self.", + ) + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"self\.{method_name}\(", + insert="await", + before="self.", + ) + tmp_src_raw = preg_replace( + tmp_src_raw, + rf"super\(\)\.{method_name}\(", + insert="await", + before="super(", + ) + + # finally write the resulting patched source to tmp module file + with open(tmp_module_path, "w", encoding="utf-8") as fp: + fp.write(tmp_src_raw) + + # We need to reformat / lint the output + # So that the git diff will actually be + # useful. + logger.info("Run black fmt over newly generated source") + + try: + reformatted_modules = black_reformat(tmp_dir) + except IOError as e: + logger.error(str(e)) + shutil.rmtree(tmp_dir) + exit(2) + except CalledProcessError as e: + logger.critical(str(e)) + shutil.rmtree(tmp_dir) + exit(3) + + if not reformatted_modules: + logger.info("Black report nothing to be reformatted") + else: + for refmt in reformatted_modules: + logger.info(refmt) + + logger.info("Run Ruff fmt/lint over newly generated source") + + try: + reformatted_modules = ruff_reformat(tmp_dir) + except IOError as e: + logger.error(str(e)) + shutil.rmtree(tmp_dir) + exit(2) + except CalledProcessError as e: + logger.critical(str(e)) + shutil.rmtree(tmp_dir) + exit(3) + + if not reformatted_modules: + logger.info("Ruff report nothing to be reformatted") + else: + for refmt in reformatted_modules: + logger.info(refmt) + + logger.info("Run docstrfmt fmt over newly generated source") + + try: + reformatted_modules = docstrfmt_reformat(tmp_dir) + except IOError as e: + logger.error(str(e)) + shutil.rmtree(tmp_dir) + exit(2) + except CalledProcessError as e: + logger.critical(str(e)) + shutil.rmtree(tmp_dir) + exit(3) + + if not reformatted_modules: + logger.info("docstrfmt report nothing to be reformatted") + else: + for refmt in reformatted_modules: + logger.info(refmt) + + any_diff_async_async: bool = False + + # now we check the diff between (sync) main and new generation + for file_path in file_to_copy: + logger.info(f"Scan difference with '{file_path}'") + + root_package_dir = Path(file_path).parent + expected_package_async_loc = os.path.join( + root_package_dir, f"_async/{Path(file_path).name}" + ) + + logger.info(f"Expected destination '{expected_package_async_loc}'") + + tmp_module_path = os.path.join(subpackage_tmp_rootdir, Path(file_path).name) + + if should_display_diff_sync_async: + try: + patch = git_diff(file_path, tmp_module_path) + except CalledProcessError as e: + logger.critical(str(e)) + shutil.rmtree(tmp_dir) + exit(3) + + for line in patch: + logger.info(line) + + if os.path.exists(expected_package_async_loc): + try: + patch = git_diff(expected_package_async_loc, tmp_module_path) + except CalledProcessError as e: + logger.critical(str(e)) + shutil.rmtree(tmp_dir) + exit(3) + + if should_display_diff_async_async: + for line in patch: + logger.info(line) + + if len(patch) > 1: + logger.warning( + f"Async source for '{expected_package_async_loc}' is outdated" + ) + any_diff_async_async = True + + if should_write_into_package: + if not os.path.exists(str(root_package_dir) + "/_async"): + os.mkdir(str(root_package_dir) + "/_async") + shutil.copy(tmp_module_path, expected_package_async_loc) + else: + any_diff_async_async = True + logger.warning( + f"non-existent async module '{expected_package_async_loc}'. need git add on it." + ) + + if any_diff_async_async: + if should_write_into_package is False: + logger.warning( + "Async code is outdated. Run the script with --fix to update the async parts" + ) + else: + logger.info("Async code successfully updated to latest (sync) changes") + shutil.rmtree(tmp_dir) + exit(1) + else: + logger.info("Async code is already up-to-date!") + + shutil.rmtree(tmp_dir) + except Exception as e: + logger.critical(str(e)) + exit(4)