From c7577050f287db876169dd37cebb57d509026f97 Mon Sep 17 00:00:00 2001 From: Thomas Hahn Date: Sun, 22 Dec 2024 18:32:04 +0100 Subject: [PATCH] Add possibility to inject httpx session --- src/homematicip/connection/rest_connection.py | 75 +++++++++++-------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/src/homematicip/connection/rest_connection.py b/src/homematicip/connection/rest_connection.py index ff46f3ad..adced1bc 100644 --- a/src/homematicip/connection/rest_connection.py +++ b/src/homematicip/connection/rest_connection.py @@ -35,16 +35,19 @@ class RestConnection: _headers: dict[str, str] = None _verify = None _log_status_exceptions = True + _httpx_client_session: httpx.AsyncClient = None - def __init__(self, context: ConnectionContext, log_status_exceptions: bool = True): + def __init__(self, context: ConnectionContext, httpx_client_session: httpx.AsyncClient = None, log_status_exceptions: bool = True): """Initialize the RestConnection object. @param context: The connection context + @param httpx_client_session: The httpx client session if you want to use a custom one @param log_status_exceptions: If status exceptions should be logged """ LOGGER.debug("Initialize new RestConnection") self.update_connection_context(context) self._log_status_exceptions = log_status_exceptions + self._httpx_client_session = httpx_client_session def update_connection_context(self, context: ConnectionContext) -> None: self._context = context @@ -67,6 +70,12 @@ def get_header(self) -> dict[str, str]: """If headers must be manipulated use this method to get the current headers.""" return self._headers + def _get_client_session(self) -> httpx.AsyncClient: + if self._httpx_client_session is not None: + return self._httpx_client_session + + return httpx.AsyncClient(verify=self._verify) + async def async_post(self, url: str, data: json = None, custom_header: dict = None) -> RestResult: """Send an async post request to cloud with json data. Returns a json result. @param url: The path of the url to send the request to @@ -75,39 +84,39 @@ async def async_post(self, url: str, data: json = None, custom_header: dict = No @return: The result as a RestResult object """ full_url = self._build_url(self._context.rest_url, url) - async with httpx.AsyncClient(verify=self._verify) as client: + client = self._get_client_session() + try: + header = self._headers + if custom_header is not None: + header = custom_header + + LOGGER.debug(f"Sending post request to url {full_url}. Data is: {data}") + r = await client.post(full_url, json=data, headers=header) + LOGGER.debug(f"Got response {r.status_code}.") + + if r.status_code == THROTTLE_STATUS_CODE: + LOGGER.error("Got error 429 (Throttling active)") + raise HmipThrottlingError + + r.raise_for_status() + + result = RestResult(status=r.status_code) try: - header = self._headers - if custom_header is not None: - header = custom_header - - LOGGER.debug(f"Sending post request to url {full_url}. Data is: {data}") - r = await client.post(full_url, json=data, headers=header) - LOGGER.debug(f"Got response {r.status_code}.") - - if r.status_code == THROTTLE_STATUS_CODE: - LOGGER.error("Got error 429 (Throttling active)") - raise HmipThrottlingError - - r.raise_for_status() - - result = RestResult(status=r.status_code) - try: - result.json = r.json() - except json.JSONDecodeError: - pass - - return result - except httpx.RequestError as exc: - LOGGER.error(f"An error occurred while requesting {exc.request.url!r}.") - return RestResult(status=-1, exception=exc) - except httpx.HTTPStatusError as exc: - if self._log_status_exceptions: - LOGGER.error( - f"Error response {exc.response.status_code} while requesting {exc.request.url!r} with data {data if data is not None else ""}." - ) - LOGGER.error(f"Response: {repr(exc.response)}") - return RestResult(status=exc.response.status_code, exception=exc, text=exc.response.text) + result.json = r.json() + except json.JSONDecodeError: + pass + + return result + except httpx.RequestError as exc: + LOGGER.error(f"An error occurred while requesting {exc.request.url!r}.") + return RestResult(status=-1, exception=exc) + except httpx.HTTPStatusError as exc: + if self._log_status_exceptions: + LOGGER.error( + f"Error response {exc.response.status_code} while requesting {exc.request.url!r} with data {data if data is not None else ""}." + ) + LOGGER.error(f"Response: {repr(exc.response)}") + return RestResult(status=exc.response.status_code, exception=exc, text=exc.response.text) @staticmethod def _build_url(base_url: str, path: str) -> str: