Skip to content

Commit

Permalink
add ability to inject fast-depends when function is called instead of…
Browse files Browse the repository at this point in the history
… created
  • Loading branch information
RuslanUC committed Sep 2, 2024
1 parent 9805f9f commit e5cb9bc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 20 deletions.
4 changes: 4 additions & 0 deletions config.example.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@
"client_secret": None,
},
}

# Use fast_depends.inject() when route function is called instead of when it is created. Speeds up Yepcord launch.
# May slow down first request to every route by ~50ms.
LAZY_INJECT = False
65 changes: 45 additions & 20 deletions yepcord/rest_api/y_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,65 +24,90 @@
from quart import Blueprint, g
from quart_schema import validate_request, validate_querystring

from yepcord.yepcord.config import Config

validate_funcs = {"body": validate_request, "qs": validate_querystring}


def apply_validator(func: T_route, type_: str, cls: Optional[type], source=None) -> T_route:
applied = getattr(func, "_patches", set())
def apply_validator(src_func: T_route, type_: str, cls: Optional[type], source=None) -> T_route:
applied = getattr(src_func, "_patches", set())

if cls is None or f"validate_{type_}" in applied or type_ not in validate_funcs:
return func
return src_func

kw = {} if source is None else {"source": source}
func = validate_funcs[type_](cls, **kw)(func)
func = validate_funcs[type_](cls, **kw)(src_func)

applied.add(f"validate_{type_}")
setattr(func, "_patches", applied)
if len(applied) > 1:
delattr(src_func, "_patches")

return func


def apply_inject(func: T_route) -> T_route:
applied = getattr(func, "_patches", set())
def apply_inject(src_func: T_route) -> T_route:
applied = getattr(src_func, "_patches", set())

if "fastdepends_inject" in applied:
return func
return src_func

if Config.LAZY_INJECT:
injected_func = None

@wraps(src_func)
async def func(*args, **kwargs):
nonlocal injected_func

if injected_func is None:
injected_func = inject(src_func)

return await injected_func(*args, **kwargs)
else:
func = inject(src_func)

func = inject(func)
applied.add("fastdepends_inject")
setattr(func, "_patches", applied)
if len(applied) > 1:
delattr(src_func, "_patches")

return func


def apply_allow_bots(func: T_route) -> T_route:
applied = getattr(func, "_patches", set())
def apply_allow_bots(src_func: T_route) -> T_route:
applied = getattr(src_func, "_patches", set())
if "allow_bots" in applied:
return func
return src_func

@wraps(func)
@wraps(src_func)
async def wrapped(*args, **kwargs):
g.bots_allowed = True
return await func(*args, **kwargs)
return await src_func(*args, **kwargs)

applied.add("allow_bots")
setattr(func, "_patches", applied)
setattr(wrapped, "_patches", applied)
if len(applied) > 1:
delattr(src_func, "_patches")

return wrapped


def apply_oauth(func: T_route, scopes: list[str]) -> T_route:
applied = getattr(func, "_patches", set())
def apply_oauth(src_func: T_route, scopes: list[str]) -> T_route:
applied = getattr(src_func, "_patches", set())
if "oauth" in applied:
return func
return src_func

@wraps(func)
@wraps(src_func)
async def wrapped(*args, **kwargs):
g.oauth_allowed = True
g.oauth_scopes = set(scopes)
return await func(*args, **kwargs)
return await src_func(*args, **kwargs)

applied.add("oauth")
setattr(func, "_patches", applied)
setattr(wrapped, "_patches", applied)
if len(applied) > 1:
delattr(src_func, "_patches")

return wrapped


Expand Down
18 changes: 18 additions & 0 deletions yepcord/yepcord/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class ConfigModel(BaseModel):
BCRYPT_ROUNDS: int = 15
CAPTCHA: ConfigCaptcha = Field(default_factory=ConfigCaptcha)
CONNECTIONS: ConfigConnections = Field(default_factory=ConfigConnections)
LAZY_INJECT: bool = False

@field_validator("KEY")
def validate_key(cls, value: str) -> str:
Expand Down Expand Up @@ -167,6 +168,23 @@ class Config:


class _Config(Singleton):
DB_CONNECT_STRING: str
MAIL_CONNECT_STRING: str
MIGRATIONS_DIR: str
KEY: str
PUBLIC_HOST: str
GATEWAY_HOST: str
CDN_HOST: str
STORAGE: dict
TENOR_KEY: Optional[str]
MESSAGE_BROKER: dict
REDIS_URL: Optional[str]
GATEWAY_KEEP_ALIVE_DELAY: int
BCRYPT_ROUNDS: int
CAPTCHA: dict
CONNECTIONS: dict
LAZY_INJECT: bool

def update(self, variables: dict) -> _Config:
self.__dict__.update(variables)
return self
Expand Down

0 comments on commit e5cb9bc

Please sign in to comment.