-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RateLimitedRestConnection with token bucket
Add AbstanceType Enum
- Loading branch information
Showing
8 changed files
with
191 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import asyncio | ||
import time | ||
|
||
|
||
class Buckets: | ||
"""Class to manage the rate limiting of the HomematicIP Cloud API. | ||
The implementation is based on the token bucket algorithm.""" | ||
|
||
def __init__(self, tokens, fill_rate): | ||
"""Initialize the Buckets with a token bucket algorithm. | ||
:param tokens: The number of tokens in the bucket. | ||
:param fill_rate: The fill rate of the bucket in tokens every x seconds.""" | ||
self.capacity = tokens | ||
self._tokens = tokens | ||
self.fill_rate = fill_rate | ||
self.timestamp = time.time() | ||
self.lock = asyncio.Lock() | ||
|
||
async def take(self, tokens=1) -> bool: | ||
"""Get a single token from the bucket. Return True if successful, False otherwise. | ||
:param tokens: The number of tokens to take from the bucket. Default is 1. | ||
:return: True if successful, False otherwise. | ||
""" | ||
async with self.lock: | ||
if tokens <= await self.tokens(): | ||
self._tokens -= tokens | ||
return True | ||
return False | ||
|
||
async def wait_and_take(self, timeout=120, tokens=1) -> bool: | ||
"""Wait until a token is available and then take it. Return True if successful, False otherwise. | ||
:param timeout: The maximum time to wait for a token in seconds. Default is 120 seconds. | ||
:param tokens: The number of tokens to take from the bucket. Default is 1. | ||
:return: True if successful, False otherwise. | ||
""" | ||
start_time = time.time() | ||
while True: | ||
if tokens <= await self.tokens(): | ||
self._tokens -= tokens | ||
return True | ||
|
||
if time.time() - start_time > timeout: | ||
raise asyncio.TimeoutError("Timeout while waiting for token.") | ||
|
||
await asyncio.sleep(1) # Wait for a second before checking again | ||
|
||
async def tokens(self): | ||
"""Get the number of tokens in the bucket. Refill the bucket if necessary.""" | ||
if self._tokens < self.capacity: | ||
now = time.time() | ||
delta = int((now - self.timestamp) / self.fill_rate) | ||
if delta > 0: | ||
self._tokens = min(self.capacity, self._tokens + delta) | ||
self.timestamp = now | ||
return self._tokens |
21 changes: 21 additions & 0 deletions
21
src/homematicip/connection/rate_limited_rest_connection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import json | ||
|
||
from homematicip.connection.buckets import Buckets | ||
from homematicip.connection.rest_connection import RestConnection, ConnectionContext, RestResult | ||
|
||
|
||
class RateLimitedRestConnection(RestConnection): | ||
|
||
def __init__(self, context: ConnectionContext, tokens: int = 10, fill_rate: int = 8): | ||
"""Initialize the RateLimitedRestConnection with a token bucket algorithm. | ||
:param context: The connection context. | ||
:param tokens: The number of tokens in the bucket. Default is 10. | ||
:param fill_rate: The fill rate of the bucket in tokens per second. Default is 8.""" | ||
super().__init__(context) | ||
self._buckets = Buckets(tokens=tokens, fill_rate=fill_rate) | ||
|
||
async def async_post(self, url: str, data: json = None, custom_header: dict = None) -> RestResult: | ||
"""Post data to the HomematicIP Cloud API.""" | ||
await self._buckets.wait_and_take() | ||
return await super().async_post(url, data, custom_header) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import asyncio | ||
|
||
import pytest | ||
|
||
from homematicip.connection.buckets import Buckets | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_bucket(): | ||
"""Testing the get bucket method.""" | ||
bucket = Buckets(2, 10) | ||
|
||
got_1st_token = await bucket.take() | ||
got_2nd_token = await bucket.take() | ||
got_3rd_token = await bucket.take() | ||
|
||
assert got_1st_token is True | ||
assert got_2nd_token is True | ||
assert got_3rd_token is False | ||
|
||
|
||
async def test_get_bucket_with_timeout(): | ||
"""Testing the get bucket method with timeout.""" | ||
bucket = Buckets(1, 100) | ||
|
||
got_1st_token = await bucket.take() | ||
with pytest.raises(asyncio.TimeoutError): | ||
await bucket.wait_and_take(timeout=1) | ||
|
||
|
||
async def test_get_bucket_and_wait_for_new(): | ||
"""Testing the get bucket method and waiting for new tokens.""" | ||
bucket = Buckets(1, 1) | ||
|
||
got_1st_token = await bucket.take() | ||
got_2nd_token = await bucket.wait_and_take() | ||
|
||
assert got_1st_token is True | ||
assert got_2nd_token is True | ||
|
||
def test_initial_tokens(): | ||
"""Testing the initial tokens of the bucket.""" | ||
bucket = Buckets(2, 10) | ||
assert bucket._tokens == 2 | ||
assert bucket.capacity == 2 | ||
assert bucket.fill_rate == 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import httpx | ||
|
||
from homematicip.connection.rate_limited_rest_connection import RateLimitedRestConnection | ||
from homematicip.connection.rest_connection import ConnectionContext | ||
|
||
|
||
async def test_send_single_request(mocker): | ||
response = mocker.Mock(spec=httpx.Response) | ||
response.status_code = 200 | ||
patched = mocker.patch("homematicip.connection.rest_connection.httpx.AsyncClient.post") | ||
patched.return_value = response | ||
|
||
context = ConnectionContext(rest_url="http://asdf") | ||
conn = RateLimitedRestConnection(context, 1, 1) | ||
|
||
result = await conn.async_post("url", {"a": "b"}, {"c": "d"}) | ||
|
||
assert patched.called | ||
assert patched.call_args[0][0] == "http://asdf/hmip/url" | ||
assert patched.call_args[1] == {"json": {"a": "b"}, "headers": {"c": "d"}} | ||
assert result.status == 200 | ||
|
||
|
||
async def test_send_and_wait_requests(mocker): | ||
response = mocker.Mock(spec=httpx.Response) | ||
response.status_code = 200 | ||
patched = mocker.patch("homematicip.connection.rest_connection.httpx.AsyncClient.post") | ||
patched.return_value = response | ||
|
||
context = ConnectionContext(rest_url="http://asdf") | ||
conn = RateLimitedRestConnection(context, 1, 1) | ||
|
||
await conn.async_post("url", {"a": "b"}, {"c": "d"}) | ||
await conn.async_post("url", {"a": "b"}, {"c": "d"}) | ||
await conn.async_post("url", {"a": "b"}, {"c": "d"}) | ||
|
||
assert patched.call_count == 3 |