From 39a6c5ed8d2a4466da46ea640c1f767aab31d4e5 Mon Sep 17 00:00:00 2001 From: morvanzhou Date: Wed, 18 Sep 2024 22:27:45 +0800 Subject: [PATCH] feat(app): rate limit --- src/retk/const/user_behavior_types.py | 1 + src/retk/core/scheduler/tasks/email.py | 5 +- src/retk/core/utils/ratelimiter.py | 71 ++++++++++++++++++++++++++ src/retk/routes/account.py | 10 +++- src/retk/routes/app_captcha.py | 3 ++ src/retk/routes/browser_extension.py | 2 + src/retk/routes/files.py | 6 +++ src/retk/routes/node.py | 8 +++ src/retk/routes/recent.py | 3 ++ src/retk/routes/utils.py | 20 +++++++- tests/test_api.py | 3 +- tests/test_core_remote.py | 17 ++++++ 12 files changed, 144 insertions(+), 5 deletions(-) diff --git a/src/retk/const/user_behavior_types.py b/src/retk/const/user_behavior_types.py index 5141f18..55263ec 100644 --- a/src/retk/const/user_behavior_types.py +++ b/src/retk/const/user_behavior_types.py @@ -40,6 +40,7 @@ class UserBehaviorTypeEnum(IntEnum): LLM_KNOWLEDGE_RESPONSE = 24 # backend NODE_PAGE_VIEW = 25 # backend + RATE_LIMIT_EXCEEDED = 26 # backend USER_BEHAVIOR_TYPE_MAP = { diff --git a/src/retk/core/scheduler/tasks/email.py b/src/retk/core/scheduler/tasks/email.py index 68a39af..21752a8 100644 --- a/src/retk/core/scheduler/tasks/email.py +++ b/src/retk/core/scheduler/tasks/email.py @@ -6,6 +6,7 @@ import httpx from retk.core.utils.tencent import get_auth +from retk.logger import logger def send_verification_code( @@ -78,5 +79,5 @@ async def _send_verification_code( }, content=payload_bytes, ) - print(response.status_code) - print(response.text) + if response.status_code != 200: + logger.error(f"send email failed: {response.text}") diff --git a/src/retk/core/utils/ratelimiter.py b/src/retk/core/utils/ratelimiter.py index 3295c04..e0030f4 100644 --- a/src/retk/core/utils/ratelimiter.py +++ b/src/retk/core/utils/ratelimiter.py @@ -1,8 +1,19 @@ import asyncio +import json import time +from collections import OrderedDict from datetime import timedelta +from functools import wraps +from typing import Callable from typing import Union +from starlette.exceptions import HTTPException + +from retk import const +from retk.config import is_local_db +from retk.core.statistic import add_user_behavior +from retk.models.tps import AuthedUser + class RateLimiter: def __init__(self, requests: int, period: Union[int, float, timedelta]): @@ -40,3 +51,63 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): self.semaphore.release() return False + + +def req_limit(requests: int, in_seconds: int): + rate_limiter = OrderedDict() + + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + if is_local_db(): + return await func(*args, **kwargs) + + au: AuthedUser = kwargs.get('au') + ip: str = kwargs.get('ip') + if not au: + if not ip: + raise HTTPException(status_code=400, detail="Bad Request: no au and ip") + key = ip + else: + key = au.u.id + + current_time = time.time() + window_start = current_time - in_seconds + try: + user_visit = rate_limiter.pop(key) + except KeyError: + user_visit = [] + rate_limiter[key] = user_visit + else: + rate_limiter[key] = [timestamp for timestamp in user_visit if timestamp > window_start] + + # remove expired keys + for k in list(rate_limiter.keys()): + try: + if rate_limiter[k][-1] > window_start: + break + except IndexError: + break + + rate_limiter.pop(k) + + # check if rate limit exceeded + if len(user_visit) >= requests: + await add_user_behavior( + uid=key if au else "", + type_=const.UserBehaviorTypeEnum.RATE_LIMIT_EXCEEDED, + remark=json.dumps({"ip": ip if ip else "", "au": au.u.id if au else "", "func": func.__name__}), + ) + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded: max {requests} requests in {in_seconds} seconds" + ) + + res = await func(*args, **kwargs) + + rate_limiter[key].append(current_time) + return res + + return wrapper + + return decorator diff --git a/src/retk/routes/account.py b/src/retk/routes/account.py index bec7527..649df74 100644 --- a/src/retk/routes/account.py +++ b/src/retk/routes/account.py @@ -5,6 +5,7 @@ from retk import const from retk.controllers import schemas, account +from retk.core.utils.ratelimiter import req_limit from retk.routes import utils router = APIRouter( @@ -20,10 +21,12 @@ response_model=schemas.user.UserInfoResponse, ) @utils.measure_time_spend +@req_limit(requests=10, in_seconds=60) async def signup( req_id: utils.ANNOTATED_REQUEST_ID, req: schemas.account.SignupRequest, referer: Optional[str] = utils.DEPENDS_REFERER, + ip: Optional[str] = utils.DEPENDS_IP, ) -> JSONResponse: return await account.signup(req_id=req_id, req=req) @@ -38,7 +41,7 @@ async def auto_login( token: str = Cookie(alias=const.settings.COOKIE_ACCESS_TOKEN, default=""), request_id: str = Header( default="", alias="RequestId", max_length=const.settings.MD_MAX_LENGTH - ) + ), ) -> schemas.user.UserInfoResponse: return await account.auto_login( token=token, @@ -52,10 +55,12 @@ async def auto_login( response_model=schemas.user.UserInfoResponse, ) @utils.measure_time_spend +@req_limit(requests=10, in_seconds=60) async def login( req_id: utils.ANNOTATED_REQUEST_ID, req: schemas.account.LoginRequest, referer: Optional[str] = utils.DEPENDS_REFERER, + ip: Optional[str] = utils.DEPENDS_IP, ) -> JSONResponse: return await account.login(req_id=req_id, req=req) @@ -80,10 +85,12 @@ async def forget_password( response_model=schemas.account.TokenResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=60) async def email_verification( req_id: utils.ANNOTATED_REQUEST_ID, req: schemas.account.EmailVerificationRequest, referer: Optional[str] = utils.DEPENDS_REFERER, + ip: Optional[str] = utils.DEPENDS_IP, ) -> schemas.account.TokenResponse: return await account.email_send_code(req_id=req_id, req=req) @@ -94,6 +101,7 @@ async def email_verification( response_model=schemas.RequestIdResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=60) async def refresh_token( au: utils.ANNOTATED_REFRESH_TOKEN, # check refresh token expiration referer: Optional[str] = utils.DEPENDS_REFERER, diff --git a/src/retk/routes/app_captcha.py b/src/retk/routes/app_captcha.py index 9c056f1..a3725b5 100644 --- a/src/retk/routes/app_captcha.py +++ b/src/retk/routes/app_captcha.py @@ -4,6 +4,7 @@ from fastapi.responses import StreamingResponse from retk.controllers import account +from retk.core.utils.ratelimiter import req_limit from retk.routes import utils router = APIRouter( @@ -18,7 +19,9 @@ status_code=200, ) @utils.measure_time_spend +@req_limit(requests=6, in_seconds=30) async def get_captcha_img( referer: Optional[str] = utils.DEPENDS_REFERER, + ip: Optional[str] = utils.DEPENDS_IP, ) -> StreamingResponse: return account.get_captcha_img() diff --git a/src/retk/routes/browser_extension.py b/src/retk/routes/browser_extension.py index b85d492..4a12db6 100644 --- a/src/retk/routes/browser_extension.py +++ b/src/retk/routes/browser_extension.py @@ -4,6 +4,7 @@ from typing_extensions import Annotated from retk.controllers import schemas, browser_extension +from retk.core.utils.ratelimiter import req_limit from retk.routes import utils router = APIRouter( @@ -44,6 +45,7 @@ async def browser_extension_refresh_token( response_model=schemas.node.NodeResponse, ) @utils.measure_time_spend +@req_limit(requests=4, in_seconds=1) async def post_node_from_browser_extension( au: utils.ANNOTATED_AUTHED_USER_BROWSER_EXTENSION, url: Annotated[str, Form(...)], diff --git a/src/retk/routes/files.py b/src/retk/routes/files.py index d77c443..1d4a113 100644 --- a/src/retk/routes/files.py +++ b/src/retk/routes/files.py @@ -4,6 +4,7 @@ from retk.controllers import schemas from retk.controllers.files import upload_files +from retk.core.utils.ratelimiter import req_limit from retk.routes import utils router = APIRouter( @@ -19,6 +20,7 @@ response_model=schemas.RequestIdResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=30) async def upload_obsidian_files( au: utils.ANNOTATED_AUTHED_USER, files: List[UploadFile], @@ -36,6 +38,7 @@ async def upload_obsidian_files( response_model=schemas.RequestIdResponse, ) @utils.measure_time_spend +@req_limit(requests=10, in_seconds=30) async def upload_text_files( au: utils.ANNOTATED_AUTHED_USER, files: List[UploadFile], @@ -53,6 +56,7 @@ async def upload_text_files( response_model=schemas.files.VditorFilesResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def vditor_upload( au: utils.ANNOTATED_AUTHED_USER, req: Request, @@ -72,6 +76,7 @@ async def vditor_upload( response_model=schemas.files.VditorImagesResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def vditor_fetch_image( au: utils.ANNOTATED_AUTHED_USER, req: schemas.files.ImageVditorFetchRequest, @@ -89,6 +94,7 @@ async def vditor_fetch_image( response_model=schemas.files.FileUploadProcessResponse, ) @utils.measure_time_spend +@req_limit(requests=10, in_seconds=30) async def get_upload_process( au: utils.ANNOTATED_AUTHED_USER, referer: Optional[str] = utils.DEPENDS_REFERER, diff --git a/src/retk/routes/node.py b/src/retk/routes/node.py index eb56212..f84f1bc 100644 --- a/src/retk/routes/node.py +++ b/src/retk/routes/node.py @@ -8,6 +8,7 @@ from retk import const from retk.controllers import schemas from retk.controllers.node import node_ops, search +from retk.core.utils.ratelimiter import req_limit from retk.routes import utils router = APIRouter( @@ -23,6 +24,7 @@ response_model=schemas.node.NodeResponse, ) @utils.measure_time_spend +@req_limit(requests=10, in_seconds=1) async def post_node( au: utils.ANNOTATED_AUTHED_USER, req: schemas.node.CreateRequest, @@ -40,6 +42,7 @@ async def post_node( response_model=schemas.node.NodeResponse, ) @utils.measure_time_spend +@req_limit(requests=10, in_seconds=1) async def post_quick_node( au: utils.ANNOTATED_AUTHED_USER, req: schemas.node.CreateRequest, @@ -57,6 +60,7 @@ async def post_quick_node( response_model=schemas.node.NodesSearchResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def get_search_nodes( au: utils.ANNOTATED_AUTHED_USER, q: str = Query(max_length=const.settings.SEARCH_QUERY_MAX_LENGTH), @@ -141,6 +145,7 @@ async def get_node( response_model=schemas.node.NodesSearchResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def get_at_search( au: utils.ANNOTATED_AUTHED_USER, nid: str = utils.ANNOTATED_NID, @@ -164,6 +169,7 @@ async def get_at_search( response_model=schemas.node.NodesSearchResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def get_recommend_nodes( au: utils.ANNOTATED_AUTHED_USER, nid: str = utils.ANNOTATED_NID, @@ -241,6 +247,7 @@ async def put_node_md( response_model=schemas.RequestIdResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def put_node_favorite( au: utils.ANNOTATED_AUTHED_USER, nid: str = utils.ANNOTATED_NID, @@ -276,6 +283,7 @@ async def delete_node_favorite( status_code=200, ) @utils.measure_time_spend +@req_limit(requests=1, in_seconds=1) async def stream_export_md_node( au: utils.ANNOTATED_AUTHED_USER, nid: str = utils.ANNOTATED_NID, diff --git a/src/retk/routes/recent.py b/src/retk/routes/recent.py index 19b93d9..ce720d7 100644 --- a/src/retk/routes/recent.py +++ b/src/retk/routes/recent.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from retk.controllers import schemas, recent +from retk.core.utils.ratelimiter import req_limit from retk.routes import utils router = APIRouter( @@ -16,6 +17,7 @@ response_model=schemas.RequestIdResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def add_recent_at_node( au: utils.ANNOTATED_AUTHED_USER, req: schemas.recent.AtNodeRequest, @@ -32,6 +34,7 @@ async def add_recent_at_node( response_model=schemas.recent.GetRecentSearchResponse, ) @utils.measure_time_spend +@req_limit(requests=5, in_seconds=1) async def get_recent_searched( au: utils.ANNOTATED_AUTHED_USER, ) -> schemas.recent.GetRecentSearchResponse: diff --git a/src/retk/routes/utils.py b/src/retk/routes/utils.py index ce5c2a8..8a977b4 100644 --- a/src/retk/routes/utils.py +++ b/src/retk/routes/utils.py @@ -5,7 +5,7 @@ from typing import Optional import jwt -from fastapi import HTTPException, Header, Cookie, Depends +from fastapi import HTTPException, Header, Cookie, Depends, Request from fastapi.params import Path from starlette.status import HTTP_403_FORBIDDEN from typing_extensions import Annotated @@ -77,6 +77,23 @@ def verify_referer(referer: Optional[str] = Header(None)): return referer +def get_ip( + request: Request +) -> str: + forwarded_for = request.headers.get('X-Forwarded-For') + if forwarded_for: + ip = forwarded_for.split(',')[0].strip() + if ip: + return ip + + real_ip = request.headers.get('X-Real-IP') + if real_ip: + return real_ip + + client_ip = request.client.host + return client_ip + + async def on_startup(): if not config.is_local_db(): add_rotating_file_handler( @@ -291,3 +308,4 @@ async def process_browser_extension_refresh_token_headers( ANNOTATED_FID = Annotated[str, Path(title="The ID of file", max_length=const.settings.FID_MAX_LENGTH)] DEPENDS_REFERER = Depends(verify_referer) +DEPENDS_IP = Depends(get_ip) diff --git a/tests/test_api.py b/tests/test_api.py index b10aa1f..2b56e0c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -94,7 +94,7 @@ def test_email_verification(self, mock_send): }, headers={"RequestId": "xxx"} ) - self.assertEqual(200, resp.status_code) + self.assertEqual(200, resp.status_code, msg=resp.json()) rj = resp.json() self.assertNotEqual("", rj["token"]) self.assertEqual("xxx", rj["requestId"]) @@ -106,6 +106,7 @@ def setUpClass(cls) -> None: utils.set_env(".env.test.local") async def asyncSetUp(self) -> None: + os.makedirs(Path(__file__).parent / "temp", exist_ok=True) scheduler.start() await client.init() self.client = TestClient(app) diff --git a/tests/test_core_remote.py b/tests/test_core_remote.py index 53062ed..eb67cf6 100644 --- a/tests/test_core_remote.py +++ b/tests/test_core_remote.py @@ -1,4 +1,6 @@ import datetime +import os +import shutil import time import unittest from copy import deepcopy @@ -8,12 +10,14 @@ import elastic_transport import pymongo.errors from bson import ObjectId +from starlette.exceptions import HTTPException from retk import const, config, core from retk.controllers.schemas.user import PatchUserRequest from retk.core.account.manager import signup from retk.core.ai.llm.knowledge.extending import extend_on_node_update from retk.core.scheduler import tasks +from retk.core.utils import ratelimiter from retk.models import db_ops from retk.models.client import client from retk.models.tps import AuthedUser, convert_user_dict_to_authed_user @@ -729,3 +733,16 @@ async def test_mark_read(self, mock_batch_send): for s in sn: self.assertTrue(s["read"]) self.assertIsNotNone(s["readTime"]) + + async def test_req_limit(self, mock_batch_send): + + @ratelimiter.req_limit(requests=5, in_seconds=1) + async def test(ip="123"): + return True + + for _ in range(5): + self.assertTrue(await test(ip="123")) + with self.assertRaises(HTTPException): + await test(ip="123") + + shutil.rmtree(os.path.join(os.path.dirname(__file__), "analytics"), ignore_errors=True)