Skip to content

Commit

Permalink
Implement Asset, avatar, assume utf-8 when decoding text for better perf
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryptex-github committed Dec 17, 2021
1 parent e7faa49 commit 023ead6
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 11 deletions.
10 changes: 10 additions & 0 deletions docs/ferriswheel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ Client
:members:
:inherited-members:

Asset
-----

.. autoclass:: Asset
:members:

Base
----

Expand Down Expand Up @@ -283,6 +289,10 @@ User
:members:
:inherited-members:

.. autoclass:: ClientUser
:members:
:inherited-members:

Utility Functions
-----------------

Expand Down
1 change: 1 addition & 0 deletions ferris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import urllib.request

from .asset import *
from .base import *
from .channel import *
from .client import *
Expand Down
74 changes: 74 additions & 0 deletions ferris/asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

# Some of the code here are shamelessly robbed from dpy.

from os import PathLike
from typing import AnyStr, Union, TYPE_CHECKING, Any
from io import BufferedIOBase


__all__ = ('Asset',)

if TYPE_CHECKING:
from .connection import Connection

class Asset:
def __init__(self, connection: Connection, url: str) -> None:
self._connection = connection

self._url: str = url

async def read(self) -> bytes:
"""|coro|
Retrives the content of this asset as a :class:`bytes` object.
Raises
------
HTTPException
An HTTP error occurred while fetching the content.
Returns
-------
:class:`bytes`
"""
return await self._connection._http.get_asset(self._url)

async def save(self, fp: Union[AnyStr, PathLike, BufferedIOBase], *, seek_begin: bool = True) -> int:
"""|coro|
Saves the asset to a file-like object.
Parameters
----------
fp: Union[AnyStr, PathLike, BufferedIOBase]
The file-like object to save the asset to.
seek_begin: bool
Whether to seek to the beginning of the file after saving.
Defaults to ``True``.
"""
data = await self.read()

if isinstance(fp, BufferedIOBase):
written = await self._connection.to_thread(fp.write, data)
if seek_begin:
fp.seek(0)

return written

with open(fp, 'wb+') as f:
return await self._connection.to_thread(f.write, data)

def __str__(self) -> str:
return self._url

def __repr__(self) -> str:
return f'<Asset url={self._url!r}>'

def __eq__(self, other: Any) -> bool:
return isinstance(other, Asset) and self._url == other._url

def __len__(self) -> int:
return len(self._url)

def __hash__(self) -> int:
return hash(self._url)
7 changes: 6 additions & 1 deletion ferris/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import asyncio
from asyncio import AbstractEventLoop
from collections import deque
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Awaitable, Coroutine, Dict, Optional, Union

import functools

from ferris.types.base import Snowflake
from ferris.user import ClientUser
Expand Down Expand Up @@ -65,6 +67,9 @@ async def _initialize_http_with_email(self, email: str, password: str, /) -> Non
email, password
)
self._store_token(self._http.token)

def to_thread(self, func, /, *args, **kwargs) -> Awaitable:
return self.loop.run_in_executor(None, functools.partial(func, *args, **kwargs))

def clear_store(self, /) -> None:
self._users: Dict[Snowflake, User] = {}
Expand Down
24 changes: 22 additions & 2 deletions ferris/guild.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Dict, List, Optional

from ferris.role import Role

from .base import BaseObject
from .asset import Asset
from .bitflags import GuildFlags
from .channel import Channel
from .invite import Invite
from .role import Role
Expand All @@ -23,7 +25,7 @@
class Guild(BaseObject):
"""Represents a FerrisChat guild."""

__slots__ = ('_connection', '_owner_id', '_name', '_channels', '_members', '_roles')
__slots__ = ('_connection', '_owner_id', '_name', '_channels', '_members', '_roles', '_avatar', '_flags')

def __init__(self, connection: Connection, data: Optional[GuildPayload], /) -> None:
self._connection: Connection = connection
Expand All @@ -38,11 +40,19 @@ def _process_data(self, data: Optional[GuildPayload], /) -> None:
self._store_snowflake(data.get('id'))

self._owner_id: Optional[Snowflake] = data.get('owner_id')

if avatar := data.get('avatar'):
self._avatar: Optional[Asset] = Asset(self._connection, avatar)
else:
self._avatar: Optional[Asset] = None

self._name: Optional[str] = data.get('name')

self._channels: Dict[Snowflake, Channel] = {}
self._roles: Dict[Snowflake, Role] = {}

self._flags: GuildFlags = GuildFlags(data.get('flags') or 0)

for c in data.get('channels') or []:
if channel := self._connection.get_channel(c.get('id')):
channel._process_data(c)
Expand Down Expand Up @@ -293,11 +303,21 @@ def get_member(self, id: Id) -> Optional[Member]:
"""
id = sanitize_id(id)
return self._members.get(id)

@property
def flags(self) -> GuildFlags:
"""GuildFlags: The flags of this guild."""
return GuildFlags(self._flags)

@property
def owner(self) -> Optional[Member]:
"""Member: The owner of this guild."""
return self.get_member(self._owner_id)

@property
def avatar(self) -> Optional[Asset]:
"""Asset: The avatar of this guild."""
return self._avatar

@property
def owner_id(self, /) -> Optional[Snowflake]:
Expand Down
11 changes: 9 additions & 2 deletions ferris/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def api(self) -> APIRouter:
@property
def session(self) -> aiohttp.ClientSession:
return self.__session

async def get_asset(self, url: str) -> bytes:
async with self.__session.get() as resp:
if 400 > resp.status >= 200:
return await resp.read()

raise HTTPException(resp, 'Failed to get asset')

@classmethod
async def from_email_and_password(cls, email: str, password: str) -> HTTPClient:
Expand All @@ -115,7 +122,7 @@ async def from_email_and_password(cls, email: str, password: str) -> HTTPClient:
json={'email': email, 'password': password},
connector=aiohttp.TCPConnector(ssl=cls.USE_SSL),
) as response:
content = await response.text()
content = await response.text('utf-8')

if 400 > response.status >= 200:
token = from_json(content)['token']
Expand Down Expand Up @@ -178,7 +185,7 @@ async def request(self, url: str, method: str, /, **kwargs) -> Optional[Data]:
async with self.__session.request(
method, url, headers=headers, **kwargs
) as response:
content = await response.text()
content = await response.text('utf-8')

log.debug(f'{method} {url} Returned {response.status} with {content}')

Expand Down
21 changes: 15 additions & 6 deletions ferris/user.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations
from ferris.bitflags import UserFlags
from ferris.types.base import Snowflake

from typing import TYPE_CHECKING, Dict, List, Optional

from .base import BaseObject
from .asset import Asset
from .guild import Guild

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,7 +49,7 @@ class User(BaseObject):
Represents a FerrisChat user.
"""

__slots__ = ('_connection', '_name', '_avatar')
__slots__ = ('_connection', '_name', '_avatar', '_flags')

def __init__(self, connection: Connection, data: UserPayload, /) -> None:
self._connection: Connection = connection
Expand All @@ -61,20 +63,27 @@ def _process_data(self, data: Optional[UserPayload], /) -> None:

self._name: Optional[str] = data.get('name')

self._avatar: Optional[str] = data.get('avatar')
if avatar := data.get('avatar'):
self._avatar: Optional[Asset] = Asset(self._connection, avatar)
else:
self._avatar: Optional[Asset] = None

# self._flags = data.get('flags')
# UserFlag after ferrischat implemented it
self._flags: UserFlags = UserFlags(data.get('flags') or 0)

@property
def name(self, /) -> Optional[str]:
"""str: The username of this user."""
return self._name

@property
def avatar(self, /) -> Optional[str]:
"""str: The avatar url of this user."""
def avatar(self, /) -> Optional[Asset]:
"""Asset: The avatar url of this user."""
return self._avatar

@property
def flags(self) -> UserFlags:
"""UserFlags: The flags of this user."""
return self._flags

def __del__(self, /) -> None:
if not hasattr(self, '_connection'):
Expand Down

0 comments on commit 023ead6

Please sign in to comment.