Skip to content

Commit

Permalink
Add possibility to inject httpx session
Browse files Browse the repository at this point in the history
  • Loading branch information
hahn-th committed Dec 22, 2024
1 parent c649bfb commit c757705
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions src/homematicip/connection/rest_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ class RestConnection:
_headers: dict[str, str] = None
_verify = None
_log_status_exceptions = True
_httpx_client_session: httpx.AsyncClient = None

def __init__(self, context: ConnectionContext, log_status_exceptions: bool = True):
def __init__(self, context: ConnectionContext, httpx_client_session: httpx.AsyncClient = None, log_status_exceptions: bool = True):
"""Initialize the RestConnection object.
@param context: The connection context
@param httpx_client_session: The httpx client session if you want to use a custom one
@param log_status_exceptions: If status exceptions should be logged
"""
LOGGER.debug("Initialize new RestConnection")
self.update_connection_context(context)
self._log_status_exceptions = log_status_exceptions
self._httpx_client_session = httpx_client_session

def update_connection_context(self, context: ConnectionContext) -> None:
self._context = context
Expand All @@ -67,6 +70,12 @@ def get_header(self) -> dict[str, str]:
"""If headers must be manipulated use this method to get the current headers."""
return self._headers

def _get_client_session(self) -> httpx.AsyncClient:
if self._httpx_client_session is not None:
return self._httpx_client_session

return httpx.AsyncClient(verify=self._verify)

async def async_post(self, url: str, data: json = None, custom_header: dict = None) -> RestResult:
"""Send an async post request to cloud with json data. Returns a json result.
@param url: The path of the url to send the request to
Expand All @@ -75,39 +84,39 @@ async def async_post(self, url: str, data: json = None, custom_header: dict = No
@return: The result as a RestResult object
"""
full_url = self._build_url(self._context.rest_url, url)
async with httpx.AsyncClient(verify=self._verify) as client:
client = self._get_client_session()
try:
header = self._headers
if custom_header is not None:
header = custom_header

LOGGER.debug(f"Sending post request to url {full_url}. Data is: {data}")
r = await client.post(full_url, json=data, headers=header)
LOGGER.debug(f"Got response {r.status_code}.")

if r.status_code == THROTTLE_STATUS_CODE:
LOGGER.error("Got error 429 (Throttling active)")
raise HmipThrottlingError

r.raise_for_status()

result = RestResult(status=r.status_code)
try:
header = self._headers
if custom_header is not None:
header = custom_header

LOGGER.debug(f"Sending post request to url {full_url}. Data is: {data}")
r = await client.post(full_url, json=data, headers=header)
LOGGER.debug(f"Got response {r.status_code}.")

if r.status_code == THROTTLE_STATUS_CODE:
LOGGER.error("Got error 429 (Throttling active)")
raise HmipThrottlingError

r.raise_for_status()

result = RestResult(status=r.status_code)
try:
result.json = r.json()
except json.JSONDecodeError:
pass

return result
except httpx.RequestError as exc:
LOGGER.error(f"An error occurred while requesting {exc.request.url!r}.")
return RestResult(status=-1, exception=exc)
except httpx.HTTPStatusError as exc:
if self._log_status_exceptions:
LOGGER.error(
f"Error response {exc.response.status_code} while requesting {exc.request.url!r} with data {data if data is not None else "<no-data>"}."
)
LOGGER.error(f"Response: {repr(exc.response)}")
return RestResult(status=exc.response.status_code, exception=exc, text=exc.response.text)
result.json = r.json()
except json.JSONDecodeError:
pass

return result
except httpx.RequestError as exc:
LOGGER.error(f"An error occurred while requesting {exc.request.url!r}.")
return RestResult(status=-1, exception=exc)
except httpx.HTTPStatusError as exc:
if self._log_status_exceptions:
LOGGER.error(
f"Error response {exc.response.status_code} while requesting {exc.request.url!r} with data {data if data is not None else "<no-data>"}."
)
LOGGER.error(f"Response: {repr(exc.response)}")
return RestResult(status=exc.response.status_code, exception=exc, text=exc.response.text)

@staticmethod
def _build_url(base_url: str, path: str) -> str:
Expand Down

0 comments on commit c757705

Please sign in to comment.