import asyncio
from asyncio import BaseEventLoop
from typing import Optional, Tuple
import discord
from discord.ext.commands import Bot
from . import enums, log, node, player
from .utils import Coroutine
__all__ = [
"initialize",
"add_node",
"connect",
"get_player",
"close",
"register_event_listener",
"unregister_event_listener",
"register_update_listener",
"unregister_update_listener",
"register_stats_listener",
"unregister_stats_listener",
"all_players",
"all_connected_players",
"active_players",
]
_event_listeners = []
_update_listeners = []
_stats_listeners = []
_loop: Optional[BaseEventLoop] = None
[docs]async def initialize(
bot: Bot,
):
"""
Setup event and update listener
.. important::
This function must only be called AFTER the bot has received its
"on_ready" event!
Parameters
----------
bot : discord.ext.commands.Bot
An instance of a discord `Bot` object.
"""
global _loop
_loop = bot.loop
register_event_listener(_handle_event)
register_update_listener(_handle_update)
bot.add_listener(_on_guild_remove, name="on_guild_remove")
[docs]async def add_node(
bot: Bot,
host: str,
password: str,
ws_port: int,
timeout: int = 30,
resume_key: Optional[str] = None,
resume_timeout: int = 60,
):
"""
Create and initialize a new node
.. important::
This function must only be called AFTER the initialize function
Parameters
----------
bot : discord.ext.commands.Bot
An instance of a discord `Bot` object.
host : str
The hostname or IP address of the Lavalink node.
password : str
The password of the Lavalink node.
ws_port : int
The websocket port on the Lavalink Node.
timeout : int
Amount of time to allow retries to occur, ``None`` is considered forever.
resume_key : Optional[str]
A resume key used for resuming a session upon re-establishing a WebSocket connection to Lavalink.
resume_timeout : int
How long the node should wait for a connection while disconnected before clearing all players.
"""
lavalink_node = node.Node(
_loop=_loop,
event_handler=dispatch,
host=host,
password=password,
port=ws_port,
user_id=bot.user.id,
num_shards=bot.shard_count or 1,
resume_key=resume_key,
resume_timeout=resume_timeout,
bot=bot,
)
await lavalink_node.connect(timeout=timeout)
lavalink_node._retries = 0
[docs]async def connect(channel: discord.VoiceChannel, deafen: bool = False):
"""
Connects to a discord voice channel.
This is the publicly exposed way to connect to a discord voice channel.
The :py:func:`initialize` function must be called first!
Parameters
----------
deafen : bool
Prevent the bot from listening others
channel : discord.VoiceChannel
The channel to move to
Returns
-------
Player
The created Player object.
Raises
------
IndexError
If there are no available lavalink nodes ready to connect to discord.
"""
node_ = node.get_node(channel.guild.id)
p = await node_.create_player(channel, deafen=deafen)
return p
[docs]def get_player(guild_id: int) -> player.Player:
"""
Get the player of a guild
Parameters
----------
guild_id : int
The guild id
Returns
-------
Player
The player of the given guild id
"""
node_ = node.get_node(guild_id)
return node_.get_player(guild_id)
async def _on_guild_remove(guild: discord.Guild):
try:
p = get_player(guild.id)
except (IndexError, KeyError):
pass
else:
await p.disconnect()
[docs]def register_event_listener(coro: Coroutine):
"""
Registers a coroutine to receive lavalink event information.
This coroutine will accept three arguments: :py:class:`Player`,
:py:class:`LavalinkEvents`, and possibly an extra. The value of the extra depends
on the value of the second argument.
If the second argument is :py:attr:`LavalinkEvents.TRACK_END`, the extra will
be a :py:class:`TrackEndReason`.
If the second argument is :py:attr:`LavalinkEvents.TRACK_EXCEPTION`, the extra
will be a dictionary with ``message``, ``cause``, and ``severity`` keys.
If the second argument is :py:attr:`LavalinkEvents.TRACK_STUCK`, the extra will
be the threshold milliseconds that the track has been stuck for.
If the second argument is :py:attr:`LavalinkEvents.TRACK_START`, the extra will be
a track identifier string.
If the second argument is any other value, the third argument will not exist.
Parameters
----------
coro : :ref:`coroutine <coroutine>`
A coroutine function that accepts the arguments listed above.
Raises
------
TypeError
If ``coro`` is not a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Function is not a coroutine.")
if coro not in _event_listeners:
_event_listeners.append(coro)
async def _handle_event(player, data: enums.LavalinkEvents, extra):
await player.handle_event(data, extra)
def _get_event_args(data: enums.LavalinkEvents, raw_data: dict):
guild_id = int(raw_data.get("guildId"))
try:
node_ = node.get_node(guild_id, ignore_ready_status=True)
player = node_.get_player(guild_id)
except (IndexError, KeyError):
if data != enums.LavalinkEvents.TRACK_END:
log.debug(
"Got an event for a guild that we have no player for."
" This may be because of a forced voice channel"
" disconnect."
)
return
extra = None
if data == enums.LavalinkEvents.TRACK_END:
extra = enums.TrackEndReason(raw_data.get("reason"))
elif data == enums.LavalinkEvents.TRACK_EXCEPTION:
exception_data = raw_data.get("exception", {})
extra = {
"message": exception_data.get(
"message", "Something went wrong when decoding the track."
),
"cause": exception_data.get("cause", "Unhandled Exception"),
"severity": enums.ExceptionSeverity(exception_data.get("severity", "FATAL")),
}
elif data == enums.LavalinkEvents.TRACK_STUCK:
extra = raw_data.get("thresholdMs")
elif data == enums.LavalinkEvents.TRACK_START:
extra = raw_data.get("track")
elif data == enums.LavalinkEvents.WEBSOCKET_CLOSED:
extra = {
"code": raw_data.get("code"),
"reason": raw_data.get("reason"),
"byRemote": raw_data.get("byRemote"),
"channelID": player.channel.id if player.channel else None,
}
return player, data, extra
[docs]def unregister_event_listener(coro: Coroutine):
"""
Unregisters coroutines from being event listeners.
Parameters
----------
coro : :ref:`coroutine <coroutine>`
"""
try:
_event_listeners.remove(coro)
except ValueError:
pass
[docs]def register_update_listener(coro: Coroutine):
"""
Registers a coroutine to receive lavalink player update information.
This coroutine will accept two arguments: an instance of :py:class:`Player`
and an instance of :py:class:`PlayerState`.
Parameters
----------
coro : :ref:`coroutine <coroutine>`
Raises
------
TypeError
If ``coro`` is not a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Function is not a coroutine.")
if coro not in _update_listeners:
_update_listeners.append(coro)
async def _handle_update(player, data: enums.PlayerState, raw_data: dict):
await player.handle_player_update(data)
def _get_update_args(data: enums.PlayerState, raw_data: dict):
guild_id = int(raw_data.get("guildId"))
try:
player = get_player(guild_id)
except (KeyError, IndexError):
log.debug(
"Got a player update for a guild that we have no player for."
" This may be because of a forced voice channel disconnect."
)
return
return player, data, raw_data
[docs]def unregister_update_listener(coro: Coroutine):
"""
Unregisters coroutines from being player update listeners.
Parameters
----------
coro : :ref:`coroutine <coroutine>`
"""
try:
_update_listeners.remove(coro)
except ValueError:
pass
[docs]def register_stats_listener(coro: Coroutine):
"""
Registers a coroutine to receive lavalink server stats information.
This coroutine will accept a single argument which will be an instance
of :py:class:`Stats`.
Parameters
----------
coro : :ref:`coroutine <coroutine>`
Raises
------
TypeError
If ``coro`` is not a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError("Function is not a coroutine.")
if coro not in _stats_listeners:
_stats_listeners.append(coro)
[docs]def unregister_stats_listener(coro: Coroutine):
"""
Unregisters coroutines from being server stats listeners.
Parameters
----------
coro : :ref:`coroutine <coroutine>`
"""
try:
_stats_listeners.remove(coro)
except ValueError:
pass
def dispatch(op: enums.LavalinkIncomingOp, data, raw_data: dict):
listeners = []
args = []
if op == enums.LavalinkIncomingOp.EVENT:
listeners = _event_listeners
args = _get_event_args(data, raw_data)
elif op == enums.LavalinkIncomingOp.PLAYER_UPDATE:
listeners = _update_listeners
args = _get_update_args(data, raw_data)
elif op == enums.LavalinkIncomingOp.STATS:
listeners = _stats_listeners
args = [data]
if args is None:
# For example, no player because channel got removed.
return
for coro in listeners:
_loop.create_task(coro(*args))
[docs]async def close(bot: Bot):
"""
Closes the lavalink connection completely.
Parameters
----------
bot: discord.ext.commands.Bot
"""
unregister_event_listener(_handle_event)
unregister_update_listener(_handle_update)
bot.remove_listener(_on_guild_remove, name="on_guild_remove")
await node.disconnect()
# Helper methods
[docs]def all_players() -> Tuple[player.Player]:
"""Get all the players"""
nodes = node._nodes
ret = tuple(p for n in nodes for p in n.players)
return ret
[docs]def all_connected_players() -> Tuple[player.Player]:
"""Get all the connected players"""
nodes = node._nodes
ret = tuple(p for n in nodes for p in n.players if p.connected)
return ret
[docs]def active_players() -> Tuple[player.Player]:
"""Get all the active players"""
ps = all_connected_players()
return tuple(p for p in ps if p.is_playing)