Skip to content

Commit

Permalink
Merge pull request #175 from mdeweerd/dev
Browse files Browse the repository at this point in the history
More updates for retry functionnality since zigpy>=0.56.0, HA2023.7.0
  • Loading branch information
mdeweerd authored Jul 10, 2023
2 parents 602363d + 1f75710 commit 1d51c4b
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 56 deletions.
2 changes: 2 additions & 0 deletions STATS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Badges showing number of downloads per version

- ![badge latest](https://img.shields.io/github/downloads/mdeweerd/zha-toolkit/latest/total.svg)
- ![badge v0.9.1](https://img.shields.io/github/downloads/mdeweerd/zha-toolkit/v0.9.1/total.svg)
- ![badge v0.8.40](https://img.shields.io/github/downloads/mdeweerd/zha-toolkit/v0.8.40/total.svg)
- ![badge v0.8.39](https://img.shields.io/github/downloads/mdeweerd/zha-toolkit/v0.8.39/total.svg)
- ![badge v0.8.38](https://img.shields.io/github/downloads/mdeweerd/zha-toolkit/v0.8.38/total.svg)
- ![badge v0.8.37](https://img.shields.io/github/downloads/mdeweerd/zha-toolkit/v0.8.37/total.svg)
Expand Down
10 changes: 6 additions & 4 deletions custom_components/zha_toolkit/binds.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ async def binds_remove_all(
if (tgt_ieee is None or dst_ieee == tgt_ieee) and (
len(clusters) == 0 or cluster_id in clusters
):
res = await zdo.request(
res = await u.retry_wrapper(
zdo.request,
ZDOCmd.Unbind_req,
src_ieee,
binding["src_ep"],
Expand All @@ -482,7 +483,8 @@ async def binds_remove_all(
if (tgt_ieee is None or dst_ieee == tgt_ieee) and (
len(clusters) == 0 or cluster_id in clusters
):
res = await zdo.request(
res = await u.retry_wrapper(
zdo.request,
ZDOCmd.Unbind_req,
src_ieee,
binding["src_ep"],
Expand Down Expand Up @@ -535,8 +537,8 @@ async def binds_get(

while not done:
# Todo: continue when reply is incomplete (update start index)
reply = await zdo.request(
ZDOCmd.Mgmt_Bind_req, idx, tries=params[p.TRIES]
reply = await u.retry_wrapper(
zdo.request, ZDOCmd.Mgmt_Bind_req, idx, tries=params[p.TRIES]
)
event_data["replies"].append(reply)

Expand Down
19 changes: 12 additions & 7 deletions custom_components/zha_toolkit/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ async def get_groups(
continue
try:
ep_info: dict[str, Any] = {}
res = await ep.groups.read_attributes(
["name_support"] # , tries=params[p.TRIES]
res = await u.retry_wrapper(
ep.groups.read_attributes,
["name_support"],
tries=params[p.TRIES],
)
event_data["result"].append(res)

Expand All @@ -41,9 +43,9 @@ async def get_groups(
name_support,
)

all_groups = await ep.groups.get_membership(
[]
) # , tries=params[p.TRIES]
all_groups = await u.retry_wrapper(
ep.groups.get_membership, [], tries=params[p.TRIES]
)
LOGGER.debug(
"Groups on 0x%04X EP %u : %s", src_dev.nwk, ep_id, all_groups
)
Expand Down Expand Up @@ -74,8 +76,11 @@ async def add_group(
# Skip ZDO or endpoints that are not selected
continue
try:
res = await ep.groups.add(
group_id, f"group {group_id}" # , tries=params[p.TRIES]
res = await u.retry_wrapper(
ep.groups.add,
group_id,
f"group {group_id}",
tries=params[p.TRIES],
)
result.append(res)
LOGGER.debug(
Expand Down
17 changes: 1 addition & 16 deletions custom_components/zha_toolkit/ota.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json
import logging
import os
Expand All @@ -7,7 +6,6 @@
import aiohttp
from pkg_resources import parse_version
from zigpy import __version__ as zigpy_version
from zigpy.exceptions import ControllerException, DeliveryError

from . import DEFAULT_OTAU
from . import utils as u
Expand All @@ -21,19 +19,6 @@
SONOFF_LIST_URL = "https://zigbee-ota.sonoff.tech/releases/upgrade.json"


@u.retryable(
(
DeliveryError,
ControllerException,
asyncio.CancelledError,
asyncio.TimeoutError,
),
tries=3,
)
async def retry_wrapper(cmd, *args, **kwargs):
return await cmd(*args, **kwargs)


async def download_koenkk_ota(listener, ota_dir):
# Get all FW files that were already downloaded.
# The files usually have the FW version in their name, making them unique.
Expand Down Expand Up @@ -232,7 +217,7 @@ async def ota_notify(
ret = await cluster.image_notify(0, 100)
else:
cmd_args = [0, 100]
ret = await retry_wrapper(
ret = await u.retry_wrapper(
cluster.client_command,
0, # cmd_id
*cmd_args,
Expand Down
13 changes: 3 additions & 10 deletions custom_components/zha_toolkit/scan_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ async def read_attr(cluster, attrs, manufacturer=None):
)


@u.retryable(
(DeliveryError, asyncio.CancelledError, asyncio.TimeoutError), tries=3
)
async def wrapper(cmd, *args, **kwargs):
return await cmd(*args, **kwargs)


async def scan_results(device, endpoints=None, manufacturer=None, tries=3):
"""Construct scan results from information available in device"""
result: dict[str, str | list | None] = {
Expand Down Expand Up @@ -180,7 +173,7 @@ async def discover_attributes_extended(cluster, manufacturer=None, tries=3):

while not done: # Repeat until all attributes are discovered or timeout
try:
done, rsp = await wrapper(
done, rsp = await u.retry_wrapper(
cluster.discover_attributes_extended,
attr_id, # Start attribute identifier
16, # Number of attributes to discover in this request
Expand Down Expand Up @@ -303,7 +296,7 @@ async def discover_commands_received(

while not done:
try:
done, rsp = await wrapper(
done, rsp = await u.retry_wrapper(
cluster.discover_commands_received,
cmd_id, # Start index of commands to discover
16, # Number of commands to discover
Expand Down Expand Up @@ -367,7 +360,7 @@ async def discover_commands_generated(

while not done:
try:
done, rsp = await wrapper(
done, rsp = await u.retry_wrapper(
cluster.discover_commands_generated,
cmd_id, # Start index of commands to discover
16, # Number of commands to discover this run
Expand Down
37 changes: 34 additions & 3 deletions custom_components/zha_toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import re
import typing
from enum import Enum
from typing import Any

import packaging
import packaging.version
Expand Down Expand Up @@ -842,27 +841,59 @@ def extractParams( # noqa: C901
#
async def retry(
func: typing.Callable[[], typing.Awaitable[typing.Any]],
retry_exceptions: typing.Iterable[Any], # typing.Iterable[BaseException],
retry_exceptions: typing.Iterable[typing.Any]
| None = None, # typing.Iterable[BaseException],
tries: int = 3,
delay: int | float = 0.1,
) -> typing.Any:
"""Retry a function in case of exception
Only exceptions in `retry_exceptions` will be retried.
"""
if retry_exceptions is None:
# Default list
retry_exceptions = (
(
DeliveryError,
ControllerException,
asyncio.CancelledError,
asyncio.TimeoutError,
),
)

while True:
LOGGER.debug("Tries remaining: %s", tries)
try:
return await func()
# pylint: disable-next=catching-non-exception
except retry_exceptions: # type:ignore[misc]
if tries <= 1:
raise
tries -= 1
await asyncio.sleep(delay)


async def retry_wrapper(
func: typing.Callable,
*args,
retry_exceptions: typing.Iterable[typing.Any]
| None = None, # typing.Iterable[BaseException],
tries: int = 3,
delay: int | float = 0.1,
**kwargs,
) -> typing.Any:
"""Inline callable wrapper for retry"""
return await retry(
functools.partial(func, *args, **kwargs),
retry_exceptions,
tries=tries,
delay=delay,
)


def retryable(
retry_exceptions: typing.Iterable[Any], # typing.Iterable[BaseException]
retry_exceptions: None
| typing.Iterable[typing.Any] = None, # typing.Iterable[BaseException]
tries: int = 1,
delay: float = 0.1,
) -> typing.Callable:
Expand Down
24 changes: 8 additions & 16 deletions custom_components/zha_toolkit/zcl_cmd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import logging
from typing import Any

Expand Down Expand Up @@ -115,21 +114,14 @@ async def zcl_cmd(app, listener, ieee, cmd, data, service, params, event_data):
org_cluster_cmd_defs[cmd_id] = None
cluster.server_commands[cmd_id] = cmd_def

if "tries" in inspect.getfullargspec(cluster.command)[0]:
event_data["cmd_reply"] = await cluster.command(
cmd_id,
*cmd_args,
manufacturer=manf,
expect_reply=expect_reply,
tries=tries,
)
else:
event_data["cmd_reply"] = await cluster.command(
cmd_id,
*cmd_args,
manufacturer=manf,
expect_reply=expect_reply,
)
event_data["cmd_reply"] = await u.retry_wrapper(
cluster.command,
cmd_id,
*cmd_args,
manufacturer=manf,
expect_reply=expect_reply,
tries=tries,
)
else:
if cluster_id not in endpoint.out_clusters:
msg = ERR005_NOT_OUT_CLUSTER.format(
Expand Down

0 comments on commit 1d51c4b

Please sign in to comment.