Skip to content

Commit

Permalink
adding polls
Browse files Browse the repository at this point in the history
  • Loading branch information
RuslanUC committed Nov 7, 2024
1 parent 2c8566e commit 5160236
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 8 deletions.
33 changes: 33 additions & 0 deletions yepcord/gateway/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,3 +1158,36 @@ async def json(self) -> dict:
"token_data": None,
}
}


class MessagePollVoteAddEvent(DispatchEvent):
NAME = "MESSAGE_POLL_VOTE_ADD"

__slots__ = ("user_id", "message_id", "channel_id", "answer_id", "guild_id",)

def __init__(self, user_id: int, message_id: int, channel_id: int, answer_id: int, guild_id: int | None):
self.user_id = user_id
self.message_id = message_id
self.channel_id = channel_id
self.answer_id = answer_id
self.guild_id = guild_id

async def json(self) -> dict:
data = {
"t": self.NAME,
"op": self.OP,
"d": {
"user_id": str(self.user_id),
"channel_id": str(self.channel_id),
"message_id": str(self.message_id),
"answer_id": self.answer_id,
}
}
if self.guild_id is not None:
data["d"]["guild_id"] = str(self.guild_id)

return data


class MessagePollVoteRemoveEvent(MessagePollVoteAddEvent):
NAME = "MESSAGE_POLL_VOTE_REMOVE"
31 changes: 29 additions & 2 deletions yepcord/rest_api/models/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,38 @@ def model_dump(self, *args, **kwargs):
return super().model_dump(*args, **kwargs)


# noinspection PyMethodParameters
class MessagePollContentModel(BaseModel):
text: str


class MessagePollQuestionContentModel(MessagePollContentModel):
text: str = Field(max_length=300)


class MessagePollAnswerContentModel(MessagePollContentModel):
text: str = Field(max_length=55)


class MessagePollAnswerModel(BaseModel):
poll_media: MessagePollAnswerContentModel


class MessagePollModel(BaseModel):
question: MessagePollQuestionContentModel
answers: list[MessagePollAnswerModel] = Field(max_length=10) # TODO: is it 10?
allow_multiselect: bool = False
duration: int = 24 # In hours, TODO: add validation for max and min
layout_type: int = 1


class MessageCreate(BaseModel):
content: Optional[str] = None
nonce: Optional[str] = None
embeds: list[EmbedModel] = Field(default_factory=list)
sticker_ids: list[int] = Field(default_factory=list)
message_reference: Optional[MessageReferenceModel] = None
flags: Optional[int] = None
poll: Optional[MessagePollModel] = None

@field_validator("content")
def validate_content(cls, value: Optional[str]):
Expand Down Expand Up @@ -392,7 +416,6 @@ def to_json(self) -> dict:
return data


# noinspection PyMethodParameters
class MessageUpdate(BaseModel):
content: Optional[str] = None
embeds: list[EmbedModel] = Field(default_factory=list)
Expand Down Expand Up @@ -503,3 +526,7 @@ def validate_limit(cls, value: int) -> int:
if value > 10:
value = 10
return value


class AnswerPoll(BaseModel):
answer_ids: list[int]
61 changes: 57 additions & 4 deletions yepcord/rest_api/routes/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@

from ..dependencies import DepUser, DepChannel, DepMessage
from ..models.channels import ChannelUpdate, MessageCreate, MessageUpdate, InviteCreate, PermissionOverwriteModel, \
WebhookCreate, GetReactionsQuery, MessageAck, CreateThread, CommandsSearchQS, SearchQuery, GetMessagesQuery
WebhookCreate, GetReactionsQuery, MessageAck, CreateThread, CommandsSearchQS, SearchQuery, GetMessagesQuery, \
AnswerPoll
from ..utils import _getMessage, processMessage
from ..y_blueprint import YBlueprint
from ...gateway.events import MessageCreateEvent, TypingEvent, MessageDeleteEvent, MessageUpdateEvent, \
DMChannelCreateEvent, DMChannelUpdateEvent, ChannelRecipientAddEvent, ChannelRecipientRemoveEvent, \
DMChannelDeleteEvent, MessageReactionAddEvent, MessageReactionRemoveEvent, ChannelUpdateEvent, ChannelDeleteEvent, \
WebhooksUpdateEvent, ThreadCreateEvent, ThreadMemberUpdateEvent, MessageAckEvent, GuildAuditLogEntryCreateEvent
WebhooksUpdateEvent, ThreadCreateEvent, ThreadMemberUpdateEvent, MessageAckEvent, GuildAuditLogEntryCreateEvent, \
MessagePollVoteAddEvent, MessagePollVoteRemoveEvent
from ...yepcord.ctx import getGw
from ...yepcord.enums import GuildPermissions, MessageType, ChannelType, WebhookType, GUILD_CHANNELS, MessageFlags
from ...yepcord.errors import UnknownMessage, UnknownUser, UnknownEmoji, UnknownInteraction, MaxPinsReached, \
MissingPermissions, CannotSendToThisUser, CannotExecuteOnDM, CannotEditAnotherUserMessage, MissingAccess
MissingPermissions, CannotSendToThisUser, CannotExecuteOnDM, CannotEditAnotherUserMessage, MissingAccess, Errors, \
InvalidDataErr, UnknowPoll
from ...yepcord.models import User, Channel, Message, ReadState, Emoji, PermissionOverwrite, Webhook, ThreadMember, \
ThreadMetadata, AuditLogEntry, Relationship, ApplicationCommand, Integration, Bot, Role, HiddenDmChannel, Invite, \
Reaction
Reaction, PollVote, Poll, PollAnswer
from ...yepcord.snowflake import Snowflake
from ...yepcord.storage import getStorage
from ...yepcord.utils import getImage, b64encode, b64decode
Expand Down Expand Up @@ -675,3 +678,53 @@ async def get_message_interaction(user: User = DepUser, channel: Channel = DepCh
"application_command": None,
"options": (message.interaction.data or {}).get("options", [])
}


@channels.put("/<int:channel_id>/polls/<int:message>/answers/@me", body_cls=AnswerPoll)
async def set_poll_answers(
data: AnswerPoll, user: User = DepUser, channel: Channel = DepChannel, message: Message = DepMessage
):
if channel.guild:
member = await channel.guild.get_member(user.id)
await member.checkPermission(GuildPermissions.READ_MESSAGE_HISTORY, GuildPermissions.VIEW_CHANNEL, channel=channel)

if (poll := await Poll.get_or_none(message=message)) is None:
raise UnknowPoll

guild_id = channel.guild.id if channel.guild else None

answer_ids = set(data.answer_ids)
if len(answer_ids) > 1 and not poll.multiselect:
raise InvalidDataErr(400, Errors.make(50035, {"answer_ids": {
"code": "CANNOT_ADD_MULTIPLE_POLL_ANSWERS", "message": "Multiple votes are not allowed for this poll."
}}))

existing_votes = {
vote.answer.local_id: vote
for vote in await PollVote.filter(answer__poll=poll, user=user).select_related("answer")
}
remove_votes = {answer_id: vote for answer_id, vote in existing_votes.items() if answer_id not in answer_ids}
add_votes = answer_ids - remove_votes.keys()
add_answers = await PollAnswer.filter(poll=poll, local_id__in=add_votes)

for to_remove in remove_votes.values():
await to_remove.delete()
await getGw().dispatch(
MessagePollVoteRemoveEvent(user.id, message.id, channel.id, to_remove.answer.local_id, guild_id),
channel=channel,
permissions=GuildPermissions.VIEW_CHANNEL,
)

await PollVote.bulk_create([
PollVote(answer=answer, user=user)
for answer in add_answers
])

for answer in add_answers:
await getGw().dispatch(
MessagePollVoteAddEvent(user.id, message.id, channel.id, answer.local_id, guild_id),
channel=channel,
permissions=GuildPermissions.VIEW_CHANNEL,
)
return "", 204

19 changes: 17 additions & 2 deletions yepcord/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
"""
from __future__ import annotations

from datetime import datetime, timedelta
from functools import wraps
from json import loads
from typing import Optional, Union, TYPE_CHECKING

from PIL import Image
from async_timeout import timeout
from magic import from_buffer
from pytz import UTC
from quart import request, current_app, g

import yepcord.yepcord.models as models
Expand Down Expand Up @@ -139,7 +141,8 @@ async def processMessageData(data: Optional[dict], channel: Channel) -> tuple[di
if not data.get("content") and \
not data.get("embeds") and \
not data.get("attachments") and \
not data.get("sticker_ids"):
not data.get("sticker_ids") and \
not data.get("poll"):
raise CannotSendEmptyMessage
return data, attachments

Expand Down Expand Up @@ -228,7 +231,7 @@ async def processMessage(data: dict, channel: Channel, author: Optional[User], v

message_type = await validate_reply(data, channel)
stickers_data = await process_stickers(data.sticker_ids)
if not data.content and not data.embeds and not attachments and not stickers_data["stickers"]:
if not data.content and not data.embeds and not attachments and not stickers_data["stickers"] and not data.poll:
raise CannotSendEmptyMessage

data_json = data.to_json()
Expand All @@ -240,6 +243,18 @@ async def processMessage(data: dict, channel: Channel, author: Optional[User], v
)
await models.ReadState.update_from_message(message)

if data.poll:
poll = await models.Poll.create(
message=message,
question=data.poll.question.text,
expires_at=datetime.now(UTC) + timedelta(hours=data.poll.duration),
multiselect=data.poll.allow_multiselect,
)
await models.PollAnswer.bulk_create([
models.PollAnswer(poll=poll, local_id=idx+1, text=answer.poll_media.text)
for idx, answer in enumerate(data.poll.answers)
])

message.nonce = data_json.get("nonce")

for attachment in attachments:
Expand Down
2 changes: 2 additions & 0 deletions yepcord/yepcord/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class _Errors:
err_10060 = "Unknown Sticker"
err_10062 = "Unknown Interaction"
err_10070 = "Unknown Guild Scheduled Event"
err_10102 = "Unknown poll"

err_30003 = "Maximum number of pins reached (50)"
err_30008 = "Maximum number of emojis reached"
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__(self, error_code: int):
UnknownSticker = UnknownX(10060)
UnknownInteraction = UnknownX(10062)
UnknownGuildEvent = UnknownX(10070)
UnknowPoll = UnknownX(10102)

MaxPinsReached = BadRequest(30003)
MaxEmojisReached = BadRequest(30008)
Expand Down
3 changes: 3 additions & 0 deletions yepcord/yepcord/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from .message import Message
from .attachment import Attachment
from .reaction import Reaction
from .poll import Poll
from .poll_answer import PollAnswer
from .poll_vote import PollVote

from .application import Application, gen_secret_key
from .bot import Bot, gen_token_secret
Expand Down
4 changes: 4 additions & 0 deletions yepcord/yepcord/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ async def ds_json(self, user_id: int = None, search: bool = False) -> dict:
"id": str(self.interaction.id),
"user": userdata,
}

if (poll := await models.Poll.get_or_none(message=self)) is not None:
data["poll"] = await poll.ds_json(user_id)

return data

async def _get_reactions_json(self, user_id: int) -> list:
Expand Down
68 changes: 68 additions & 0 deletions yepcord/yepcord/models/poll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
YEPCord: Free open source selfhostable fully discord-compatible chat
Copyright (C) 2022-2024 RuslanUC
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from datetime import datetime
from typing import Optional

from pytz import UTC
from tortoise import fields

import yepcord.yepcord.models as models
from ._utils import SnowflakeField, Model


class Poll(Model):
id: int = SnowflakeField(pk=True)
message: models.Message = fields.ForeignKeyField("models.Message", unique=True)
question: str = fields.CharField(max_length=300)
expires_at: datetime = fields.DatetimeField()
multiselect: bool = fields.BooleanField(default=False)

def is_expired(self) -> bool:
return datetime.now(UTC) > self.expires_at

async def ds_json(self, user_id: Optional[int] = None) -> dict:
answers = await models.PollAnswer.filter(poll=self).order_by("local_id")
counts = []
for answer in answers:
count = await models.PollVote.filter(answer=answer).count()
if not count:
continue
me = await models.PollVote.exists(answer=answer, user__id=user_id) \
if user_id is not None else 0
counts.append({
"id": answer.local_id,
"count": count,
"me_voted": me,
})

return {
"question": {
"text": self.question,
},
"answers": [
answer.ds_json()
for answer in answers
],
"expiry": self.expires_at.strftime("%Y-%m-%dT%H:%M:%S.000000+00:00"),
"allow_multiselect": self.multiselect,
"layout_type": 1,
"results": {
"answer_counts": counts,
"is_finalized": self.is_expired(),
}
}
39 changes: 39 additions & 0 deletions yepcord/yepcord/models/poll_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
YEPCord: Free open source selfhostable fully discord-compatible chat
Copyright (C) 2022-2024 RuslanUC
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from datetime import datetime

from pytz import UTC
from tortoise import fields

import yepcord.yepcord.models as models
from ._utils import SnowflakeField, Model


class PollAnswer(Model):
id: int = SnowflakeField(pk=True)
poll: models.Poll = fields.ForeignKeyField("models.Poll")
local_id: int = fields.SmallIntField()
text: str = fields.CharField(max_length=55)

def ds_json(self) -> dict:
return {
"answer_id": self.local_id,
"poll_media": {
"text": self.text,
},
}
Loading

0 comments on commit 5160236

Please sign in to comment.