-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🔀 Merge Pull Request: Remove Redis and socket connection bugs (#18)
- Loading branch information
Showing
12 changed files
with
152 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,92 @@ | ||
import asyncio | ||
import json | ||
|
||
import redis | ||
from fastapi import WebSocket | ||
from redis import asyncio as aioredis | ||
from starlette.websockets import WebSocketState | ||
|
||
from message_broker.message_broker import MessageBroker | ||
from settings import CHANNEL_ID | ||
|
||
|
||
class WebSocketBroker: | ||
def __init__(self): | ||
self.channel_id = None | ||
self.sockets: list = [] | ||
def __init__(self, channel_id: str): | ||
self.channel_id = channel_id | ||
self.sockets: list[WebSocket] = [] | ||
self.pubsub_client = MessageBroker() | ||
|
||
async def add_user_to_channel(self, channel_id: str, websocket: WebSocket) -> None: | ||
async def accept(self) -> None: | ||
""" | ||
Adds a user's WebSocket connection to a channel. | ||
:param channel_id: Channel ID to add user to. | ||
:param websocket: WebSocket connection object. | ||
Connects to Redis server and establish channel. | ||
""" | ||
await websocket.accept() | ||
self.sockets.append(websocket) | ||
|
||
if self.channel_id is None: | ||
self.channel_id = channel_id | ||
if not self.sockets: | ||
await self.pubsub_client.connect() | ||
ps_subscriber = await self.pubsub_client.subscribe(channel_id) | ||
ps_subscriber = await self.pubsub_client.subscribe(self.channel_id) | ||
asyncio.create_task(self._pubsub_data_reader(ps_subscriber)) | ||
|
||
async def add_client_to_channel(self, websocket: WebSocket) -> None: | ||
""" | ||
Adds a client's WebSocket connection to a channel. | ||
:param websocket: WebSocket connection object. | ||
""" | ||
self.sockets.append(websocket) | ||
|
||
async def broadcast_to_channel(self, channel_id: str, message: str) -> None: | ||
""" | ||
Broadcasts a message to all connected WebSockets in a channel. | ||
:param channel_id: Channel ID to publish to. | ||
:param message: Message to be broadcast. | ||
""" | ||
await self.pubsub_client.publish(channel_id, message) | ||
if self.sockets: | ||
await self.pubsub_client.publish(channel_id, message) | ||
|
||
async def _pubsub_data_reader(self, ps_subscriber: aioredis.Redis): | ||
""" | ||
Reads and broadcasts messages received from Redis PubSub. | ||
:param ps_subscriber: PubSub object for the subscribed channel. | ||
""" | ||
while True: | ||
message = await ps_subscriber.get_message(ignore_subscribe_messages=True) | ||
message = None | ||
try: | ||
message = await ps_subscriber.get_message( | ||
ignore_subscribe_messages=True | ||
) | ||
except redis.exceptions.ConnectionError: | ||
# TODO: Add logging. | ||
# TODO: Replace return handle when Redis is closed when server is running with a better | ||
# approach, perhaps a back-off algorithm could be added. | ||
return | ||
except Exception as e: | ||
# TODO: Implement handle of other Exceptions. | ||
pass | ||
|
||
if message: | ||
for socket in self.sockets: | ||
data = message["data"].decode("utf-8") | ||
await socket.send_json(json.loads(data)) | ||
if ( | ||
socket.application_state == WebSocketState.CONNECTED | ||
and socket.client_state == WebSocketState.CONNECTED | ||
): | ||
data = message["data"].decode("utf-8") | ||
await socket.send_json(json.loads(data)) | ||
|
||
async def remove_client_from_channel(self, websocket: WebSocket) -> None: | ||
""" | ||
Removes a client's WebSocket connection from a channel. | ||
:param websocket: WebSocket connection object. | ||
""" | ||
self.sockets.remove(websocket) | ||
|
||
async def close_sockets(self): | ||
""" | ||
Closes client sockets. | ||
""" | ||
for socket in self.sockets: | ||
if ( | ||
socket.application_state == WebSocketState.CONNECTED | ||
and socket.client_state == WebSocketState.CONNECTED | ||
): | ||
await socket.close() | ||
|
||
|
||
socket_broker = WebSocketBroker(CHANNEL_ID) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import signal | ||
from contextlib import asynccontextmanager | ||
|
||
from fastapi import FastAPI | ||
|
||
from logger import log_error, log_info_async | ||
from message_broker.websocket_broker import socket_broker | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
await log_info_async("Server starting.") | ||
yield | ||
await log_info_async("Server shutting down.") | ||
await socket_broker.close_sockets() | ||
|
||
|
||
def shutdown(): | ||
""" | ||
Used to shut down the ASGI server. | ||
""" | ||
log_error(f"Server shut down forcefully.") | ||
os.kill(os.getpid(), signal.SIGTERM) |
Oops, something went wrong.