Skip to content

Commit

Permalink
Rename on_lost -> on_reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
levsh committed Dec 12, 2024
1 parent cbc685f commit 2ad7a00
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 43 deletions.
72 changes: 41 additions & 31 deletions rmqaio/rmqaio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from inspect import iscoroutine, iscoroutinefunction
from itertools import chain, repeat
from ssl import SSLContext
from typing import Any, Callable, Coroutine, Iterable, cast
from typing import Any, Callable, Coroutine, Iterable, Literal, cast
from uuid import uuid4

import aiormq
Expand Down Expand Up @@ -229,21 +229,20 @@ def __init__(
if self._key not in self.__shared:
self.__shared[self._key] = {
"refs": 0,
"objs": 0,
"url": urls[0],
"ssl_context": ssl_contexts[0],
"iter": _LoopIter(list(zip(urls, ssl_contexts))),
"conn": None,
"connect_lock": Lock(),
"instances": {},
}

shared: dict = self.__shared[self._key]
shared["objs"] += 1
shared[self] = {
shared["instances"][self] = {
"on_open": {},
"on_lost": {},
"on_reconnect": {},
"on_close": {},
"callback_tasks": {"on_open": {}, "on_lost": {}, "on_close": {}},
"callback_tasks": {"on_open": {}, "on_reconnect": {}, "on_close": {}},
}
self._shared = shared

Expand All @@ -259,11 +258,8 @@ def __del__(self):
if getattr(self, "_key", None):
if self._conn and not self.is_closed:
logger.warning(_("%s unclosed"), self)
shared = self._shared
shared["objs"] -= 1
if self in shared:
shared.pop(self, None)
if shared["objs"] == 0:
self._shared["instances"].pop(self, None)
if not self._shared["instances"]:
self.__shared.pop(self._key, None)

@property
Expand All @@ -290,11 +286,15 @@ def _refs(self) -> int:
def _refs(self, value: int):
self._shared["refs"] = value

async def _execute_callbacks(self, tp: str, reraise: bool | None = None):
async def _execute_callbacks(
self,
tp: Literal["on_open", "on_reconnect", "on_close"],
reraise: bool | None = None,
):
async def fn(name, callback):
logger.debug(_("%s execute callback[tp=%s, name=%s, reraise=%s]"), self, tp, name, reraise)

self._shared[self]["callback_tasks"][tp][name] = current_task()
self._shared["instances"][self]["callback_tasks"][tp][name] = current_task()
try:
if iscoroutinefunction(callback):
await callback()
Expand All @@ -307,20 +307,30 @@ async def fn(name, callback):
if reraise:
raise e
finally:
self._shared[self]["callback_tasks"][tp].pop(name, None)
self._shared["instances"][self]["callback_tasks"][tp].pop(name, None)

for name, callback in tuple(self._shared[self][tp].items()):
for name, callback in tuple(self._shared["instances"][self][tp].items()):
await create_task(fn(name, callback))

def set_callback(self, tp: str, name: Hashable, callback: Callable):
def set_callback(
self,
tp: Literal["on_open", "on_reconnect", "on_close"],
name: Hashable,
callback: Callable,
):
logger.debug(_("%s set callback[tp=%s, name=%s, callback=%s]"), self, tp, name, callback)
if shared := self._shared.get(self):
if shared := self._shared["instances"].get(self):
if tp not in shared:
raise ValueError("invalid callback type")
shared[tp][name] = callback

def remove_callback(self, tp: str, name: Hashable, cancel: bool | None = None):
if shared := self._shared.get(self):
def remove_callback(
self,
tp: Literal["on_open", "on_reconnect", "on_close"],
name: Hashable,
cancel: bool | None = None,
):
if shared := self._shared["instances"].get(self):
if tp not in shared:
raise ValueError("invalid callback type")
if name in shared[tp]:
Expand All @@ -331,16 +341,16 @@ def remove_callback(self, tp: str, name: Hashable, cancel: bool | None = None):
task.cancel()

def remove_callbacks(self, cancel: bool | None = None):
if self in self._shared:
if shared := self._shared["instances"].get(self):
if cancel:
for tp in ("on_open", "on_lost", "on_close"):
for task in self._shared[self]["callback_tasks"][tp].values():
for tp in ("on_open", "on_reconnect", "on_close"):
for task in shared["callback_tasks"][tp].values():
task.cancel()
self._shared[self] = {
self._shared["instances"][self] = {
"on_open": {},
"on_lost": {},
"on_reconnect": {},
"on_close": {},
"callback_tasks": {"on_open": {}, "on_lost": {}, "on_close": {}},
"callback_tasks": {"on_open": {}, "on_reconnect": {}, "on_close": {}},
}

def __str__(self):
Expand Down Expand Up @@ -377,7 +387,7 @@ async def _watcher(self):
await self._channel.close()
self._refs -= 1
self._reconnect_task = create_task(self.open(retry_timeouts=iter(chain((0, 3), repeat(5)))))
await self._execute_callbacks("on_lost")
await self._execute_callbacks("on_reconnect")

async def _connect(self):
self._shared["url"], self._shared["ssl_context"] = next(self._shared["iter"])
Expand Down Expand Up @@ -799,8 +809,8 @@ def __post_init__(self):
if self.conn_factory:
object.__setattr__(self, "conn", self.conn_factory())
self.conn.set_callback(
"on_lost",
f"on_lost_queue_[{self.name}]_cleanup_consumer",
"on_reconnect",
f"on_reconnect_queue_[{self.name}]_cleanup_consumer",
lambda: object.__setattr__(self, "consumer", None),
)
self.conn.set_callback(
Expand Down Expand Up @@ -1019,8 +1029,8 @@ async def consume(
logger.info(_("consume %s"), self)

self.conn.set_callback(
"on_lost",
f"on_lost_queue_[{self.name}]_consume",
"on_reconnect",
f"on_reconnect_queue_[{self.name}]_consume",
partial(
_retry(
retry_timeouts=repeat(retry_timeout),
Expand All @@ -1044,7 +1054,7 @@ async def stop_consume(self, timeout: int | None = None):

logger.info(_("stop consume %s"), self)

self.conn.remove_callback("on_lost", f"on_lost_queue_[{self.name}]_consume", cancel=True)
self.conn.remove_callback("on_reconnect", f"on_reconnect_queue_[{self.name}]_consume", cancel=True)

if self.consumer and not self.consumer.channel.is_closed:
await self.consumer.channel.basic_cancel(self.consumer.consumer_tag, timeout=timeout)
Expand Down
28 changes: 16 additions & 12 deletions tests/test_rmqaio.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ async def test__init(self):
assert conn.name == "abc"
assert conn._key is not None
assert conn._key in conn._Connection__shared
assert conn._shared["objs"] == 1
assert conn in conn._shared
assert conn._shared[conn] == {
assert len(conn._shared["instances"]) == 1
assert conn in conn._shared["instances"]
assert conn._shared["instances"][conn] == {
"on_open": {},
"on_lost": {},
"on_reconnect": {},
"on_close": {},
"callback_tasks": {"on_close": {}, "on_lost": {}, "on_open": {}},
"callback_tasks": {"on_close": {}, "on_reconnect": {}, "on_open": {}},
}
assert conn._conn is None
finally:
Expand Down Expand Up @@ -182,26 +182,30 @@ def cb_with_exception():
conn = rmqaio.Connection("amqp://admin@example.com")
try:
conn.set_callback("on_open", "test", on_open_cb)
assert conn._shared[conn]["on_open"]["test"] == on_open_cb
assert conn._shared["instances"][conn]["on_open"]["test"] == on_open_cb
conn.set_callback("on_close", "test", on_close_cb)
assert conn._shared[conn]["on_close"]["test"] == on_close_cb
assert conn._shared["instances"][conn]["on_close"]["test"] == on_close_cb

await conn._execute_callbacks("on_open")
assert on_open_cb_flag is True
assert on_close_cb_flag is False

conn.remove_callback("on_open", "test")
assert conn._shared[conn]["on_open"] == {}
assert conn._shared[conn]["on_close"]["test"] == on_close_cb
assert conn._shared["instances"][conn]["on_open"] == {}
assert conn._shared["instances"][conn]["on_close"]["test"] == on_close_cb

await conn._execute_callbacks("on_close")
assert on_open_cb_flag is True
assert on_close_cb_flag is True

conn.remove_callbacks(cancel=True)
assert conn._shared[conn]["on_open"] == {}
assert conn._shared[conn]["on_close"] == {}
assert conn._shared[conn]["callback_tasks"] == {"on_close": {}, "on_lost": {}, "on_open": {}}
assert conn._shared["instances"][conn]["on_open"] == {}
assert conn._shared["instances"][conn]["on_close"] == {}
assert conn._shared["instances"][conn]["callback_tasks"] == {
"on_close": {},
"on_reconnect": {},
"on_open": {},
}

conn.set_callback("on_open", "excpetion", cb_with_exception)
await conn._execute_callbacks("on_open")
Expand Down

0 comments on commit 2ad7a00

Please sign in to comment.