Skip to content

Commit

Permalink
fix: issue #186
Browse files Browse the repository at this point in the history
  • Loading branch information
vastsa committed Jul 28, 2024
1 parent 8136c85 commit fafc56b
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
8 changes: 4 additions & 4 deletions apps/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def get_random_code(style='num'):
return code


# 错误IP限制器
error_ip_limit = IPRateLimit(count=settings.errorCount, minutes=settings.errorMinute)
# 上传文件限制器
upload_ip_limit = IPRateLimit(count=settings.uploadCount, minutes=settings.errorMinute)
ip_limit = {
'error': IPRateLimit(count=settings.uploadCount, minutes=settings.errorMinute),
'upload': IPRateLimit(count=settings.errorCount, minutes=settings.errorMinute)
}
23 changes: 12 additions & 11 deletions apps/base/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from apps.admin.depends import admin_required
from apps.base.models import FileCodes
from apps.base.pydantics import SelectFileModel
from apps.base.utils import get_expire_info, get_file_path_name, error_ip_limit, upload_ip_limit
from apps.base.utils import get_expire_info, get_file_path_name, ip_limit
from core.response import APIResponse
from core.settings import settings
from core.storage import storages, FileStorageInterface
Expand All @@ -22,7 +22,7 @@

# 分享文本的API
@share_api.post('/text/', dependencies=[Depends(admin_required)])
async def share_text(text: str = Form(...), expire_value: int = Form(default=1, gt=0), expire_style: str = Form(default='day'), ip: str = Depends(upload_ip_limit)):
async def share_text(text: str = Form(...), expire_value: int = Form(default=1, gt=0), expire_style: str = Form(default='day'), ip: str = Depends(ip_limit['upload'])):
# 获取过期信息
expired_at, expired_count, used_count, code = await get_expire_info(expire_value, expire_style)
# 创建一个新的FileCodes实例
Expand All @@ -36,7 +36,7 @@ async def share_text(text: str = Form(...), expire_value: int = Form(default=1,
prefix='文本分享'
)
# 添加IP到限制列表
upload_ip_limit.add_ip(ip)
ip_limit['upload'].add_ip(ip)
# 返回API响应
return APIResponse(detail={
'code': code,
Expand All @@ -45,7 +45,8 @@ async def share_text(text: str = Form(...), expire_value: int = Form(default=1,

# 分享文件的API
@share_api.post('/file/', dependencies=[Depends(admin_required)])
async def share_file(expire_value: int = Form(default=1, gt=0), expire_style: str = Form(default='day'), file: UploadFile = File(...), ip: str = Depends(upload_ip_limit)):
async def share_file(expire_value: int = Form(default=1, gt=0), expire_style: str = Form(default='day'), file: UploadFile = File(...),
ip: str = Depends(ip_limit['upload'])):
# 检查文件大小是否超过限制
if file.size > int(settings.uploadSize):
raise HTTPException(status_code=403, detail=f'文件大小超过限制,最大为{settings.uploadSize}字节')
Expand All @@ -71,7 +72,7 @@ async def share_file(expire_value: int = Form(default=1, gt=0), expire_style: st
used_count=used_count,
)
# 添加IP到限制列表
upload_ip_limit.add_ip(ip)
ip_limit['upload'].add_ip(ip)
# 返回API响应
return APIResponse(detail={
'code': code,
Expand All @@ -94,14 +95,14 @@ async def get_code_file_by_code(code, check=True):

# 获取文件的API
@share_api.get('/select/')
async def get_code_file(code: str, ip: str = Depends(error_ip_limit)):
async def get_code_file(code: str, ip: str = Depends(ip_limit['error'])):
file_storage: FileStorageInterface = storages[settings.file_storage]()
# 获取文件
has, file_code = await get_code_file_by_code(code)
# 检查文件是否存在
if not has:
# 添加IP到限制列表
error_ip_limit.add_ip(ip)
ip_limit['error'].add_ip(ip)
# 返回API响应
return APIResponse(code=404, detail=file_code)
# 更新文件的使用次数和过期次数
Expand All @@ -116,14 +117,14 @@ async def get_code_file(code: str, ip: str = Depends(error_ip_limit)):

# 选择文件的API
@share_api.post('/select/')
async def select_file(data: SelectFileModel, ip: str = Depends(error_ip_limit)):
async def select_file(data: SelectFileModel, ip: str = Depends(ip_limit['error'])):
file_storage: FileStorageInterface = storages[settings.file_storage]()
# 获取文件
has, file_code = await get_code_file_by_code(data.code)
# 检查文件是否存在
if not has:
# 添加IP到限制列表
error_ip_limit.add_ip(ip)
ip_limit['error'].add_ip(ip)
# 返回API响应
return APIResponse(code=404, detail=file_code)
# 更新文件的使用次数和过期次数
Expand All @@ -143,13 +144,13 @@ async def select_file(data: SelectFileModel, ip: str = Depends(error_ip_limit)):

# 下载文件的API
@share_api.get('/download')
async def download_file(key: str, code: str, ip: str = Depends(error_ip_limit)):
async def download_file(key: str, code: str, ip: str = Depends(ip_limit['error'])):
file_storage: FileStorageInterface = storages[settings.file_storage]()
# 检查token是否有效
is_valid = await get_select_token(code) == key
if not is_valid:
# 添加IP到限制列表
error_ip_limit.add_ip(ip)
ip_limit['error'].add_ip(ip)
# 获取文件
has, file_code = await get_code_file_by_code(code, False)
# 检查文件是否存在
Expand Down
6 changes: 3 additions & 3 deletions core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tortoise.expressions import Q

from apps.base.models import FileCodes
from apps.base.utils import error_ip_limit, upload_ip_limit
from apps.base.utils import ip_limit
from core.settings import settings
from core.storage import FileStorageInterface, storages
from core.utils import get_now
Expand All @@ -17,8 +17,8 @@ async def delete_expire_files():
file_storage: FileStorageInterface = storages[settings.file_storage]()
while True:
try:
await error_ip_limit.remove_expired_ip()
await upload_ip_limit.remove_expired_ip()
await ip_limit['error'].remove_expired_ip()
await ip_limit['upload'].remove_expired_ip()
expire_data = await FileCodes.filter(Q(expired_at__lt=await get_now()) | Q(expired_count=0)).all()
for exp in expire_data:
await file_storage.delete_file(exp)
Expand Down
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from fastapi.staticfiles import StaticFiles
from tortoise.contrib.fastapi import register_tortoise

from apps.base.depends import IPRateLimit
from apps.base.models import KeyValue
from apps.base.utils import ip_limit
from apps.base.views import share_api
from apps.admin.views import admin_api
from core.response import APIResponse
Expand Down Expand Up @@ -59,6 +61,10 @@ async def startup_event():
# 读取用户配置
user_config, created = await KeyValue.get_or_create(key='settings', defaults={'value': DEFAULT_CONFIG})
settings.user_config = user_config.value
ip_limit['error'].minutes = settings.errorMinute
ip_limit['error'].count = settings.errorCount
ip_limit['upload'].minutes = settings.uploadMinute
ip_limit['upload'].count = settings.uploadCount


@app.get('/')
Expand Down

0 comments on commit fafc56b

Please sign in to comment.