Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance anonymous ratelimiting with dynamic thresholds and max size configuration #2086

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2066.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement `anonymous-ratelimit`
1 change: 1 addition & 0 deletions changes/2067.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Move `anonymous_ratelimit` to shared_config
1 change: 1 addition & 0 deletions changes/2074.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement hit counter for detecting `suspicious_ips`
1 change: 1 addition & 0 deletions changes/2075.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement `get_hot_anonymous_clients` to RateLimit API
1 change: 1 addition & 0 deletions changes/2083.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
truncate `suspicious_ips` when it exceed the max_size
18 changes: 18 additions & 0 deletions src/ai/backend/client/func/ratelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ..request import Request
from .base import BaseFunction, api_function

__all__ = ("RateLimit",)


class RateLimit(BaseFunction):
"""
Provides RateLimiting API functions.
"""

@api_function
@classmethod
async def get_hot_anonymous_clients(cls):
""" """
rqst = Request("GET", "/ratelimit/hot_anonymous_clients")
async with rqst.fetch() as resp:
return await resp.json()
3 changes: 3 additions & 0 deletions src/ai/backend/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class BaseSession(metaclass=abc.ABCMeta):
"Service",
"Model",
"QuotaScope",
"RateLimit",
)

aiohttp_session: aiohttp.ClientSession
Expand Down Expand Up @@ -306,6 +307,7 @@ def __init__(
from .func.manager import Manager
from .func.model import Model
from .func.quota_scope import QuotaScope
from .func.ratelimit import RateLimit
from .func.resource import Resource
from .func.scaling_group import ScalingGroup
from .func.server_log import ServerLog
Expand Down Expand Up @@ -343,6 +345,7 @@ def __init__(
self.Service = Service
self.Model = Model
self.QuotaScope = QuotaScope
self.RateLimit = RateLimit

@property
def proxy_mode(self) -> bool:
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/common/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Callable, Mapping, TypeVar

import aiohttp
from aiohttp import web
from async_timeout import timeout as _timeout

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,3 +53,12 @@ def find_free_port(bind_addr: str = "127.0.0.1") -> int:
s.bind((bind_addr, 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def get_client_ip(request: web.Request) -> str | None:
client_ip = request.headers.get("X-Forwarded-For")
if not client_ip and request.transport:
client_ip = request.transport.get_extra_info("peername")[0]
if not client_ip:
client_ip = request.remote
return client_ip
119 changes: 103 additions & 16 deletions src/ai/backend/manager/api/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from decimal import Decimal
from typing import Final, Iterable, Tuple

import aiohttp_cors
import attrs
from aiohttp import web
from aiotools import apartial

from ai.backend.common import redis_helper
from ai.backend.common.defs import REDIS_RLIM_DB
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.networking import get_client_ip
from ai.backend.common.types import RedisConnectionInfo
from ai.backend.manager.api.auth import superadmin_required
from ai.backend.manager.api.manager import READ_ALLOWED, server_status_required

from .context import RootContext
from .exceptions import RateLimitExceeded
Expand All @@ -27,19 +31,40 @@
# last-minute and first-minute bursts between the intervals.

_rlim_script = """
local access_key = KEYS[1]
local id_type = KEYS[1]
local id_value = KEYS[2]
local namespaced_id = id_type .. ":" .. id_value
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local request_id = tonumber(redis.call('INCR', '__request_id'))
if request_id >= 1e12 then
redis.call('SET', '__request_id', 1)
end
if redis.call('EXISTS', access_key) == 1 then
redis.call('ZREMRANGEBYSCORE', access_key, 0, now - window)
if redis.call('EXISTS', namespaced_id) == 1 then
redis.call('ZREMRANGEBYSCORE', namespaced_id, 0, now - window)
end
redis.call('ZADD', access_key, now, tostring(request_id))
redis.call('EXPIRE', access_key, window)
return redis.call('ZCARD', access_key)
redis.call('ZADD', namespaced_id, now, tostring(request_id))
redis.call('EXPIRE', namespaced_id, window)

local rolling_count = redis.call('ZCARD', namespaced_id)

if id_type == "ip" then
local rate_limit = tonumber(ARGV[3])
local suspicious_ips_maxsize = tonumber(ARGV[4])
local suspicious_ips_threshold_ratio = tonumber(ARGV[5])

-- Add the IP address to "suspicious_ips" only if rolling_count is greater than the threshold
if rolling_count >= rate_limit * suspicious_ips_threshold_ratio then
redis.call('ZADD', 'suspicious_ips', rolling_count, id_value)

local current_size = redis.call('ZCARD', 'suspicious_ips')
if current_size > suspicious_ips_maxsize then
redis.call('ZREMRANGEBYRANK', 'suspicious_ips', 0, 0)
end
end
end

return rolling_count
"""


Expand All @@ -50,17 +75,18 @@ async def rlim_middleware(
handler: WebRequestHandler,
) -> web.StreamResponse:
# This is a global middleware: request.app is the root app.
app_ctx: PrivateContext = app["ratelimit.context"]
app_ctx: RateLimitContext = app["ratelimit.context"]
now = Decimal(time.time()).quantize(_time_prec)
rr = app_ctx.redis_rlim

if request["is_authorized"]:
rate_limit = request["keypair"]["rate_limit"]
access_key = request["keypair"]["access_key"]
ret = await redis_helper.execute_script(
rr,
"ratelimit",
_rlim_script,
[access_key],
["access_key", access_key],
[str(now), str(_rlim_window)],
)
if ret is None:
Expand All @@ -76,23 +102,61 @@ async def rlim_middleware(
response.headers["X-RateLimit-Window"] = str(_rlim_window)
return response
else:
# No checks for rate limiting for non-authorized queries.
response = await handler(request)
response.headers["X-RateLimit-Limit"] = "1000"
response.headers["X-RateLimit-Remaining"] = "1000"
root_ctx: RootContext = app["_root.context"]
anonymous_ratelimiter = root_ctx.shared_config["anonymous_ratelimiter"]

ip_address = get_client_ip(request)

if not ip_address or anonymous_ratelimiter is None:
# No checks for rate limiting.
response = await handler(request)
# Arbitrary number for indicating no rate limiting.
response.headers["X-RateLimit-Limit"] = "1000"
response.headers["X-RateLimit-Remaining"] = "1000"
else:
rate_limit, suspicious_ips_maxsize, suspicious_ips_threshold_ratio = (
anonymous_ratelimiter["rlimit"],
anonymous_ratelimiter["suspicious_ips_maxsize"],
anonymous_ratelimiter["suspicious_ips_threshold_ratio"],
)
ret = await redis_helper.execute_script(
rr,
"ratelimit",
_rlim_script,
["ip", ip_address],
[
str(now),
str(_rlim_window),
str(rate_limit),
str(suspicious_ips_maxsize),
str(suspicious_ips_threshold_ratio),
],
)
if ret is None:
remaining = rate_limit
else:
rolling_count = int(ret)
remaining = rate_limit - rolling_count
if remaining < 0:
raise RateLimitExceeded

response = await handler(request)
response.headers["X-RateLimit-Limit"] = str(rate_limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)

response.headers["X-RateLimit-Window"] = str(_rlim_window)
return response


@attrs.define(slots=True, auto_attribs=True, init=False)
class PrivateContext:
class RateLimitContext:
redis_rlim: RedisConnectionInfo
redis_rlim_script: str


async def init(app: web.Application) -> None:
root_ctx: RootContext = app["_root.context"]
app_ctx: PrivateContext = app["ratelimit.context"]
app_ctx: RateLimitContext = app["ratelimit.context"]
app_ctx.redis_rlim = redis_helper.get_redis_object(
root_ctx.shared_config.data["redis"], name="ratelimit", db=REDIS_RLIM_DB
)
Expand All @@ -101,8 +165,26 @@ async def init(app: web.Application) -> None:
)


@server_status_required(READ_ALLOWED)
@superadmin_required
async def get_hot_anonymous_clients(request: web.Request) -> web.Response:
"""
Retrieve a dictionary of anonymous client IP addresses and their corresponding suspicion scores.
suspicion scores are based on the number of requests made by the client.
"""
log.info("RATELIMIT.GET_HOT_ANONYMOUS_CLIENTS ()")
rlimit_ctx: RateLimitContext = request.app["ratelimit.context"]
rr = rlimit_ctx.redis_rlim
result: list[tuple[bytes, float]] = await redis_helper.execute(
rr, lambda r: r.zrange("suspicious_ips", 0, -1, withscores=True)
)
suspicious_ips = {k.decode(): v for k, v in dict(result).items()}

return web.json_response(suspicious_ips, status=200)


async def shutdown(app: web.Application) -> None:
app_ctx: PrivateContext = app["ratelimit.context"]
app_ctx: RateLimitContext = app["ratelimit.context"]
await redis_helper.execute(app_ctx.redis_rlim, lambda r: r.flushdb())
await app_ctx.redis_rlim.close()

Expand All @@ -112,7 +194,12 @@ def create_app(
) -> Tuple[web.Application, Iterable[WebMiddleware]]:
app = web.Application()
app["api_versions"] = (1, 2, 3, 4)
app["ratelimit.context"] = PrivateContext()
app["ratelimit.context"] = RateLimitContext()
app["prefix"] = "ratelimit"
cors = aiohttp_cors.setup(app, defaults=default_cors_options)
add_route = app.router.add_route
cors.add(add_route("GET", "/hot_anonymous_clients", get_hot_anonymous_clients))

app.on_startup.append(init)
app.on_shutdown.append(shutdown)
# middleware must be wrapped by web.middleware at the outermost level.
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ def container_registry_serialize(v: dict[str, Any]) -> dict[str, str]:
},
).allow_extra("*"),
t.Key("roundrobin_states", default=None): t.Null | tx.RoundRobinStatesJSONString,
t.Key("anonymous_ratelimiter", default=None): t.Null
| t.Dict({
t.Key("rlimit"): t.ToInt(),
t.Key("suspicious_ips_maxsize", default=1000): t.Null | t.ToInt(),
t.Key("suspicious_ips_threshold_ratio", default=0.8): t.Null | t.ToFloat(),
}),
}).allow_extra("*")

_volume_defaults: dict[str, Any] = {
Expand Down
10 changes: 1 addition & 9 deletions src/ai/backend/web/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ai.backend.client.config import APIConfig
from ai.backend.client.session import AsyncSession as APISession
from ai.backend.common.networking import get_client_ip
from ai.backend.common.web.session import get_session

from . import user_agent
Expand Down Expand Up @@ -77,15 +78,6 @@ async def get_anonymous_session(
return APISession(config=api_config, proxy_mode=True)


def get_client_ip(request: web.Request) -> Optional[str]:
client_ip = request.headers.get("X-Forwarded-For")
if not client_ip and request.transport:
client_ip = request.transport.get_extra_info("peername")[0]
if not client_ip:
client_ip = request.remote
return client_ip


def fill_forwarding_hdrs_to_api_session(
request: web.Request,
api_session: APISession,
Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
from ai.backend.client.session import AsyncSession as APISession
from ai.backend.common import config, redis_helper
from ai.backend.common.logging import BraceStyleAdapter, Logger
from ai.backend.common.networking import get_client_ip
from ai.backend.common.types import LogSeverity
from ai.backend.common.web.session import extra_config_headers, get_session
from ai.backend.common.web.session import setup as setup_session
from ai.backend.common.web.session.redis_storage import RedisStorage

from . import __version__, user_agent
from .auth import fill_forwarding_hdrs_to_api_session, get_client_ip
from .auth import fill_forwarding_hdrs_to_api_session
from .config import config_iv
from .proxy import decrypt_payload, web_handler, web_plugin_handler, websocket_handler
from .stats import WebStats, track_active_handlers, view_stats
Expand Down
Loading