Skip to content

Commit

Permalink
Merge pull request #336 from Olegt0rr/feature/aiogram-inject-once
Browse files Browse the repository at this point in the history
Inject aiogram handlers once
  • Loading branch information
Tishka17 authored Jan 10, 2025
2 parents e74b5b6 + 4a02fa9 commit bb5335f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
58 changes: 46 additions & 12 deletions src/dishka/integrations/aiogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
"CONTAINER_NAME",
"FromDishka",
"inject",
"inject_handler",
"inject_router",
"setup_dishka",
]

import warnings
from collections.abc import Awaitable, Callable, Container
from functools import partial
from inspect import Parameter, signature
from typing import Any, Final, ParamSpec, TypeVar, cast

Expand Down Expand Up @@ -61,6 +65,14 @@ async def __call__(


class AutoInjectMiddleware(BaseMiddleware):
def __init__(self):
warnings.warn(
f"{self.__class__.__name__} is slow, "
"use `setup_dishka` instead if you care about performance",
UserWarning,
stacklevel=2,
)

async def __call__(
self,
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
Expand All @@ -71,15 +83,7 @@ async def __call__(
if is_dishka_injected(old_handler.callback):
return await handler(event, data)

new_handler = HandlerObject(
callback=inject(old_handler.callback),
filters=old_handler.filters,
flags=old_handler.flags,
)
old_handler.callback = new_handler.callback
old_handler.params = new_handler.params
old_handler.varkw = new_handler.varkw
old_handler.awaitable = new_handler.awaitable
inject_handler(old_handler)
return await handler(event, data)


Expand All @@ -90,9 +94,39 @@ def setup_dishka(
auto_inject: bool = False,
) -> None:
middleware = ContainerMiddleware(container)
auto_inject_middleware = AutoInjectMiddleware()

for observer in router.observers.values():
observer.outer_middleware(middleware)
if auto_inject and observer.event_name != "update":
observer.middleware(auto_inject_middleware)

if auto_inject:
callback = partial(inject_router, router=router)
router.startup.register(callback)


def inject_router(router: Router) -> None:
"""Inject dishka to the router handlers."""
for observer in router.observers.values():
if observer.event_name == "update":
continue

for handler in observer.handlers:
if not is_dishka_injected(handler.callback):
inject_handler(handler)


def inject_handler(handler: HandlerObject) -> HandlerObject:
"""Inject dishka for callback in aiogram's handler."""
# temp_handler is used to apply original __post_init__ processing
# for callback object wrapped by injector
temp_handler = HandlerObject(
callback=inject(handler.callback),
filters=handler.filters,
flags=handler.flags,
)

# since injector modified callback and params,
# we should update them in the original handler
handler.callback = temp_handler.callback
handler.params = temp_handler.params

return handler
28 changes: 28 additions & 0 deletions tests/integrations/aiogram/test_aiogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ async def dishka_app(handler, provider):
await container.close()


@asynccontextmanager
async def dishka_early_inject_app(handler, provider):
dp = Dispatcher()

# first apply auto_inject
container = make_async_container(provider)
setup_dishka(container, router=dp, auto_inject=True)

# then register raw handler
dp.message.register(handler)

await dp.emit_startup()
yield dp
await dp.emit_shutdown()
await container.close()


@asynccontextmanager
async def dishka_auto_app(handler, provider):
dp = Dispatcher()
Expand Down Expand Up @@ -120,3 +137,14 @@ async def handle_get_async_container(
async def test_get_async_container(bot, app_provider: AppProvider):
async with dishka_app(handle_get_async_container, app_provider) as dp:
await send_message(bot, dp)


@pytest.mark.asyncio
async def test_early_autoinject(bot, app_provider: AppProvider):
async with dishka_early_inject_app(
handler=handle_with_request,
provider=app_provider,
) as dp:
await send_message(bot, dp)
app_provider.mock.assert_called_with(REQUEST_DEP_VALUE)
app_provider.request_released.assert_called_once()

0 comments on commit bb5335f

Please sign in to comment.