From 896bcd4ce5ebc0e80c09b7169c3d8129a8dd2e4c Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 18 Mar 2024 18:08:39 +0200 Subject: [PATCH] implement connecting/disconnecting from voice channel --- config.example.py | 1 + yepcord/gateway/events.py | 30 ++++++------- yepcord/gateway/gateway.py | 61 +++++++++++++++++++++------ yepcord/voice_gateway/gateway.py | 52 +++++++++-------------- yepcord/voice_gateway/go_rpc.py | 30 +++++++++++++ yepcord/voice_gateway/main.py | 16 ++++--- yepcord/yepcord/config.py | 1 + yepcord/yepcord/gateway_dispatcher.py | 6 ++- yepcord/yepcord/models/__init__.py | 1 + yepcord/yepcord/models/voice_state.py | 37 ++++++++++++++++ 10 files changed, 164 insertions(+), 71 deletions(-) create mode 100644 yepcord/voice_gateway/go_rpc.py create mode 100644 yepcord/yepcord/models/voice_state.py diff --git a/config.example.py b/config.example.py index a27eb6d..dbf76fc 100644 --- a/config.example.py +++ b/config.example.py @@ -13,6 +13,7 @@ PUBLIC_HOST = "127.0.0.1:8080" GATEWAY_HOST = "127.0.0.1:8000/gateway" +VOICE_GATEWAY_HOST = "127.0.0.1:8000/voice" CDN_HOST = "127.0.0.1:8000/media" STORAGE = { diff --git a/yepcord/gateway/events.py b/yepcord/gateway/events.py index f8d5716..2c2b7b1 100644 --- a/yepcord/gateway/events.py +++ b/yepcord/gateway/events.py @@ -29,7 +29,7 @@ from ..yepcord.snowflake import Snowflake if TYPE_CHECKING: # pragma: no cover - from ..yepcord.models import Channel, Invite, GuildMember, UserData, User, UserSettings, Guild + from ..yepcord.models import Channel, Invite, GuildMember, UserData, User, UserSettings, Guild, VoiceState from ..yepcord.core import Core from .gateway import GatewayClient from .presences import Presence @@ -173,16 +173,16 @@ def __init__(self, friends_presences: list[dict], guilds_ids: list[int]): self.guilds_ids = guilds_ids async def json(self) -> dict: - g = [{"voice_states": [], "id": str(i), "embedded_activities": []} for i in self.guilds_ids] # TODO + g = [{"voice_states": [], "id": str(i), "embedded_activities": []} for i in self.guilds_ids] # TODO: add voice states return { "t": self.NAME, "op": self.OP, "d": { "merged_presences": { - "guilds": [[]], # TODO + "guilds": [[]], "friends": self.friends_presences }, - "merged_members": [[]], # TODO + "merged_members": [[]], "guilds": g } } @@ -1059,7 +1059,7 @@ async def json(self) -> dict: class VoiceStateUpdate(DispatchEvent): NAME = "VOICE_STATE_UPDATE" - def __init__(self, user_id: int, session_id: str, channel: Channel, guild: Optional[Guild], + def __init__(self, user_id: int, session_id: str, channel: Optional[Channel], guild: Optional[Guild], member: Optional[GuildMember], **kwargs): self.user_id = user_id self.session_id = session_id @@ -1074,7 +1074,7 @@ async def json(self) -> dict: "op": self.OP, "d": { "user_id": str(self.user_id), - "channel_id": str(self.channel.id), + "channel_id": str(self.channel.id) if self.channel is not None else None, "deaf": False, "mute": False, "session_id": self.session_id, @@ -1090,22 +1090,22 @@ async def json(self) -> dict: class VoiceServerUpdate(DispatchEvent): NAME = "VOICE_SERVER_UPDATE" - def __init__(self, channel: Channel, guild: Optional[Guild]): - self.channel = channel - self.guild = guild + def __init__(self, voice_state: VoiceState): + self.state = voice_state async def json(self) -> dict: data = { "t": self.NAME, "op": self.OP, "d": { - "token": "idk_token", - "endpoint": "127.0.0.1:8000/voice" + "token": f"{self.state.id}.{self.state.token}", + "endpoint": Config.VOICE_GATEWAY_HOST } } - if self.guild: - data["d"]["guild_id"] = str(self.guild.id) - if self.channel: - data["d"]["channel_id"] = str(self.channel.id) + if self.state.guild: + data["d"]["guild_id"] = str(self.state.guild.id) + if self.state.channel: + data["d"]["channel_id"] = str(self.state.channel.id) + print(data) return data diff --git a/yepcord/gateway/gateway.py b/yepcord/gateway/gateway.py index b2f74ae..37341dc 100644 --- a/yepcord/gateway/gateway.py +++ b/yepcord/gateway/gateway.py @@ -31,8 +31,10 @@ from ..yepcord.classes.fakeredis import FakeRedis from ..yepcord.core import Core from ..yepcord.ctx import getCore -from ..yepcord.enums import GatewayOp +from ..yepcord.enums import GatewayOp, GuildPermissions +from ..yepcord.gateway_dispatcher import GatewayDispatcher from ..yepcord.models import Session, User, UserSettings, Bot, GuildMember +from ..yepcord.models.voice_state import VoiceState from ..yepcord.mq_broker import getBroker @@ -180,22 +182,53 @@ async def handle_VOICE_STATE(self, data: dict) -> None: self_mute = bool(data.get("self_mute")) self_deaf = bool(data.get("self_deaf")) - if not (channel := await getCore().getChannel(data.get("channel_id"))): return - if not await getCore().getUserByChannel(channel, self.user_id): return - - guild = None - member = None - if guild_id := data.get("guild_id"): - if (guild := await getCore().getGuild(guild_id)) is None \ - or (member := await getCore().getGuildMember(guild, self.user_id)) is None: + voice_state = await VoiceState.get_or_none(user__id=self.user_id).select_related("channel", "guild") + if voice_state is not None: + if voice_state.session_id != self.sid and data["channel_id"] is None: return + if str(voice_state.guild.id) != data["guild_id"] and voice_state.session_id == self.sid: + member = await getCore().getGuildMember(voice_state.guild, self.user_id) + voice_event = VoiceStateUpdate( + self.id, self.sid, None, voice_state.guild, member, self_mute=self_mute, self_deaf=self_deaf + ) + await self.gateway.broker.publish(channel="yepcord_events", message={ + "data": await voice_event.json(), + "event": voice_event.NAME, + **(await GatewayDispatcher.getChannelFilter(voice_state.channel, GuildPermissions.VIEW_CHANNEL)) + }) + if data["guild_id"] is None: + return await voice_state.delete() - print(f"Connecting to voice with session_id={self.sid}") + if data["channel_id"] is None or data["guild_id"] is None: + return + if not (channel := await getCore().getChannel(data["channel_id"])): return + if not await getCore().getUserByChannel(channel, self.user_id): return + if (guild := await getCore().getGuild(data["guild_id"])) is None or channel.guild != guild or \ + (member := await getCore().getGuildMember(guild, self.user_id)) is None: + return - await self.esend(VoiceStateUpdate( - self.id, self.sid, channel, guild, member, self_mute=self_mute, self_deaf=self_deaf - )) - await self.esend(VoiceServerUpdate(channel, guild)) + if voice_state is not None: + await voice_state.update(guild=guild, channel=channel, session_id=self.sid) + else: + voice_state = await VoiceState.create(guild=guild, channel=channel, user=member.user, session_id=self.sid) + + if member is None: + member = await getCore().getGuildMember(voice_state.guild, self.user_id) + + voice_event = VoiceStateUpdate( + self.id, self.sid, voice_state.channel, voice_state.guild, member, self_mute=self_mute, self_deaf=self_deaf + ) + await self.gateway.mcl_yepcordEventsCallback({ + "data": await voice_event.json(), + "event": voice_event.NAME, + "user_ids": None, + "guild_id": None, + "role_ids": None, + "session_id": None, + "exclude": [], + } | await GatewayDispatcher.getChannelFilter(voice_state.channel, GuildPermissions.VIEW_CHANNEL)) + await self.esend(VoiceServerUpdate(voice_state)) + print("should connect now") class GatewayEvents: diff --git a/yepcord/voice_gateway/gateway.py b/yepcord/voice_gateway/gateway.py index 6296c53..990d856 100644 --- a/yepcord/voice_gateway/gateway.py +++ b/yepcord/voice_gateway/gateway.py @@ -10,37 +10,11 @@ from yepcord.yepcord.enums import VoiceGatewayOp from .default_sdp import DEFAULT_SDP from .events import Event, ReadyEvent, SpeakingEvent, UdpSessionDescriptionEvent, RtcSessionDescriptionEvent +from .go_rpc import GoRpc from .schemas import SelectProtocol from ..gateway.utils import require_auth from ..yepcord.config import Config - - -class GoRpc: - def __init__(self, rpc_addr: str): - self._address = f"http://{rpc_addr}/rpc" - - async def create_endpoint(self, channel_id: int) -> Optional[int]: - async with AsyncClient() as cl: - resp = await cl.post(self._address, json={ - # TODO: auto port allocation (on golang side) - "id": 0, "method": "Rpc.CreateApi", "params": [{"channel_id": str(channel_id), "port": 3791}] - }) - j = resp.json() - if j["error"] is None: - return j["result"] - print(j["error"]) - - async def create_peer_connection(self, channel_id: int, session_id: int, offer: str) -> Optional[str]: - async with AsyncClient() as cl: - resp = await cl.post(self._address, json={ - "id": 0, "method": "Rpc.NewPeerConnection", "params": [ - {"channel_id": str(channel_id), "session_id": str(session_id), "offer": offer} - ] - }) - j = resp.json() - if j["error"] is None: - return j["result"] - print(j["error"]) +from ..yepcord.models import VoiceState class GatewayClient: @@ -49,7 +23,7 @@ def __init__(self, ws: Websocket, gw: Gateway): self.user_id = None self.session_id = None self.guild_id = None - self.token = None + self.channel_id = None self.ssrc = 0 self.video_ssrc = 0 self.rtx_ssrc = 0 @@ -67,13 +41,25 @@ async def esend(self, event: Event): async def handle_IDENTIFY(self, data: dict): print(f"Connected to voice with session_id={data['session_id']}") - if data["token"] != "idk_token": + try: + token = data["token"].split(".") + if len(token) != 2: + raise ValueError() + state_id, token = token + state_id = int(state_id) + + state = await VoiceState.get_or_none(id=state_id, token=token).select_related("user", "guild", "channel") + if state is None: + raise ValueError + except ValueError: return await self.ws.close(4004) self.user_id = int(data["user_id"]) self.session_id = data["session_id"] self.guild_id = int(data["server_id"]) - self.token = data["token"] + self.channel_id = state.channel.id + if self.user_id != state.user.id or self.guild_id != state.guild.id: + return await self.ws.close(4004) self.ssrc = self._gw.ssrc self._gw.ssrc += 1 @@ -86,7 +72,7 @@ async def handle_IDENTIFY(self, data: dict): port = 0 rpc = self._gw.rpc(self.guild_id) if rpc is not None: - port = await rpc.create_endpoint(self.guild_id) + port = await rpc.create_endpoint(self.channel_id) await self.esend(ReadyEvent(self.ssrc, self.video_ssrc, self.rtx_ssrc, ip, port)) @@ -114,7 +100,7 @@ async def handle_SELECT_PROTOCOL(self, data: dict): sdp = "v=0\r\n" + str(self.sdp) + "\r\n" - answer = await rpc.create_peer_connection(self.guild_id, self.session_id, sdp) + answer = await rpc.create_peer_connection(self.channel_id, self.session_id, sdp) sdp = SDPInfo.parse(answer) c = sdp.candidates[0] diff --git a/yepcord/voice_gateway/go_rpc.py b/yepcord/voice_gateway/go_rpc.py new file mode 100644 index 0000000..88a09da --- /dev/null +++ b/yepcord/voice_gateway/go_rpc.py @@ -0,0 +1,30 @@ +from typing import Optional + +from httpx import AsyncClient + + +class GoRpc: + def __init__(self, rpc_addr: str): + self._address = f"http://{rpc_addr}/rpc" + + async def create_endpoint(self, channel_id: int) -> Optional[int]: + async with AsyncClient() as cl: + resp = await cl.post(self._address, json={ + "id": 0, "method": "Rpc.CreateApi", "params": [{"channel_id": str(channel_id)}] + }) + j = resp.json() + if j["error"] is None: + return j["result"] + print(j["error"]) + + async def create_peer_connection(self, channel_id: int, session_id: int, offer: str) -> Optional[str]: + async with AsyncClient() as cl: + resp = await cl.post(self._address, json={ + "id": 0, "method": "Rpc.NewPeerConnection", "params": [ + {"channel_id": str(channel_id), "session_id": str(session_id), "offer": offer} + ] + }) + j = resp.json() + if j["error"] is None: + return j["result"] + print(j["error"]) diff --git a/yepcord/voice_gateway/main.py b/yepcord/voice_gateway/main.py index 68535c1..07e560d 100644 --- a/yepcord/voice_gateway/main.py +++ b/yepcord/voice_gateway/main.py @@ -1,8 +1,10 @@ from asyncio import CancelledError from quart import Quart, websocket, Websocket +from tortoise.contrib.quart import register_tortoise from .gateway import Gateway +from ..yepcord.config import Config class YEPcord(Quart): @@ -35,10 +37,10 @@ async def ws_gateway_voice(): except CancelledError: raise -# ? -# register_tortoise( -# app, -# db_url=Config.DB_CONNECT_STRING, -# modules={"models": ["yepcord.yepcord.models"]}, -# generate_schemas=False, -# ) + +register_tortoise( + app, + db_url=Config.DB_CONNECT_STRING, + modules={"models": ["yepcord.yepcord.models"]}, + generate_schemas=False, +) diff --git a/yepcord/yepcord/config.py b/yepcord/yepcord/config.py index 5335aa6..38e48fa 100644 --- a/yepcord/yepcord/config.py +++ b/yepcord/yepcord/config.py @@ -123,6 +123,7 @@ class ConfigModel(BaseModel): MIGRATIONS_DIR: str = "./migrations" KEY: str = "XUJHVU0nUn51TifQuy9H1j0gId0JqhQ+PUz16a2WOXE=" PUBLIC_HOST: str = "127.0.0.1:8080" + VOICE_GATEWAY_HOST: str = "127.0.0.1:8000/voice" GATEWAY_HOST: str = "127.0.0.1:8080/gateway" CDN_HOST: str = "127.0.0.1:8080/media" STORAGE: ConfigStorage = Field(default_factory=ConfigStorage) diff --git a/yepcord/yepcord/gateway_dispatcher.py b/yepcord/yepcord/gateway_dispatcher.py index 8ecc6f3..55b5cd7 100644 --- a/yepcord/yepcord/gateway_dispatcher.py +++ b/yepcord/yepcord/gateway_dispatcher.py @@ -120,7 +120,8 @@ async def sendStickersUpdateEvent(self, guild: Guild) -> None: stickers = [await sticker.ds_json() for sticker in stickers] await self.dispatch(StickersUpdateEvent(guild.id, stickers), guild_id=guild.id) - async def getChannelFilter(self, channel: Channel, permissions: int = 0) -> dict: + @staticmethod + async def getChannelFilter(channel: Channel, permissions: int = 0) -> dict: if channel.type in {ChannelType.DM, ChannelType.GROUP_DM}: return {"user_ids": await channel.recipients.all().values_list("id", flat=True)} @@ -157,7 +158,8 @@ async def getChannelFilter(self, channel: Channel, permissions: int = 0) -> dict return {"role_ids": result_roles, "user_ids": list(user_ids), "exclude": list(excluded_user_ids)} - async def getRolesByPermissions(self, guild_id: int, permissions: int = 0) -> list[int]: + @staticmethod + async def getRolesByPermissions(guild_id: int, permissions: int = 0) -> list[int]: return await Role.filter(guild__id=guild_id).annotate(perms=RawSQL(f"permissions & {permissions}"))\ .filter(perms=permissions).values_list("id", flat=True) diff --git a/yepcord/yepcord/models/__init__.py b/yepcord/yepcord/models/__init__.py index 7fd8bfb..67259e8 100644 --- a/yepcord/yepcord/models/__init__.py +++ b/yepcord/yepcord/models/__init__.py @@ -27,6 +27,7 @@ from .guild_event import GuildEvent from .guild_template import GuildTemplate from .role import Role +from .voice_state import VoiceState from .message import Message from .attachment import Attachment diff --git a/yepcord/yepcord/models/voice_state.py b/yepcord/yepcord/models/voice_state.py new file mode 100644 index 0000000..572f971 --- /dev/null +++ b/yepcord/yepcord/models/voice_state.py @@ -0,0 +1,37 @@ +""" + YEPCord: Free open source selfhostable fully discord-compatible chat + Copyright (C) 2022-2024 RuslanUC + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +""" +from os import urandom +from typing import Optional + +from tortoise import fields + +import yepcord.yepcord.models as models +from ._utils import SnowflakeField, Model + + +def gen_token(): + return urandom(32).hex() + + +class VoiceState(Model): + id: int = SnowflakeField(pk=True) + guild: models.Guild = fields.ForeignKeyField("models.Guild", default=None, null=True) + channel: models.Channel = fields.ForeignKeyField("models.Channel") + user: models.User = fields.ForeignKeyField("models.User") + session_id: str = fields.CharField(max_length=64) + token: Optional[str] = fields.CharField(max_length=128, default=gen_token)