From bb84f2bac72e638883f9186acbc6b0bd15014b57 Mon Sep 17 00:00:00 2001 From: hawang-wish <130547790+hawang-wish@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:38:22 +0800 Subject: [PATCH 1/2] Added on_send_error middleware hook --- taskiq/abc/middleware.py | 21 +++++++- taskiq/kicker.py | 22 +++++++- tests/middlewares/test_hooks.py | 96 +++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 3 deletions(-) create mode 100644 tests/middlewares/test_hooks.py diff --git a/taskiq/abc/middleware.py b/taskiq/abc/middleware.py index 5edcf80..bc57ae2 100644 --- a/taskiq/abc/middleware.py +++ b/taskiq/abc/middleware.py @@ -2,7 +2,7 @@ if TYPE_CHECKING: # pragma: no cover # pragma: no cover from taskiq.abc.broker import AsyncBroker - from taskiq.message import TaskiqMessage + from taskiq.message import BrokerMessage, TaskiqMessage from taskiq.result import TaskiqResult @@ -126,3 +126,22 @@ def on_error( :param result: returned value. :param exception: found exception. """ + + def on_send_error( + self, + message: "TaskiqMessage", + broker_message: "BrokerMessage", + exception: BaseException, + ) -> "Union[Union[bool, None], Coroutine[Any, Any, Union[bool, None]]]": + """ + This function is called when exception is raised while sending a message. + + In most cases, it would be a connection issue from the broker. + + Any exceptions occurred by broker's formatter will not trigger this. + + :param message: the sending TaskiqMessage (not BrokerMessage) + :param broker_message: the sending BrokerMessage (not TaskiqMessage) + :param exception: exception, not yet wrapped with SendTaskError + :return: True if the error should be omitted, False or None otherwise. + """ diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 889dee0..4979d9f 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -134,10 +134,28 @@ async def kiq( for middleware in self.broker.middlewares: if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: message = await maybe_awaitable(middleware.pre_send(message)) + + broker_message = self.broker.formatter.dumps(message) try: - await self.broker.kick(self.broker.formatter.dumps(message)) + await self.broker.kick(broker_message) except Exception as exc: - raise SendTaskError from exc + omitting = False + for middleware in reversed(self.broker.middlewares): + if middleware.__class__.on_send_error != TaskiqMiddleware.on_send_error: + omitting = ( + bool( + await maybe_awaitable( + middleware.on_send_error( + message, + broker_message, + exc, + ), + ), + ) + or omitting + ) + if not omitting: + raise SendTaskError from exc for middleware in self.broker.middlewares: if middleware.__class__.post_send != TaskiqMiddleware.post_send: diff --git a/tests/middlewares/test_hooks.py b/tests/middlewares/test_hooks.py new file mode 100644 index 0000000..1ff2002 --- /dev/null +++ b/tests/middlewares/test_hooks.py @@ -0,0 +1,96 @@ +import asyncio + +import pytest + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.exceptions import SendTaskError +from taskiq.message import BrokerMessage, TaskiqMessage + + +@pytest.mark.anyio +async def test_on_send_error() -> None: + caught = [] + + class _TestMiddleware(TaskiqMiddleware): + def on_send_error( + self, + message: "TaskiqMessage", + broker_message: "BrokerMessage", + exception: BaseException, + ) -> bool: + caught.append(1) + return True + + broker = InMemoryBroker().with_middlewares(_TestMiddleware()) + + broker.kick = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("test")) # type: ignore + + await broker.startup() + await broker.task(lambda: None).kiq() + await broker.shutdown() + + assert caught == [1] + + +@pytest.mark.anyio +async def test_on_send_error_raise() -> None: + caught = [] + + class _TestMiddleware(TaskiqMiddleware): + def on_send_error( + self, + message: "TaskiqMessage", + broker_message: "BrokerMessage", + exception: BaseException, + ) -> None: + caught.append(0) + + broker = InMemoryBroker().with_middlewares(_TestMiddleware()) + + broker.kick = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("test")) # type: ignore + + await broker.startup() + + with pytest.raises(SendTaskError): + await broker.task(lambda: None).kiq() + + await broker.shutdown() + + assert caught == [0] + + +@pytest.mark.anyio +async def test_on_send_error_inverted() -> None: + caught = [] + + class _TestMiddleware1(TaskiqMiddleware): + def on_send_error( + self, + message: "TaskiqMessage", + broker_message: "BrokerMessage", + exception: BaseException, + ) -> bool: + caught.append(1) + return True + + class _TestMiddleware2(TaskiqMiddleware): + async def on_send_error( + self, + message: "TaskiqMessage", + broker_message: "BrokerMessage", + exception: BaseException, + ) -> bool: + await asyncio.sleep(0) + caught.append(2) + return True + + broker = InMemoryBroker().with_middlewares(_TestMiddleware1(), _TestMiddleware2()) + + broker.kick = lambda *args, **kwargs: (_ for _ in ()).throw(Exception("test")) # type: ignore + + await broker.startup() + await broker.task(lambda: None).kiq() + await broker.shutdown() + + assert caught == [2, 1] From 881c4d6766e4423ca1a086efb0971c0c4aeafbb0 Mon Sep 17 00:00:00 2001 From: Harry Wang Date: Tue, 3 Sep 2024 13:05:19 +0800 Subject: [PATCH 2/2] Fixed that the exception param typing should be Exception --- taskiq/abc/middleware.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/taskiq/abc/middleware.py b/taskiq/abc/middleware.py index bc57ae2..d594cb8 100644 --- a/taskiq/abc/middleware.py +++ b/taskiq/abc/middleware.py @@ -131,7 +131,7 @@ def on_send_error( self, message: "TaskiqMessage", broker_message: "BrokerMessage", - exception: BaseException, + exception: Exception, ) -> "Union[Union[bool, None], Coroutine[Any, Any, Union[bool, None]]]": """ This function is called when exception is raised while sending a message. @@ -140,6 +140,9 @@ def on_send_error( Any exceptions occurred by broker's formatter will not trigger this. + SystemExit, KeyboardInterrupt as well as other BaseExceptions will not + be caught here as it would be essentially meaningless to catch them. + :param message: the sending TaskiqMessage (not BrokerMessage) :param broker_message: the sending BrokerMessage (not TaskiqMessage) :param exception: exception, not yet wrapped with SendTaskError