From f2056bcbeb925c90142788c3a60a357aeba3d473 Mon Sep 17 00:00:00 2001 From: Ali-Akber Saifee Date: Fri, 20 Dec 2024 15:52:19 -0800 Subject: [PATCH] Change definition of reset_time to floating point Truncating reset_time (or start of next window) by casting to an int was incorrectly providing the wrong reset_time as there was precision loss in the conversion. --- limits/aio/storage/base.py | 4 +-- limits/aio/storage/etcd.py | 6 ++-- limits/aio/storage/memcached.py | 4 +-- limits/aio/storage/memory.py | 10 +++--- limits/aio/storage/mongodb.py | 26 ++++++---------- limits/aio/storage/redis.py | 16 +++++----- .../redis/lua_scripts/moving_window.lua | 4 ++- limits/storage/base.py | 4 +-- limits/storage/etcd.py | 7 ++--- limits/storage/memcached.py | 4 +-- limits/storage/memory.py | 10 +++--- limits/storage/mongodb.py | 19 +++++------- limits/storage/redis.py | 13 ++++---- limits/storage/redis_sentinel.py | 2 +- limits/util.py | 2 +- tests/aio/test_strategy.py | 30 ++++++++++++------ tests/test_strategy.py | 31 +++++++++++++------ tests/utils.py | 4 +-- 18 files changed, 105 insertions(+), 91 deletions(-) diff --git a/limits/aio/storage/base.py b/limits/aio/storage/base.py index e1d39f11..774fcedf 100644 --- a/limits/aio/storage/base.py +++ b/limits/aio/storage/base.py @@ -103,7 +103,7 @@ async def get(self, key: str) -> int: raise NotImplementedError @abstractmethod - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ @@ -169,7 +169,7 @@ async def acquire_entry( @abstractmethod async def get_moving_window( self, key: str, limit: int, expiry: int - ) -> Tuple[int, int]: + ) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window diff --git a/limits/aio/storage/etcd.py b/limits/aio/storage/etcd.py index 0a745761..1348b12d 100644 --- a/limits/aio/storage/etcd.py +++ b/limits/aio/storage/etcd.py @@ -116,12 +116,12 @@ async def get(self, key: str) -> int: return int(amount) return 0 - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: cur = await self.storage.get(self.prefixed_key(key)) if cur: window_end = float(cur.value.split(b":")[1]) - return int(window_end) - return int(time.time()) + return window_end + return time.time() async def check(self) -> bool: try: diff --git a/limits/aio/storage/memcached.py b/limits/aio/storage/memcached.py index 3a6e3c7e..bdf13267 100644 --- a/limits/aio/storage/memcached.py +++ b/limits/aio/storage/memcached.py @@ -126,14 +126,14 @@ async def incr( return amount - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ storage = await self.get_storage() item = await storage.get(f"{key}/expires".encode()) - return int(item and float(item.value) or time.time()) + return item and float(item.value) or time.time() async def check(self) -> bool: """ diff --git a/limits/aio/storage/memory.py b/limits/aio/storage/memory.py index 6b465cda..7b575066 100644 --- a/limits/aio/storage/memory.py +++ b/limits/aio/storage/memory.py @@ -128,12 +128,12 @@ async def acquire_entry( return True - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ - return int(self.expirations.get(key, time.time())) + return self.expirations.get(key, time.time()) async def get_num_acquired(self, key: str, expiry: int) -> int: """ @@ -153,7 +153,7 @@ async def get_num_acquired(self, key: str, expiry: int) -> int: # FIXME: arg limit is not used async def get_moving_window( self, key: str, limit: int, expiry: int - ) -> Tuple[int, int]: + ) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window @@ -167,9 +167,9 @@ async def get_moving_window( for item in self.events.get(key, [])[::-1]: if item.atime >= timestamp - expiry: - return int(item.atime), acquired + return item.atime, acquired - return int(timestamp), acquired + return timestamp, acquired async def check(self) -> bool: """ diff --git a/limits/aio/storage/mongodb.py b/limits/aio/storage/mongodb.py index 6128e9ec..87773f2e 100644 --- a/limits/aio/storage/mongodb.py +++ b/limits/aio/storage/mongodb.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import calendar import datetime import time from typing import Any, cast @@ -135,21 +134,19 @@ async def clear(self, key: str) -> None: ), ) - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ counter = await self.database[self.__collection_mapping["counters"]].find_one( {"_id": key} ) - expiry = ( - counter["expireAt"] - if counter - else datetime.datetime.now(datetime.timezone.utc) + return ( + (counter["expireAt"] if counter else datetime.datetime.now()) + .replace(tzinfo=datetime.timezone.utc) + .timestamp() ) - return calendar.timegm(expiry.timetuple()) - async def get(self, key: str) -> int: """ :param key: the key to get the counter value for @@ -227,7 +224,7 @@ async def check(self) -> bool: async def get_moving_window( self, key: str, limit: int, expiry: int - ) -> Tuple[int, int]: + ) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window @@ -237,7 +234,7 @@ async def get_moving_window( :return: (start of window, number of acquired entries) """ timestamp = time.time() - result = ( + if result := ( await self.database[self.__collection_mapping["windows"]] .aggregate( [ @@ -264,12 +261,9 @@ async def get_moving_window( ] ) .to_list(length=1) - ) - - if result: - return (int(result[0]["min"]), result[0]["count"]) - - return (int(timestamp), 0) + ): + return result[0]["min"], result[0]["count"] + return timestamp, 0 async def acquire_entry( self, key: str, limit: int, expiry: int, amount: int = 1 diff --git a/limits/aio/storage/redis.py b/limits/aio/storage/redis.py index ee063708..35a58edf 100644 --- a/limits/aio/storage/redis.py +++ b/limits/aio/storage/redis.py @@ -78,7 +78,7 @@ async def _clear(self, key: str, connection: AsyncRedisClient) -> None: async def get_moving_window( self, key: str, limit: int, expiry: int - ) -> Tuple[int, int]: + ) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window @@ -88,12 +88,12 @@ async def get_moving_window( :return: (start of window, number of acquired entries) """ key = self.prefixed_key(key) - timestamp = int(time.time()) + timestamp = time.time() window = await self.lua_moving_window.execute( - [key], [int(timestamp - expiry), limit] + [key], [timestamp - expiry, limit] ) if window: - return tuple(window) # type: ignore + return float(window[0]), window[1] # type: ignore return timestamp, 0 async def _acquire_entry( @@ -118,14 +118,14 @@ async def _acquire_entry( return bool(acquired) - async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> int: + async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> float: """ :param key: the key to get the expiry for :param connection: Redis connection """ key = self.prefixed_key(key) - return int(max(await connection.ttl(key), 0) + time.time()) + return max(await connection.ttl(key), 0) + time.time() async def _check(self, connection: AsyncRedisClient) -> bool: """ @@ -261,7 +261,7 @@ async def acquire_entry( return await super()._acquire_entry(key, limit, expiry, self.storage, amount) - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ @@ -450,7 +450,7 @@ async def get(self, key: str) -> int: key, self.storage_replica if self.use_replicas else self.storage ) - async def get_expiry(self, key: str) -> int: + async def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ diff --git a/limits/resources/redis/lua_scripts/moving_window.lua b/limits/resources/redis/lua_scripts/moving_window.lua index affe6578..b3d38724 100644 --- a/limits/resources/redis/lua_scripts/moving_window.lua +++ b/limits/resources/redis/lua_scripts/moving_window.lua @@ -16,4 +16,6 @@ for idx=1,#items do end end -return {oldest, a} +if oldest then + return {tostring(oldest), a} +end \ No newline at end of file diff --git a/limits/storage/base.py b/limits/storage/base.py index 75cc6467..4b05b11b 100644 --- a/limits/storage/base.py +++ b/limits/storage/base.py @@ -99,7 +99,7 @@ def get(self, key: str) -> int: raise NotImplementedError @abstractmethod - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ @@ -161,7 +161,7 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b raise NotImplementedError @abstractmethod - def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]: + def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window diff --git a/limits/storage/etcd.py b/limits/storage/etcd.py index cea2e50f..f7db4af6 100644 --- a/limits/storage/etcd.py +++ b/limits/storage/etcd.py @@ -110,12 +110,11 @@ def get(self, key: str) -> int: return int(amount) return 0 - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: value, _ = self.storage.get(self.prefixed_key(key)) if value: - window_end = float(value.split(b":")[1]) - return int(window_end) - return int(time.time()) + return float(value.split(b":")[1]) + return time.time() def check(self) -> bool: try: diff --git a/limits/storage/memcached.py b/limits/storage/memcached.py index 7942094c..6439055f 100644 --- a/limits/storage/memcached.py +++ b/limits/storage/memcached.py @@ -192,12 +192,12 @@ def incr( return amount - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ - return int(float(self.storage.get(key + "/expires") or time.time())) + return float(self.storage.get(key + "/expires") or time.time()) def check(self) -> bool: """ diff --git a/limits/storage/memory.py b/limits/storage/memory.py index e5ffa8ef..a0fa7593 100644 --- a/limits/storage/memory.py +++ b/limits/storage/memory.py @@ -121,12 +121,12 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b self.events[key][:0] = [LockableEntry(expiry) for _ in range(amount)] return True - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ - return int(self.expirations.get(key, time.time())) + return self.expirations.get(key, time.time()) def get_num_acquired(self, key: str, expiry: int) -> int: """ @@ -143,7 +143,7 @@ def get_num_acquired(self, key: str, expiry: int) -> int: else 0 ) - def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]: + def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window @@ -157,9 +157,9 @@ def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int for item in self.events.get(key, [])[::-1]: if item.atime >= timestamp - expiry: - return int(item.atime), acquired + return item.atime, acquired - return int(timestamp), acquired + return timestamp, acquired def check(self) -> bool: """ diff --git a/limits/storage/mongodb.py b/limits/storage/mongodb.py index 4161d9ee..dea1f97a 100644 --- a/limits/storage/mongodb.py +++ b/limits/storage/mongodb.py @@ -1,6 +1,5 @@ from __future__ import annotations -import calendar import datetime import time from abc import ABC, abstractmethod @@ -122,19 +121,17 @@ def clear(self, key: str) -> None: self.counters.find_one_and_delete({"_id": key}) self.windows.find_one_and_delete({"_id": key}) - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ counter = self.counters.find_one({"_id": key}) - expiry = ( - counter["expireAt"] - if counter - else datetime.datetime.now(datetime.timezone.utc) + return ( + (counter["expireAt"] if counter else datetime.datetime.now()) + .replace(tzinfo=datetime.timezone.utc) + .timestamp() ) - return calendar.timegm(expiry.timetuple()) - def get(self, key: str) -> int: """ :param key: the key to get the counter value for @@ -205,7 +202,7 @@ def check(self) -> bool: except: # noqa: E722 return False - def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]: + def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window @@ -243,9 +240,9 @@ def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int ) if result: - return int(result[0]["min"]), result[0]["count"] + return result[0]["min"], result[0]["count"] - return int(timestamp), 0 + return timestamp, 0 def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool: """ diff --git a/limits/storage/redis.py b/limits/storage/redis.py index 1d7ab722..f6d7446f 100644 --- a/limits/storage/redis.py +++ b/limits/storage/redis.py @@ -32,7 +32,7 @@ class RedisInteractor: def prefixed_key(self, key: str) -> str: return f"{self.PREFIX}:{key}" - def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]: + def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]: """ returns the starting point and the number of entries in the moving window @@ -43,9 +43,10 @@ def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int """ key = self.prefixed_key(key) timestamp = time.time() - window = self.lua_moving_window([key], [int(timestamp - expiry), limit]) + if window := self.lua_moving_window([key], [timestamp - expiry, limit]): + return float(window[0]), window[1] - return window or (int(timestamp), 0) + return timestamp, 0 def _incr( self, @@ -109,14 +110,14 @@ def _acquire_entry( return bool(acquired) - def _get_expiry(self, key: str, connection: RedisClient) -> int: + def _get_expiry(self, key: str, connection: RedisClient) -> float: """ :param key: the key to get the expiry for :param connection: Redis connection """ key = self.prefixed_key(key) - return int(max(connection.ttl(key), 0) + time.time()) + return max(connection.ttl(key), 0) + time.time() def _check(self, connection: RedisClient) -> bool: """ @@ -232,7 +233,7 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b return super()._acquire_entry(key, limit, expiry, self.storage, amount) - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ diff --git a/limits/storage/redis_sentinel.py b/limits/storage/redis_sentinel.py index 7d4b03c2..34334440 100644 --- a/limits/storage/redis_sentinel.py +++ b/limits/storage/redis_sentinel.py @@ -101,7 +101,7 @@ def get(self, key: str) -> int: key, self.storage_slave if self.use_replicas else self.storage ) - def get_expiry(self, key: str) -> int: + def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ diff --git a/limits/util.py b/limits/util.py index a70f8858..e5c85716 100644 --- a/limits/util.py +++ b/limits/util.py @@ -38,7 +38,7 @@ class WindowStats(NamedTuple): """ #: Time as seconds since the Epoch when this window will be reset - reset_time: int + reset_time: float #: Quantity remaining in this window remaining: int diff --git a/tests/aio/test_strategy.py b/tests/aio/test_strategy.py index b21f8084..fa2a264c 100644 --- a/tests/aio/test_strategy.py +++ b/tests/aio/test_strategy.py @@ -34,7 +34,9 @@ async def test_fixed_window(self, uri, args, fixture): assert all([await limiter.hit(limit) for _ in range(0, 10)]) assert not await limiter.hit(limit) assert (await limiter.get_window_stats(limit)).remaining == 0 - assert (await limiter.get_window_stats(limit)).reset_time == start + 2 + assert (await limiter.get_window_stats(limit)).reset_time == pytest.approx( + start + 2, 1e-2 + ) @async_all_storage @fixed_start @@ -43,7 +45,9 @@ async def test_fixed_window_empty_stats(self, uri, args, fixture): limiter = FixedWindowRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) assert (await limiter.get_window_stats(limit)).remaining == 10 - assert (await limiter.get_window_stats(limit)).reset_time == int(time.time()) + assert (await limiter.get_window_stats(limit)).reset_time == pytest.approx( + time.time(), 1e-2 + ) @async_moving_window_storage async def test_moving_window_stats(self, uri, args, fixture): @@ -56,9 +60,9 @@ async def test_moving_window_stats(self, uri, args, fixture): time.sleep(1) assert not await limiter.hit(limit, "key") assert (await limiter.get_window_stats(limit, "key")).remaining == 0 - assert (await limiter.get_window_stats(limit, "key")).reset_time - int( - time.time() - ) == 58 + assert ( + await limiter.get_window_stats(limit, "key") + ).reset_time - time.time() == pytest.approx(58, 1e-2) @async_all_storage @fixed_start @@ -82,12 +86,16 @@ async def test_fixed_window_with_elastic_expiry(self, uri, args, fixture): assert all([await limiter.hit(limit) for _ in range(0, 10)]) assert not await limiter.hit(limit) assert (await limiter.get_window_stats(limit)).remaining == 0 - assert (await limiter.get_window_stats(limit)).reset_time == start + 2 + assert (await limiter.get_window_stats(limit)).reset_time == pytest.approx( + start + 2, 1e-2 + ) async with async_window(3) as (start, end): assert not await limiter.hit(limit) assert await limiter.hit(limit) assert (await limiter.get_window_stats(limit)).remaining == 9 - assert (await limiter.get_window_stats(limit)).reset_time == end + 2 + assert (await limiter.get_window_stats(limit)).reset_time == pytest.approx( + end + 2, 1e-2 + ) @async_all_storage @fixed_start @@ -101,7 +109,9 @@ async def test_fixed_window_with_elastic_expiry_multiple_cost( async with async_window(0) as (_, end): assert await limiter.hit(limit, "k2", cost=5) assert (await limiter.get_window_stats(limit, "k2")).remaining == 5 - assert (await limiter.get_window_stats(limit, "k2")).reset_time == end + 2 + assert ( + await limiter.get_window_stats(limit, "k2") + ).reset_time == pytest.approx(end + 2, 1e-2) assert not await limiter.hit(limit, "k2", cost=6) @async_moving_window_storage @@ -128,8 +138,8 @@ async def test_moving_window_empty_stats(self, uri, args, fixture): limiter = MovingWindowRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) assert (await limiter.get_window_stats(limit)).remaining == 10 - assert (await limiter.get_window_stats(limit)).reset_time == int( - time.time() + 2 + assert (await limiter.get_window_stats(limit)).reset_time == pytest.approx( + time.time() + 2, 1e-2 ) @async_moving_window_storage diff --git a/tests/test_strategy.py b/tests/test_strategy.py index e7dff66f..c8887422 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -1,4 +1,3 @@ -import math import time import pytest @@ -28,7 +27,9 @@ def test_fixed_window(self, uri, args, fixture): assert all([limiter.hit(limit) for _ in range(0, 10)]) assert not limiter.hit(limit) assert limiter.get_window_stats(limit).remaining == 0 - assert limiter.get_window_stats(limit).reset_time == math.floor(start + 2) + assert limiter.get_window_stats(limit).reset_time == pytest.approx( + start + 2, 1e-2 + ) @all_storage @fixed_start @@ -37,7 +38,9 @@ def test_fixed_window_empty_stats(self, uri, args, fixture): limiter = FixedWindowRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) assert limiter.get_window_stats(limit).remaining == 10 - assert limiter.get_window_stats(limit).reset_time == int(time.time()) + assert limiter.get_window_stats(limit).reset_time == pytest.approx( + time.time(), 1e-2 + ) @all_storage @fixed_start @@ -61,12 +64,16 @@ def test_fixed_window_with_elastic_expiry(self, uri, args, fixture): assert all([limiter.hit(limit) for _ in range(0, 10)]) assert not limiter.hit(limit) assert limiter.get_window_stats(limit).remaining == 0 - assert limiter.get_window_stats(limit).reset_time == start + 2 + assert limiter.get_window_stats(limit).reset_time == pytest.approx( + start + 2, 1e-2 + ) with window(3) as (start, end): assert not limiter.hit(limit) assert limiter.hit(limit) assert limiter.get_window_stats(limit).remaining == 9 - assert limiter.get_window_stats(limit).reset_time == end + 2 + assert limiter.get_window_stats(limit).reset_time == pytest.approx( + end + 2, 1e-2 + ) @all_storage @fixed_start @@ -78,7 +85,9 @@ def test_fixed_window_with_elastic_expiry_multiple_cost(self, uri, args, fixture with window(0) as (start, end): assert limiter.hit(limit, "k2", cost=5) assert limiter.get_window_stats(limit, "k2").remaining == 5 - assert limiter.get_window_stats(limit, "k2").reset_time == end + 2 + assert limiter.get_window_stats(limit, "k2").reset_time == pytest.approx( + end + 2, 1e-2 + ) assert not limiter.hit(limit, "k2", cost=6) @moving_window_storage @@ -87,7 +96,9 @@ def test_moving_window_empty_stats(self, uri, args, fixture): limiter = MovingWindowRateLimiter(storage) limit = RateLimitItemPerSecond(10, 2) assert limiter.get_window_stats(limit).remaining == 10 - assert limiter.get_window_stats(limit).reset_time == int(time.time() + 2) + assert limiter.get_window_stats(limit).reset_time == pytest.approx( + time.time() + 2, 1e-2 + ) @moving_window_storage def test_moving_window_stats(self, uri, args, fixture): @@ -100,9 +111,9 @@ def test_moving_window_stats(self, uri, args, fixture): time.sleep(1) assert not limiter.hit(limit, "key") assert limiter.get_window_stats(limit, "key").remaining == 0 - assert ( - limiter.get_window_stats(limit, "key").reset_time - int(time.time()) == 58 - ) + assert limiter.get_window_stats( + limit, "key" + ).reset_time - time.time() == pytest.approx(58, 1e-2) @moving_window_storage def test_moving_window(self, uri, args, fixture): diff --git a/tests/utils.py b/tests/utils.py index 534e36c5..1abc9cb8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -29,7 +29,7 @@ def window(delay_end: float, delay: Optional[float] = None): if delay is not None: while time.time() - start < delay: time.sleep(0.001) - yield (int(start), int(start + delay_end)) + yield (start, start + delay_end) while time.time() - start < delay_end: time.sleep(0.001) @@ -43,7 +43,7 @@ async def async_window(delay_end: float, delay: Optional[float] = None): while time.time() - start < delay: await asyncio.sleep(0.001) - yield (int(start), int(start + delay_end)) + yield (start, start + delay_end) while time.time() - start < delay_end: await asyncio.sleep(0.001)