Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Preserve the event loop in a Queue when connecting #184

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,15 @@ def __init__(
self._dump = dump or json.dumps
self._load = load or json.loads
self._before_enqueues: dict[int, BeforeEnqueueType] = {}
self._loop: asyncio.AbstractEventLoop | None = None

def job_id(self, job_key: str) -> str:
return job_key

@property
def loop(self) -> asyncio.AbstractEventLoop:
return self._loop or asyncio.get_running_loop()

@abstractmethod
async def disconnect(self) -> None:
pass
Expand Down Expand Up @@ -167,7 +172,7 @@ def from_url(url: str, **kwargs: t.Any) -> Queue:
return HttpQueue.from_url(url, **kwargs)

async def connect(self) -> None:
pass
self._loop = asyncio.get_running_loop()

def serialize(self, job: Job) -> bytes | str:
return self._dump(job.to_dict())
Expand Down
1 change: 1 addition & 0 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
async def connect(self) -> None:
if not self.session:
self.session = ClientSession(**self.session_kwargs)
await super().connect()

async def disconnect(self) -> None:
if self.session:
Expand Down
2 changes: 2 additions & 0 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ async def connect(self) -> None:
await self.pool.resize(min_size=self.min_size, max_size=self.max_size)
await self.init_db()

await super().connect()

def serialize(self, job: Job) -> bytes | str:
"""Ensure serialized job is in bytes because the job column is of type BYTEA."""
serialized = self._dump(job.to_dict())
Expand Down
Loading