from __future__ import annotations
import re
from collections import deque
from typing import Tuple, Any, Optional, TYPE_CHECKING
from urllib.parse import quote
import discord
from aiohttp.client_exceptions import ServerDisconnectedError
from yarl import URL
from . import log
from .enums import ExceptionSeverity, LoadType, PlayerState
from .tuples import PlaylistInfo
if TYPE_CHECKING:
from node import Node
from player import Player
__all__ = ["Track", "RESTClient", "playlist_info", "LoadResult"]
# This exists to preprocess rather than pull in dataclasses for __post_init__
# noinspection PyPep8Naming
def playlist_info(name: Optional[str] = None, selectedTrack: Optional[int] = None) -> PlaylistInfo:
return PlaylistInfo(
name=name if name is not None else "Unknown",
selectedTrack=selectedTrack if selectedTrack is not None else -1,
)
_re_youtube_timestamp = re.compile(r"[&?]t=(\d+)s?")
_re_soundcloud_timestamp = re.compile(r"#t=(\d+):(\d+)s?")
_re_twitch_timestamp = re.compile(r"\?t=(\d+)h(\d+)m(\d+)s")
def parse_timestamps(data: dict[str, Any]) -> list[dict[str, Any]]:
if data["loadType"] == LoadType.PLAYLIST_LOADED:
return data["tracks"]
new_tracks = deque()
query = data["query"]
try:
query_url = URL(query)
except ValueError:
query_url = None
if query_url is None:
return data["tracks"]
for track in data["tracks"]:
start_time = 0
try:
if all([query_url.scheme, query_url.host, query_url.path]) or any(
x in query for x in ["ytsearch:", "scsearch:"]
):
if (
(query_url.host in ["youtube.com", "youtu.be"] or "ytsearch:" in query)
and any(x in query for x in ["&t=", "?t="])
and not all(k in query for k in ["playlist?", "&list="])
):
match = re.search(_re_youtube_timestamp, query)
if match:
start_time = int(match.group(1))
elif (query_url.host == "soundcloud.com" or "scsearch:" in query) and "#t=" in query:
if "/sets/" not in query or ("/sets/" in query and "?in=" in query):
match = re.search(_re_soundcloud_timestamp, query)
if match:
start_time = (int(match.group(1)) * 60) + int(match.group(2))
elif query_url.host == "twitch.tv" and "?t=" in query:
match = re.search(_re_twitch_timestamp, query)
if match:
start_time = (
(int(match.group(1)) * 60 * 60)
+ (int(match.group(2)) * 60)
+ int(match.group(3))
)
except (AttributeError, IndexError):
pass
track["info"]["timestamp"] = start_time * 1000
new_tracks.append(track)
return new_tracks
[docs]class Track:
"""
Information about a Lavalink track.
Attributes
----------
requester : discord.User
The user who requested the track.
track_identifier : str
Track identifier used by the Lavalink player to play tracks.
identifier: str
Track identifier on YouTube
seekable : bool
Boolean determining if seeking can be done on this track.
author : str
The author of this track.
length : int
The length of this track in milliseconds.
is_stream : bool
Determines whether Lavalink will stream this track.
position : int
Current seeked position to begin playback.
title : str
Title of this track.
uri : str
The playback url of this track.
start_timestamp: int
The track start time in milliseconds as provided by the query.
"""
requester: discord.User
track_identifier: str
identifier: str
seekable: bool
author: str
length: int
is_stream: bool
position: int
title: str
uri: str
start_timestamp: int
extras: dict[str, Any]
def __init__(self, data: dict[str, Any]):
self.requester = None
self.track_identifier: str = data.get("track")
_info: dict = data.get("info", {})
self.identifier = _info.get("identifier")
self.source: Optional[str] = _info.get("sourceName", None)
self.seekable: bool = _info.get("isSeekable", False)
self.author: str = _info.get("author")
self.length: int = _info.get("length", 0)
self.is_stream: bool = _info.get("isStream", False)
self.position: int = _info.get("position")
self.title: str = _info.get("title")
self.uri: str = _info.get("uri")
self.start_timestamp: int = _info.get("timestamp", 0)
self.extras: dict = data.get("extras", {})
@property
def thumbnail(self) -> Optional[str]:
"""Returns a thumbnail URL for YouTube tracks."""
if self.source == "youtube":
return f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg"
elif self.source == "twitch":
return f"https://static-cdn.jtvnw.net/previews-ttv/live_user_{self.author.lower()}.jpg"
elif self.source == "soundcloud":
# TODO: return a real thumbnail
return f"https://developers.soundcloud.com/assets/logo_big_black-4fbe88aa0bf28767bbfc65a08c828c76.png"
else:
return None
def __eq__(self, other):
"""Overrides the default implementation"""
if isinstance(other, Track):
return self.track_identifier == other.track_identifier
return NotImplemented
def __ne__(self, other):
"""Overrides the default implementation"""
x = self.__eq__(other)
if x is not NotImplemented:
return not x
return NotImplemented
def __hash__(self):
"""Overrides the default implementation"""
return hash(tuple(sorted([self.track_identifier, self.title, self.author, self.uri])))
def __repr__(self):
return (
"<Track: "
f"track_identifier={self.track_identifier!r}, "
f"author={self.author!r}, "
f"length={self.length}, "
f"is_stream={self.is_stream}, uri={self.uri!r}, title={self.title!r}>"
)
[docs]class LoadResult:
"""
The result of a load_tracks request.
Attributes
----------
load_type : LoadType
The result of the loadtracks request
playlist_info : Optional[PlaylistInfo]
The playlist information detected by Lavalink
tracks : tuple[Track, ...]
The tracks that were loaded, if any
"""
load_type: LoadType
playlist_info: Optional[PlaylistInfo]
tracks: tuple[Track, ...]
def __init__(self, data: dict[str, Any]):
self._raw = data
_fallback = {
"loadType": LoadType.LOAD_FAILED,
"exception": {
"message": "Lavalink API returned an unsupported response, Please report it.",
"severity": ExceptionSeverity.SUSPICIOUS,
},
"playlistInfo": {},
"tracks": [],
}
for (k, v) in _fallback.items():
if k not in data:
if (
k == "exception"
and data.get("loadType", LoadType.LOAD_FAILED) != LoadType.LOAD_FAILED
):
continue
elif k == "exception":
v["message"] = (
f"Timestamp: {self._raw.get('timestamp', 'Unknown')}\n"
f"Status Code: {self._raw.get('status', 'Unknown')}\n"
f"Error: {self._raw.get('error', 'Unknown')}\n"
f"Query: {self._raw.get('query', 'Unknown')}\n"
f"Load Type: {self._raw['loadType']}\n"
f"Message: {self._raw.get('message', v['message'])}"
)
self._raw.update({k: v})
self.load_type = LoadType(self._raw["loadType"])
is_playlist = self._raw.get("isPlaylist") or self.load_type == LoadType.PLAYLIST_LOADED
if is_playlist is True:
self.is_playlist = True
self.playlist_info = playlist_info(**self._raw["playlistInfo"])
elif is_playlist is False:
self.is_playlist = False
self.playlist_info = None
else:
self.is_playlist = None
self.playlist_info = None
_tracks = parse_timestamps(self._raw) if self._raw.get("query") else self._raw["tracks"]
self.tracks = tuple(Track(t) for t in _tracks)
@property
def has_error(self) -> bool:
return self.load_type == LoadType.LOAD_FAILED
@property
def exception_message(self) -> Optional[str]:
"""
On Lavalink V3, if there was an exception during a load or get tracks call
this property will be populated with the error message.
If there was no error this property will be ``None``.
"""
if self.has_error:
exception_data = self._raw.get("exception", {})
return exception_data.get("message")
return None
@property
def exception_severity(self) -> Optional[ExceptionSeverity]:
if self.has_error:
exception_data = self._raw.get("exception", {})
severity = exception_data.get("severity")
if severity is not None:
return ExceptionSeverity(severity)
return None
[docs]class RESTClient:
"""
Client class used to access the REST endpoints on a Lavalink node.
Attributes
----------
player : Player
The player to use
node : Node
The node used by the player
state : PlayerState
The current player state
"""
player: Player
node: Node
state: PlayerState
def __init__(self, player: Player, ssl: bool = False):
"""
Parameters
----------
player: Player
The player object to use
ssl : bool
Whether to use the `https://` protocol.
"""
self.player = player
self.node = player.node
self._session = self.node.session
if ssl:
self._uri = f"https://{self.node.host}:{self.node.port}/loadtracks?identifier="
else:
self._uri = f"http://{self.node.host}:{self.node.port}/loadtracks?identifier="
self._headers = {"Authorization": self.node.password}
self.state = player.state
self._warned = False
def __check_node_ready(self):
if self.state != PlayerState.READY:
raise RuntimeError("Cannot execute REST request when node not ready.")
async def _get(self, url: str) -> dict[str, Any]:
try:
async with self._session.get(url, headers=self._headers) as resp:
data = await resp.json(content_type=None)
except ServerDisconnectedError:
if self.state == PlayerState.DISCONNECTING:
return {
"loadType": LoadType.LOAD_FAILED,
"exception": {
"message": "Load tracks interrupted by player disconnect.",
"severity": ExceptionSeverity.COMMON,
},
"tracks": [],
}
log.debug("Received server disconnected error when player state = %s", self.state.name)
raise
return data
[docs] async def load_tracks(self, query: str) -> LoadResult:
"""
Executes a loadtracks request. Only works on Lavalink V3.
Parameters
----------
query : str
Returns
-------
LoadResult
"""
self.__check_node_ready()
query = str(query)
url = self._uri + quote(query)
data = await self._get(url)
if isinstance(data, dict):
data["query"] = query
data["encodedquery"] = url
return LoadResult(data)
elif isinstance(data, list):
modified_data = {
"loadType": LoadType.V2_COMPAT,
"tracks": data,
"query": query,
"encodedquery": url,
}
return LoadResult(modified_data)
[docs] async def get_tracks(self, query: str) -> Tuple[Track, ...]:
"""
Gets tracks from lavalink.
Parameters
----------
query : str
Returns
-------
Tuple[Track, ...]
"""
if not self._warned:
log.warn("get_tracks() is now deprecated. Please switch to using load_tracks().")
self._warned = True
result = await self.load_tracks(query)
return result.tracks
[docs] async def search_yt(self, query: str) -> LoadResult:
"""
Gets track results from YouTube from Lavalink.
Parameters
----------
query : str
Returns
-------
list of Track
"""
return await self.load_tracks(f"ytsearch:{query}")
[docs] async def search_sc(self, query: str) -> LoadResult:
"""
Gets track results from SoundCloud from Lavalink.
Parameters
----------
query : str
Returns
-------
list of Track
"""
return await self.load_tracks(f"scsearch:{query}")