Skip to content

Commit

Permalink
Change definition of reset_time to floating point
Browse files Browse the repository at this point in the history
Truncating reset_time (or start of next window) by
casting to an int was incorrectly providing the wrong
reset_time as there was precision loss in the conversion.
  • Loading branch information
alisaifee committed Dec 21, 2024
1 parent 0671723 commit f2056bc
Show file tree
Hide file tree
Showing 18 changed files with 105 additions and 91 deletions.
4 changes: 2 additions & 2 deletions limits/aio/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def get(self, key: str) -> int:
raise NotImplementedError

@abstractmethod
async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
Expand Down Expand Up @@ -169,7 +169,7 @@ async def acquire_entry(
@abstractmethod
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> Tuple[int, int]:
) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand Down
6 changes: 3 additions & 3 deletions limits/aio/storage/etcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ async def get(self, key: str) -> int:
return int(amount)
return 0

async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
cur = await self.storage.get(self.prefixed_key(key))
if cur:
window_end = float(cur.value.split(b":")[1])
return int(window_end)
return int(time.time())
return window_end
return time.time()

async def check(self) -> bool:
try:
Expand Down
4 changes: 2 additions & 2 deletions limits/aio/storage/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ async def incr(

return amount

async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
storage = await self.get_storage()
item = await storage.get(f"{key}/expires".encode())

return int(item and float(item.value) or time.time())
return item and float(item.value) or time.time()

async def check(self) -> bool:
"""
Expand Down
10 changes: 5 additions & 5 deletions limits/aio/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ async def acquire_entry(

return True

async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""

return int(self.expirations.get(key, time.time()))
return self.expirations.get(key, time.time())

async def get_num_acquired(self, key: str, expiry: int) -> int:
"""
Expand All @@ -153,7 +153,7 @@ async def get_num_acquired(self, key: str, expiry: int) -> int:
# FIXME: arg limit is not used
async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> Tuple[int, int]:
) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand All @@ -167,9 +167,9 @@ async def get_moving_window(

for item in self.events.get(key, [])[::-1]:
if item.atime >= timestamp - expiry:
return int(item.atime), acquired
return item.atime, acquired

return int(timestamp), acquired
return timestamp, acquired

async def check(self) -> bool:
"""
Expand Down
26 changes: 10 additions & 16 deletions limits/aio/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import calendar
import datetime
import time
from typing import Any, cast
Expand Down Expand Up @@ -135,21 +134,19 @@ async def clear(self, key: str) -> None:
),
)

async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
counter = await self.database[self.__collection_mapping["counters"]].find_one(
{"_id": key}
)
expiry = (
counter["expireAt"]
if counter
else datetime.datetime.now(datetime.timezone.utc)
return (
(counter["expireAt"] if counter else datetime.datetime.now())
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
)

return calendar.timegm(expiry.timetuple())

async def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
Expand Down Expand Up @@ -227,7 +224,7 @@ async def check(self) -> bool:

async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> Tuple[int, int]:
) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand All @@ -237,7 +234,7 @@ async def get_moving_window(
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
result = (
if result := (
await self.database[self.__collection_mapping["windows"]]
.aggregate(
[
Expand All @@ -264,12 +261,9 @@ async def get_moving_window(
]
)
.to_list(length=1)
)

if result:
return (int(result[0]["min"]), result[0]["count"])

return (int(timestamp), 0)
):
return result[0]["min"], result[0]["count"]
return timestamp, 0

async def acquire_entry(
self, key: str, limit: int, expiry: int, amount: int = 1
Expand Down
16 changes: 8 additions & 8 deletions limits/aio/storage/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def _clear(self, key: str, connection: AsyncRedisClient) -> None:

async def get_moving_window(
self, key: str, limit: int, expiry: int
) -> Tuple[int, int]:
) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand All @@ -88,12 +88,12 @@ async def get_moving_window(
:return: (start of window, number of acquired entries)
"""
key = self.prefixed_key(key)
timestamp = int(time.time())
timestamp = time.time()
window = await self.lua_moving_window.execute(
[key], [int(timestamp - expiry), limit]
[key], [timestamp - expiry, limit]
)
if window:
return tuple(window) # type: ignore
return float(window[0]), window[1] # type: ignore
return timestamp, 0

async def _acquire_entry(
Expand All @@ -118,14 +118,14 @@ async def _acquire_entry(

return bool(acquired)

async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> int:
async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> float:
"""
:param key: the key to get the expiry for
:param connection: Redis connection
"""

key = self.prefixed_key(key)
return int(max(await connection.ttl(key), 0) + time.time())
return max(await connection.ttl(key), 0) + time.time()

async def _check(self, connection: AsyncRedisClient) -> bool:
"""
Expand Down Expand Up @@ -261,7 +261,7 @@ async def acquire_entry(

return await super()._acquire_entry(key, limit, expiry, self.storage, amount)

async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
Expand Down Expand Up @@ -450,7 +450,7 @@ async def get(self, key: str) -> int:
key, self.storage_replica if self.use_replicas else self.storage
)

async def get_expiry(self, key: str) -> int:
async def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
Expand Down
4 changes: 3 additions & 1 deletion limits/resources/redis/lua_scripts/moving_window.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ for idx=1,#items do
end
end

return {oldest, a}
if oldest then
return {tostring(oldest), a}
end
4 changes: 2 additions & 2 deletions limits/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get(self, key: str) -> int:
raise NotImplementedError

@abstractmethod
def get_expiry(self, key: str) -> int:
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
Expand Down Expand Up @@ -161,7 +161,7 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b
raise NotImplementedError

@abstractmethod
def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]:
def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand Down
7 changes: 3 additions & 4 deletions limits/storage/etcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,11 @@ def get(self, key: str) -> int:
return int(amount)
return 0

def get_expiry(self, key: str) -> int:
def get_expiry(self, key: str) -> float:
value, _ = self.storage.get(self.prefixed_key(key))
if value:
window_end = float(value.split(b":")[1])
return int(window_end)
return int(time.time())
return float(value.split(b":")[1])
return time.time()

def check(self) -> bool:
try:
Expand Down
4 changes: 2 additions & 2 deletions limits/storage/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,12 @@ def incr(

return amount

def get_expiry(self, key: str) -> int:
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""

return int(float(self.storage.get(key + "/expires") or time.time()))
return float(self.storage.get(key + "/expires") or time.time())

def check(self) -> bool:
"""
Expand Down
10 changes: 5 additions & 5 deletions limits/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> b
self.events[key][:0] = [LockableEntry(expiry) for _ in range(amount)]
return True

def get_expiry(self, key: str) -> int:
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""

return int(self.expirations.get(key, time.time()))
return self.expirations.get(key, time.time())

def get_num_acquired(self, key: str, expiry: int) -> int:
"""
Expand All @@ -143,7 +143,7 @@ def get_num_acquired(self, key: str, expiry: int) -> int:
else 0
)

def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]:
def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand All @@ -157,9 +157,9 @@ def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int

for item in self.events.get(key, [])[::-1]:
if item.atime >= timestamp - expiry:
return int(item.atime), acquired
return item.atime, acquired

return int(timestamp), acquired
return timestamp, acquired

def check(self) -> bool:
"""
Expand Down
19 changes: 8 additions & 11 deletions limits/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import calendar
import datetime
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -122,19 +121,17 @@ def clear(self, key: str) -> None:
self.counters.find_one_and_delete({"_id": key})
self.windows.find_one_and_delete({"_id": key})

def get_expiry(self, key: str) -> int:
def get_expiry(self, key: str) -> float:
"""
:param key: the key to get the expiry for
"""
counter = self.counters.find_one({"_id": key})
expiry = (
counter["expireAt"]
if counter
else datetime.datetime.now(datetime.timezone.utc)
return (
(counter["expireAt"] if counter else datetime.datetime.now())
.replace(tzinfo=datetime.timezone.utc)
.timestamp()
)

return calendar.timegm(expiry.timetuple())

def get(self, key: str) -> int:
"""
:param key: the key to get the counter value for
Expand Down Expand Up @@ -205,7 +202,7 @@ def check(self) -> bool:
except: # noqa: E722
return False

def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int]:
def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[float, int]:
"""
returns the starting point and the number of entries in the moving
window
Expand Down Expand Up @@ -243,9 +240,9 @@ def get_moving_window(self, key: str, limit: int, expiry: int) -> Tuple[int, int
)

if result:
return int(result[0]["min"]), result[0]["count"]
return result[0]["min"], result[0]["count"]

return int(timestamp), 0
return timestamp, 0

def acquire_entry(self, key: str, limit: int, expiry: int, amount: int = 1) -> bool:
"""
Expand Down
Loading

0 comments on commit f2056bc

Please sign in to comment.