Skip to content

Commit

Permalink
feat(app): rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
MorvanZhou committed Sep 18, 2024
1 parent 307b7c1 commit 39a6c5e
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/retk/const/user_behavior_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UserBehaviorTypeEnum(IntEnum):

LLM_KNOWLEDGE_RESPONSE = 24 # backend
NODE_PAGE_VIEW = 25 # backend
RATE_LIMIT_EXCEEDED = 26 # backend


USER_BEHAVIOR_TYPE_MAP = {
Expand Down
5 changes: 3 additions & 2 deletions src/retk/core/scheduler/tasks/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import httpx

from retk.core.utils.tencent import get_auth
from retk.logger import logger


def send_verification_code(
Expand Down Expand Up @@ -78,5 +79,5 @@ async def _send_verification_code(
},
content=payload_bytes,
)
print(response.status_code)
print(response.text)
if response.status_code != 200:
logger.error(f"send email failed: {response.text}")
71 changes: 71 additions & 0 deletions src/retk/core/utils/ratelimiter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import asyncio
import json
import time
from collections import OrderedDict
from datetime import timedelta
from functools import wraps
from typing import Callable
from typing import Union

from starlette.exceptions import HTTPException

from retk import const
from retk.config import is_local_db
from retk.core.statistic import add_user_behavior
from retk.models.tps import AuthedUser


class RateLimiter:
def __init__(self, requests: int, period: Union[int, float, timedelta]):
Expand Down Expand Up @@ -40,3 +51,63 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.semaphore.release()
return False


def req_limit(requests: int, in_seconds: int):
rate_limiter = OrderedDict()

def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
if is_local_db():
return await func(*args, **kwargs)

au: AuthedUser = kwargs.get('au')
ip: str = kwargs.get('ip')
if not au:
if not ip:
raise HTTPException(status_code=400, detail="Bad Request: no au and ip")
key = ip
else:
key = au.u.id

current_time = time.time()
window_start = current_time - in_seconds
try:
user_visit = rate_limiter.pop(key)
except KeyError:
user_visit = []
rate_limiter[key] = user_visit
else:
rate_limiter[key] = [timestamp for timestamp in user_visit if timestamp > window_start]

# remove expired keys
for k in list(rate_limiter.keys()):
try:
if rate_limiter[k][-1] > window_start:
break
except IndexError:
break

rate_limiter.pop(k)

# check if rate limit exceeded
if len(user_visit) >= requests:
await add_user_behavior(
uid=key if au else "",
type_=const.UserBehaviorTypeEnum.RATE_LIMIT_EXCEEDED,
remark=json.dumps({"ip": ip if ip else "", "au": au.u.id if au else "", "func": func.__name__}),
)
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded: max {requests} requests in {in_seconds} seconds"
)

res = await func(*args, **kwargs)

rate_limiter[key].append(current_time)
return res

return wrapper

return decorator
10 changes: 9 additions & 1 deletion src/retk/routes/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from retk import const
from retk.controllers import schemas, account
from retk.core.utils.ratelimiter import req_limit
from retk.routes import utils

router = APIRouter(
Expand All @@ -20,10 +21,12 @@
response_model=schemas.user.UserInfoResponse,
)
@utils.measure_time_spend
@req_limit(requests=10, in_seconds=60)
async def signup(
req_id: utils.ANNOTATED_REQUEST_ID,
req: schemas.account.SignupRequest,
referer: Optional[str] = utils.DEPENDS_REFERER,
ip: Optional[str] = utils.DEPENDS_IP,
) -> JSONResponse:
return await account.signup(req_id=req_id, req=req)

Expand All @@ -38,7 +41,7 @@ async def auto_login(
token: str = Cookie(alias=const.settings.COOKIE_ACCESS_TOKEN, default=""),
request_id: str = Header(
default="", alias="RequestId", max_length=const.settings.MD_MAX_LENGTH
)
),
) -> schemas.user.UserInfoResponse:
return await account.auto_login(
token=token,
Expand All @@ -52,10 +55,12 @@ async def auto_login(
response_model=schemas.user.UserInfoResponse,
)
@utils.measure_time_spend
@req_limit(requests=10, in_seconds=60)
async def login(
req_id: utils.ANNOTATED_REQUEST_ID,
req: schemas.account.LoginRequest,
referer: Optional[str] = utils.DEPENDS_REFERER,
ip: Optional[str] = utils.DEPENDS_IP,
) -> JSONResponse:
return await account.login(req_id=req_id, req=req)

Expand All @@ -80,10 +85,12 @@ async def forget_password(
response_model=schemas.account.TokenResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=60)
async def email_verification(
req_id: utils.ANNOTATED_REQUEST_ID,
req: schemas.account.EmailVerificationRequest,
referer: Optional[str] = utils.DEPENDS_REFERER,
ip: Optional[str] = utils.DEPENDS_IP,
) -> schemas.account.TokenResponse:
return await account.email_send_code(req_id=req_id, req=req)

Expand All @@ -94,6 +101,7 @@ async def email_verification(
response_model=schemas.RequestIdResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=60)
async def refresh_token(
au: utils.ANNOTATED_REFRESH_TOKEN, # check refresh token expiration
referer: Optional[str] = utils.DEPENDS_REFERER,
Expand Down
3 changes: 3 additions & 0 deletions src/retk/routes/app_captcha.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.responses import StreamingResponse

from retk.controllers import account
from retk.core.utils.ratelimiter import req_limit
from retk.routes import utils

router = APIRouter(
Expand All @@ -18,7 +19,9 @@
status_code=200,
)
@utils.measure_time_spend
@req_limit(requests=6, in_seconds=30)
async def get_captcha_img(
referer: Optional[str] = utils.DEPENDS_REFERER,
ip: Optional[str] = utils.DEPENDS_IP,
) -> StreamingResponse:
return account.get_captcha_img()
2 changes: 2 additions & 0 deletions src/retk/routes/browser_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing_extensions import Annotated

from retk.controllers import schemas, browser_extension
from retk.core.utils.ratelimiter import req_limit
from retk.routes import utils

router = APIRouter(
Expand Down Expand Up @@ -44,6 +45,7 @@ async def browser_extension_refresh_token(
response_model=schemas.node.NodeResponse,
)
@utils.measure_time_spend
@req_limit(requests=4, in_seconds=1)
async def post_node_from_browser_extension(
au: utils.ANNOTATED_AUTHED_USER_BROWSER_EXTENSION,
url: Annotated[str, Form(...)],
Expand Down
6 changes: 6 additions & 0 deletions src/retk/routes/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from retk.controllers import schemas
from retk.controllers.files import upload_files
from retk.core.utils.ratelimiter import req_limit
from retk.routes import utils

router = APIRouter(
Expand All @@ -19,6 +20,7 @@
response_model=schemas.RequestIdResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=30)
async def upload_obsidian_files(
au: utils.ANNOTATED_AUTHED_USER,
files: List[UploadFile],
Expand All @@ -36,6 +38,7 @@ async def upload_obsidian_files(
response_model=schemas.RequestIdResponse,
)
@utils.measure_time_spend
@req_limit(requests=10, in_seconds=30)
async def upload_text_files(
au: utils.ANNOTATED_AUTHED_USER,
files: List[UploadFile],
Expand All @@ -53,6 +56,7 @@ async def upload_text_files(
response_model=schemas.files.VditorFilesResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def vditor_upload(
au: utils.ANNOTATED_AUTHED_USER,
req: Request,
Expand All @@ -72,6 +76,7 @@ async def vditor_upload(
response_model=schemas.files.VditorImagesResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def vditor_fetch_image(
au: utils.ANNOTATED_AUTHED_USER,
req: schemas.files.ImageVditorFetchRequest,
Expand All @@ -89,6 +94,7 @@ async def vditor_fetch_image(
response_model=schemas.files.FileUploadProcessResponse,
)
@utils.measure_time_spend
@req_limit(requests=10, in_seconds=30)
async def get_upload_process(
au: utils.ANNOTATED_AUTHED_USER,
referer: Optional[str] = utils.DEPENDS_REFERER,
Expand Down
8 changes: 8 additions & 0 deletions src/retk/routes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from retk import const
from retk.controllers import schemas
from retk.controllers.node import node_ops, search
from retk.core.utils.ratelimiter import req_limit
from retk.routes import utils

router = APIRouter(
Expand All @@ -23,6 +24,7 @@
response_model=schemas.node.NodeResponse,
)
@utils.measure_time_spend
@req_limit(requests=10, in_seconds=1)
async def post_node(
au: utils.ANNOTATED_AUTHED_USER,
req: schemas.node.CreateRequest,
Expand All @@ -40,6 +42,7 @@ async def post_node(
response_model=schemas.node.NodeResponse,
)
@utils.measure_time_spend
@req_limit(requests=10, in_seconds=1)
async def post_quick_node(
au: utils.ANNOTATED_AUTHED_USER,
req: schemas.node.CreateRequest,
Expand All @@ -57,6 +60,7 @@ async def post_quick_node(
response_model=schemas.node.NodesSearchResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def get_search_nodes(
au: utils.ANNOTATED_AUTHED_USER,
q: str = Query(max_length=const.settings.SEARCH_QUERY_MAX_LENGTH),
Expand Down Expand Up @@ -141,6 +145,7 @@ async def get_node(
response_model=schemas.node.NodesSearchResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def get_at_search(
au: utils.ANNOTATED_AUTHED_USER,
nid: str = utils.ANNOTATED_NID,
Expand All @@ -164,6 +169,7 @@ async def get_at_search(
response_model=schemas.node.NodesSearchResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def get_recommend_nodes(
au: utils.ANNOTATED_AUTHED_USER,
nid: str = utils.ANNOTATED_NID,
Expand Down Expand Up @@ -241,6 +247,7 @@ async def put_node_md(
response_model=schemas.RequestIdResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def put_node_favorite(
au: utils.ANNOTATED_AUTHED_USER,
nid: str = utils.ANNOTATED_NID,
Expand Down Expand Up @@ -276,6 +283,7 @@ async def delete_node_favorite(
status_code=200,
)
@utils.measure_time_spend
@req_limit(requests=1, in_seconds=1)
async def stream_export_md_node(
au: utils.ANNOTATED_AUTHED_USER,
nid: str = utils.ANNOTATED_NID,
Expand Down
3 changes: 3 additions & 0 deletions src/retk/routes/recent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter

from retk.controllers import schemas, recent
from retk.core.utils.ratelimiter import req_limit
from retk.routes import utils

router = APIRouter(
Expand All @@ -16,6 +17,7 @@
response_model=schemas.RequestIdResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def add_recent_at_node(
au: utils.ANNOTATED_AUTHED_USER,
req: schemas.recent.AtNodeRequest,
Expand All @@ -32,6 +34,7 @@ async def add_recent_at_node(
response_model=schemas.recent.GetRecentSearchResponse,
)
@utils.measure_time_spend
@req_limit(requests=5, in_seconds=1)
async def get_recent_searched(
au: utils.ANNOTATED_AUTHED_USER,
) -> schemas.recent.GetRecentSearchResponse:
Expand Down
20 changes: 19 additions & 1 deletion src/retk/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import jwt
from fastapi import HTTPException, Header, Cookie, Depends
from fastapi import HTTPException, Header, Cookie, Depends, Request
from fastapi.params import Path
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated
Expand Down Expand Up @@ -77,6 +77,23 @@ def verify_referer(referer: Optional[str] = Header(None)):
return referer


def get_ip(
request: Request
) -> str:
forwarded_for = request.headers.get('X-Forwarded-For')
if forwarded_for:
ip = forwarded_for.split(',')[0].strip()
if ip:
return ip

real_ip = request.headers.get('X-Real-IP')
if real_ip:
return real_ip

client_ip = request.client.host
return client_ip


async def on_startup():
if not config.is_local_db():
add_rotating_file_handler(
Expand Down Expand Up @@ -291,3 +308,4 @@ async def process_browser_extension_refresh_token_headers(
ANNOTATED_FID = Annotated[str, Path(title="The ID of file", max_length=const.settings.FID_MAX_LENGTH)]

DEPENDS_REFERER = Depends(verify_referer)
DEPENDS_IP = Depends(get_ip)
3 changes: 2 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_email_verification(self, mock_send):
},
headers={"RequestId": "xxx"}
)
self.assertEqual(200, resp.status_code)
self.assertEqual(200, resp.status_code, msg=resp.json())
rj = resp.json()
self.assertNotEqual("", rj["token"])
self.assertEqual("xxx", rj["requestId"])
Expand All @@ -106,6 +106,7 @@ def setUpClass(cls) -> None:
utils.set_env(".env.test.local")

async def asyncSetUp(self) -> None:
os.makedirs(Path(__file__).parent / "temp", exist_ok=True)
scheduler.start()
await client.init()
self.client = TestClient(app)
Expand Down
Loading

0 comments on commit 39a6c5e

Please sign in to comment.