Source code for defectio.client

from __future__ import annotations

import asyncio
import logging
import signal
import sys
import traceback
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar, Sequence

import aiohttp
from defectio.models.auth import Auth
from defectio.models.user import ClientUser
from defectio import utils
from .models import Message

from . import __version__
from .gateway import DefectioWebsocket
from .http import DefectioHTTP
from .models import User
from .state import ConnectionState

if TYPE_CHECKING:
    from .models import Channel, Server

__all__ = ("Client",)

Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])

logger = logging.getLogger("defectio")


def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
    tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}

    if not tasks:
        return

    logger.info("Cleaning up after %d tasks.", len(tasks))
    for task in tasks:
        task.cancel()

    loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
    logger.info("All tasks finished cancelling.")

    for task in tasks:
        if task.cancelled():
            continue
        if task.exception() is not None:
            loop.call_exception_handler(
                {
                    "message": "Unhandled exception during Client.run shutdown.",
                    "exception": task.exception(),
                    "task": task,
                }
            )


def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
    try:
        _cancel_tasks(loop)
        loop.run_until_complete(loop.shutdown_asyncgens())
    finally:
        logger.info("Closing the event loop.")
        loop.close()


[docs]class Client: def __init__( self, *, api_url: Optional[str] = "https://api.revolt.chat", loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any, ) -> None: """Creates a new client. Parameters ---------- api_url : Optional[str], optional url to revolt instance, by default "https://api.revolt.chat" loop : Optional[asyncio.AbstractEventLoop], optional asyncio event loop to use otherwise it is grabbed, by default None """ self.api_url: str = api_url self.loop: asyncio.AbstractEventLoop = ( asyncio.get_event_loop() if loop is None else loop ) self.websocket: DefectioWebsocket = None self.http: DefectioHTTP = None self.session = kwargs.pop("session", None) self._handlers: dict[str, Callable] = {"ready": self._handle_ready} self._listeners: list[ str, list[tuple[asyncio.Future, Callable[..., bool]]] ] = {} self._ready = asyncio.Event() self._closed = True self._auth: Optional[Auth] = None self._connection: ConnectionState = self._get_state(**kwargs) def _get_state(self, **options: Any) -> ConnectionState: """Returns the connection state. Returns ------- ConnectionState The connection state. """ return ConnectionState( dispatch=self.dispatch, handlers=self._handlers, http=self.get_http, websocket=self.get_websocket, auth=self.get_auth, loop=self.loop, **options, ) def _handle_ready(self) -> None: """Handles the ready event.""" self._ready.set() async def _run_event( self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any, ) -> None: """Runs an event. Parameters ---------- coro : Callable[..., Coroutine[Any, Any, Any]] The coroutine to run. event_name : str The name of the event to run. """ try: await coro(*args, **kwargs) except asyncio.CancelledError: pass except Exception: try: await self.on_error(event_name, *args, **kwargs) except asyncio.CancelledError: pass def _schedule_event( self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any, ) -> asyncio.Task: """Schedules an event to be run. Parameters ---------- coro : Callable[..., Coroutine[Any, Any, Any]] The coroutine to run. event_name : str The name of the event to run. Returns ------- asyncio.Task The task that the event was scheduled for. """ wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task return asyncio.create_task(wrapped, name=f"defectio: {event_name}")
[docs] def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: """Dispatch an event Parameters ---------- event : str The event to dispatch. """ logger.debug("Dispatching event %s", event) method = "on_" + event listeners = self._listeners.get(event) if listeners: removed = [] for i, (future, condition) in enumerate(listeners): if future.cancelled(): removed.append(i) continue try: result = condition(*args) except Exception as exc: future.set_exception(exc) removed.append(i) else: if result: if len(args) == 0: future.set_result(None) elif len(args) == 1: future.set_result(args[0]) else: future.set_result(args) removed.append(i) if len(removed) == len(listeners): self._listeners.pop(event) else: for idx in reversed(removed): del listeners[idx] try: coro = getattr(self, method) except AttributeError: pass else: self._schedule_event(coro, method, *args, **kwargs)
[docs] async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: """|coro| The default error handler provided by the client. By default this prints to :data:`sys.stderr` however it could be overridden to have a different implementation. """ print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exc()
[docs] async def wait_until_ready(self) -> None: """|coro| Waits until the client's internal cache is all ready. """ await self._ready.wait()
[docs] def wait_for( self, event: str, *, check: Optional[Callable[..., bool]] = None, timeout: Optional[float] = None, ) -> Any: """Waits for a specific event to be dispatched. Parameters ---------- event : str The event to wait for. check : Optional[Callable[..., bool]], optional A check to run on the event, by default None timeout : Optional[float], optional timeout to wait, by default None Returns ------- Any response from method """ future = self.loop.create_future() if check is None: def _check(*args): return True check = _check ev = event.lower() try: listeners = self._listeners[ev] except KeyError: listeners = [] self._listeners[ev] = listeners listeners.append((future, check)) return asyncio.wait_for(future, timeout)
# event registration
[docs] def event(self, coro: Coro) -> Coro: """A decorator that registers an event to listen to. Example --------- .. code-block:: python3 @client.event async def on_ready(): print('Ready!') Raises -------- TypeError The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): raise TypeError("event registered must be a coroutine function") setattr(self, coro.__name__, coro) logger.debug("%s has successfully been registered as an event", coro.__name__) return coro
################ ## Properties ## ################ @property def user(self) -> Optional[ClientUser]: """Optional[:class:`.ClientUser`]: Represents the connected client. ``None`` if not logged in.""" return self._connection.user @property def users(self) -> list[User]: """Returns a list of all the users stored in the internal cache. Returns ------- list[User] A list of cached users. """ return list(self._connection._users.values()) @property def cached_messages(self) -> Sequence[Message]: """Sequence[:class:`.Message`]: Read-only list of messages the connected client has cached. .. versionadded:: 1.1 """ return utils.SequenceProxy(self._connection._messages or []) @property def servers(self) -> list[Server]: """Returns a list of all the servers stored in the internal cache. Returns ------- list[Server] A list of cached servers. """ return list(self._connection._servers.values()) @property def channels(self) -> list[Channel]: """Returns a list of all the channels stored in the internal cache. Returns ------- list[Channel] [A list of cached channels """ return list(self._connection._server_channels.values())
[docs] def get_auth(self) -> Auth: """Returns the Auth object used for logging in.""" return self._auth
def get_http(self) -> DefectioHTTP: return self.http def get_websocket(self) -> DefectioWebsocket: return self.websocket ############# ## Getters ## #############
[docs] def get_channel(self, channel_id: str) -> Optional[Channel]: """Get a channel with the specified ID from the internal cache. Parameters ---------- channel_id : str The channel ID to look for. Returns ------- Optional[Channel] The requested channel. If not found, returns ``None``. """ channel = self._connection.get_channel(channel_id) return channel
[docs] def get_server(self, server_id: str) -> Optional[Server]: """Get a server with the specified ID from the internal cache. Parameters ---------- server_id : str The server ID to look for. Returns ------- Optional[Server] The requested server. If not found, returns ``None``. """ server = self._connection.get_server(server_id) return server
[docs] def get_user(self, user_id: str) -> Optional[User]: """Get a user with the specified ID from the internal cache. Parameters ---------- user_id : str The user ID to look for. Returns ------- Optional[User] The requested user. If not found, returns ``None``. """ user = self._connection.get_user(user_id) return user
[docs] async def fetch_channel(self, channel_id: str) -> Optional[Channel]: """Fetches a channel from revolt bypassing the internal cache. This should be used if you beleive the cache may be stale but it is recommended to use :meth:`get_channel` instead. Parameters ---------- channel_id : str The channel ID to look for. Returns ------- Optional[Channel] The requested channel. If not found, returns ``None``. """ channel = await self._connection.http.get_channel(channel_id) if channel: channel = self._connection._add_channel_from_data(channel) return channel
[docs] async def fetch_server(self, server_id: str) -> Optional[Server]: """Fetches a server from revolution bypassing the internal cache. This should be used if you beleive the cache may be stale but it is recommended to use :meth:`get_server` instead. Parameters ---------- server_id : str The server ID to look for. Returns ------- Optional[Server] The requested server. If not found, returns ``None``. """ server = await self._connection.http.get_server(server_id) if server: server = self._connection._add_server_from_data(server) return server
[docs] async def fetch_user(self, user_id: str) -> Optional[User]: """Fetches a user from revolution bypassing the internal cache. This should be used if you beleive the cache may be stale but it is recommended to use :meth:`get_user` instead. Parameters ---------- user_id : str The user ID to look for. Returns ------- Optional[User] The requested user. If not found, returns ``None``. """ user = await self._connection.http.get_user(user_id) if user: user = self._connection._add_user_from_data(user) return user
###################### ## State Management ## ######################
[docs] def is_closed(self): """Indicates if the websocket connection is closed.""" return self.websocket.closed and self.session.closed
[docs] async def close(self) -> None: """|coro| Closes the connection to revolt. """ if self._closed: return self._closed = True if self.websocket is not None: await self.websocket.close() if self.session is not None: await self.session.close()
[docs] async def create(self) -> None: """|coro| Creates the client with the cache, websocket and http client. """ user_agent = "Defectio (https://github.com/Darkflame72/defectio {0}) Python/{1[0]}.{1[1]} aiohttp/{2}".format( __version__, sys.version_info, aiohttp.__version__ ) self.session = aiohttp.ClientSession() self.http = DefectioHTTP(self.session, self.api_url, user_agent) api_info = await self.http.node_info() api_info = self._connection.set_api_info(api_info) self.api_info = api_info self.websocket = DefectioWebsocket( self.session, api_info.ws_url, user_agent, self )
async def connect(self) -> None: self._closed = False
[docs] async def login(self, token: str, bot: bool = True) -> None: """|coro| Logs in using the token provided as a bot. Parameters ---------- token : str The authentication token. """ self._auth = self.http.start(token, bot=bot) await self.websocket.start(self._auth)
[docs] async def start( self, *, token: Optional[str] = None, bot: bool = True, ) -> None: """|coro| Creates a client and logs the user in. Parameters ---------- token : Optional[str] The Revolt API token. session_token : Optional[str] The Revolt session ID of a user user_id : Optional[str] The ID of the user which th session token belongs to """ await self.create() await self.login(token, bot=bot) await self.connect()
[docs] def run(self, token: Optional[str] = None, *, bot: bool = True) -> None: """A blocking call that abstracts away the event loop initialisation from you. If you want more control over the event loop then this function should not be used. Use :meth:`start` coroutine or :meth:`connect` + :meth:`login`. Roughly Equivalent to: :: try: loop.run_until_complete(start(*args, **kwargs)) except KeyboardInterrupt: loop.run_until_complete(close()) # cancel all tasks lingering finally: loop.close() .. warning:: This function must be the last function to call due to the fact that it is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. Parameters ----------- token: Optional[:class:`str`] The authentication token of the bot to login. bot: bool Indicates if the client is a bot account. Defaults to ``True``. """ loop = self.loop try: loop.add_signal_handler(signal.SIGINT, loop.stop) loop.add_signal_handler(signal.SIGTERM, loop.stop) except NotImplementedError: pass async def runner() -> None: try: await self.start(token=token, bot=bot) finally: if not self.is_closed(): await self.close() def stop_loop_on_completion(f): loop.stop() future = asyncio.ensure_future(runner(), loop=loop) future.add_done_callback(stop_loop_on_completion) try: loop.run_forever() except KeyboardInterrupt: logger.info("Received signal to terminate bot and event loop.") finally: future.remove_done_callback(stop_loop_on_completion) logger.info("Cleaning up tasks.") _cleanup_loop(loop)